This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d5f81a4e2d Switch AzureDataLakeStorageV2Hook to use
DefaultAzureCredential for managed identity/workload auth (#38497)
d5f81a4e2d is described below
commit d5f81a4e2de0d4236cffcf2e2d3c682b4c6ec355
Author: Tamara Janina Fingerlin <[email protected]>
AuthorDate: Mon May 27 02:28:39 2024 +0200
Switch AzureDataLakeStorageV2Hook to use DefaultAzureCredential for managed
identity/workload auth (#38497)
---
.../providers/microsoft/azure/hooks/data_lake.py | 7 +++---
.../microsoft/azure/hooks/test_data_factory.py | 29 +++++++++++++++++++++-
2 files changed, 32 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 054eda087e..b2d9c5aafa 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -22,7 +22,7 @@ from typing import Any, Union
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.datalake.store import core, lib, multithread
-from azure.identity import ClientSecretCredential
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.storage.filedatalake import (
DataLakeDirectoryClient,
DataLakeFileClient,
@@ -38,9 +38,10 @@ from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
get_field,
+ get_sync_default_azure_credential,
)
-Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
+Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter,
DefaultAzureCredential]
class AzureDataLakeHook(BaseHook):
@@ -358,7 +359,7 @@ class AzureDataLakeStorageV2Hook(BaseHook):
else:
managed_identity_client_id = self._get_field(extra,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extra,
"workload_identity_tenant_id")
- credential = AzureIdentityCredentialAdapter(
+ credential = get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
diff --git a/tests/providers/microsoft/azure/hooks/test_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_data_factory.py
index 1ee77ad3af..a7d8786fd8 100644
--- a/tests/providers/microsoft/azure/hooks/test_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_data_factory.py
@@ -86,8 +86,8 @@ def setup_connections(create_mock_connections):
"factory_name": DEFAULT_FACTORY,
},
),
+ # connection_missing_subscription_id
Connection(
- # connection_missing_subscription_id
conn_id="azure_data_factory_missing_subscription_id",
conn_type="azure_data_factory",
login="clientId",
@@ -110,6 +110,18 @@ def setup_connections(create_mock_connections):
"factory_name": DEFAULT_FACTORY,
},
),
+ # connection_workload_identity
+ Connection(
+ conn_id="azure_data_factory_workload_identity",
+ conn_type="azure_data_factory",
+ extra={
+ "subscriptionId": "subscriptionId",
+ "resource_group_name": DEFAULT_RESOURCE_GROUP,
+ "factory_name": DEFAULT_FACTORY,
+ "workload_identity_tenant_id": "workload_tenant_id",
+ "managed_identity_client_id": "workload_client_id",
+ },
+ ),
)
@@ -198,6 +210,21 @@ def
test_get_conn_by_default_azure_credential(mock_credential):
mock_create_client.assert_called_with(mock_credential(),
"subscriptionId")
[email protected](f"{MODULE}.get_sync_default_azure_credential")
+def test_get_conn_with_workload_identity(mock_credential):
+ hook = AzureDataFactoryHook("azure_data_factory_workload_identity")
+ with patch.object(hook, "_create_client") as mock_create_client:
+ mock_create_client.return_value = MagicMock()
+
+ connection = hook.get_conn()
+ assert connection is not None
+ mock_credential.assert_called_once_with(
+ managed_identity_client_id="workload_client_id",
+ workload_identity_tenant_id="workload_tenant_id",
+ )
+ mock_create_client.assert_called_with(mock_credential(),
"subscriptionId")
+
+
def test_get_factory(hook: AzureDataFactoryHook):
hook.get_factory(RESOURCE_GROUP, FACTORY)