This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 13865ab05a add managed identity support to AsyncDefaultAzureCredential
(#35394)
13865ab05a is described below
commit 13865ab05a26bc4923f00a734889d03aa3b8d2b2
Author: Wei Lee <[email protected]>
AuthorDate: Fri Nov 3 23:47:13 2023 +0800
add managed identity support to AsyncDefaultAzureCredential (#35394)
* feat(providers/microsoft): add AsyncDefaultAzureCredential support and
make args keyword args for get_default_azure_credential
* refactor(providers/microsoft): split get_default_azure_credential into 2
functions
---
airflow/providers/microsoft/azure/hooks/adx.py | 7 ++--
airflow/providers/microsoft/azure/hooks/asb.py | 12 ++++---
.../microsoft/azure/hooks/container_instance.py | 7 ++--
.../microsoft/azure/hooks/container_registry.py | 10 +++---
.../microsoft/azure/hooks/container_volume.py | 9 ++---
airflow/providers/microsoft/azure/hooks/cosmos.py | 10 +++---
.../microsoft/azure/hooks/data_factory.py | 15 +++++++--
.../providers/microsoft/azure/hooks/data_lake.py | 14 ++++++--
.../providers/microsoft/azure/hooks/fileshare.py | 15 +++++----
airflow/providers/microsoft/azure/hooks/synapse.py | 7 ++--
airflow/providers/microsoft/azure/hooks/wasb.py | 16 ++++++---
.../providers/microsoft/azure/secrets/key_vault.py | 4 +--
airflow/providers/microsoft/azure/utils.py | 39 ++++++++++++++++++----
tests/providers/microsoft/azure/hooks/test_adx.py | 2 +-
tests/providers/microsoft/azure/hooks/test_asb.py | 12 ++++---
.../azure/hooks/test_azure_container_instance.py | 6 ++--
.../azure/hooks/test_azure_container_registry.py | 8 +++--
.../azure/hooks/test_azure_container_volume.py | 6 ++--
.../microsoft/azure/hooks/test_azure_cosmos.py | 2 +-
.../azure/hooks/test_azure_data_factory.py | 33 +++++++++++++++++-
.../microsoft/azure/hooks/test_azure_synapse.py | 2 +-
tests/providers/microsoft/azure/hooks/test_wasb.py | 2 +-
.../azure/secrets/test_azure_key_vault.py | 8 ++---
23 files changed, 180 insertions(+), 66 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/adx.py
b/airflow/providers/microsoft/azure/hooks/adx.py
index 4d2ba0ae6d..d573e381e7 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -36,7 +36,7 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
+ get_sync_default_azure_credential,
)
if TYPE_CHECKING:
@@ -198,7 +198,10 @@ class AzureDataExplorerHook(BaseHook):
elif auth_method == "AZURE_TOKEN_CRED":
managed_identity_client_id =
conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id =
conn.extra_dejson.get("workload_identity_tenant_id")
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
kcsb = KustoConnectionStringBuilder.with_azure_token_credential(
connection_string=cluster,
credential=credential,
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py
b/airflow/providers/microsoft/azure/hooks/asb.py
index d1580f2aa4..1c8a55544c 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -24,8 +24,8 @@ from azure.servicebus.management import QueueProperties,
ServiceBusAdministratio
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
get_field,
+ get_sync_default_azure_credential,
)
if TYPE_CHECKING:
@@ -119,8 +119,9 @@ class AdminClientHook(BaseAzureServiceBusHook):
workload_identity_tenant_id = self._get_field(
extras=extras, field_name="workload_identity_tenant_id"
)
- credential = get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
)
client = ServiceBusAdministrationClient(
fully_qualified_namespace=fully_qualified_namespace,
@@ -211,8 +212,9 @@ class MessageHook(BaseAzureServiceBusHook):
workload_identity_tenant_id = self._get_field(
extras=extras, field_name="workload_identity_tenant_id"
)
- credential = get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
)
client = ServiceBusClient(
fully_qualified_namespace=fully_qualified_namespace,
diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index ccbfdcfe69..04878b98bd 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -27,7 +27,7 @@ from azure.mgmt.containerinstance import
ContainerInstanceManagementClient
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import
get_sync_default_azure_credential
if TYPE_CHECKING:
from azure.mgmt.containerinstance.models import (
@@ -95,7 +95,10 @@ class AzureContainerInstanceHook(AzureBaseHook):
self.log.info("Using DefaultAzureCredential as credential")
managed_identity_client_id =
conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id =
conn.extra_dejson.get("workload_identity_tenant_id")
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
subscription_id = cast(str, conn.extra_dejson.get("subscriptionId"))
return ContainerInstanceManagementClient(
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index 04cfa719ef..04f7893d39 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -27,8 +27,8 @@ from azure.mgmt.containerregistry import
ContainerRegistryManagementClient
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
get_field,
+ get_sync_default_azure_credential,
)
@@ -109,10 +109,12 @@ class AzureContainerRegistryHook(BaseHook):
resource_group = self._get_field(extras, "resource_group")
managed_identity_client_id = self._get_field(extras,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras,
"workload_identity_tenant_id")
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
client = ContainerRegistryManagementClient(
- credential=get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
- ),
+ credential=credential,
subscription_id=subscription_id,
)
credentials = client.registries.list_credentials(resource_group,
conn.login).as_dict()
diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index fdf6d96407..b9a712738c 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -24,8 +24,8 @@ from azure.mgmt.storage import StorageManagementClient
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
get_field,
+ get_sync_default_azure_credential,
)
@@ -111,10 +111,11 @@ class AzureContainerVolumeHook(BaseHook):
if subscription_id and storage_account_name and resource_group:
managed_identity_client_id = self._get_field(extras,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras,
"workload_identity_tenant_id")
- credentials = get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
)
- storage_client = StorageManagementClient(credentials,
subscription_id)
+ storage_client = StorageManagementClient(credential,
subscription_id)
storage_account_list_keys_result =
storage_client.storage_accounts.list_keys(
resource_group, storage_account_name
)
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 1af863090c..4737554b50 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -37,8 +37,8 @@ from airflow.exceptions import AirflowBadRequest,
AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
get_field,
+ get_sync_default_azure_credential,
)
@@ -133,10 +133,12 @@ class AzureCosmosDBHook(BaseHook):
managed_identity_client_id = self._get_field(extras,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras,
"workload_identity_tenant_id")
subscritption_id = self._get_field(extras, "subscription_id")
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
management_client = CosmosDBManagementClient(
- credential=get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
- ),
+ credential=credential,
subscription_id=subscritption_id,
)
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index e8a80242ea..8047f15939 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -50,7 +50,8 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
+ get_async_default_azure_credential,
+ get_sync_default_azure_credential,
)
if TYPE_CHECKING:
@@ -212,7 +213,10 @@ class AzureDataFactoryHook(BaseHook):
else:
managed_identity_client_id = get_field(extras,
"managed_identity_client_id")
workload_identity_tenant_id = get_field(extras,
"workload_identity_tenant_id")
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self._conn = self._create_client(credential, subscription_id)
return self._conn
@@ -1147,7 +1151,12 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
client_id=conn.login, client_secret=conn.password,
tenant_id=tenant
)
else:
- credential = AsyncDefaultAzureCredential()
+ managed_identity_client_id = get_field(extras,
"managed_identity_client_id")
+ workload_identity_tenant_id = get_field(extras,
"workload_identity_tenant_id")
+ credential = get_async_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self._async_conn = AsyncDataFactoryManagementClient(
credential=credential,
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 557c70fe99..b6f49e5cf8 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -123,7 +123,12 @@ class AzureDataLakeHook(BaseHook):
if tenant:
credential = lib.auth(tenant_id=tenant,
client_secret=conn.password, client_id=conn.login)
else:
- credential = AzureIdentityCredentialAdapter()
+ managed_identity_client_id = self._get_field(extras,
"managed_identity_client_id")
+ workload_identity_tenant_id = self._get_field(extras,
"workload_identity_tenant_id")
+ credential = AzureIdentityCredentialAdapter(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self._conn = core.AzureDLFileSystem(credential,
store_name=self.account_name)
self._conn.connect()
return self._conn
@@ -347,7 +352,12 @@ class AzureDataLakeStorageV2Hook(BaseHook):
elif conn.password:
credential = conn.password
else:
- credential = AzureIdentityCredentialAdapter()
+ managed_identity_client_id = self._get_field(extra,
"managed_identity_client_id")
+ workload_identity_tenant_id = self._get_field(extra,
"workload_identity_tenant_id")
+ credential = AzureIdentityCredentialAdapter(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
return DataLakeServiceClient(
account_url=f"https://{conn.host}.dfs.core.windows.net",
diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py
b/airflow/providers/microsoft/azure/hooks/fileshare.py
index 92235e7ecf..5ecaee80da 100644
--- a/airflow/providers/microsoft/azure/hooks/fileshare.py
+++ b/airflow/providers/microsoft/azure/hooks/fileshare.py
@@ -24,7 +24,7 @@ from azure.storage.fileshare import FileProperties,
ShareDirectoryClient, ShareF
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
+ get_sync_default_azure_credential,
)
@@ -106,12 +106,15 @@ class AzureFileShareHook(BaseHook):
return f"https://{account_url}.file.core.windows.net"
return account_url
- def _get_default_azure_credential(self):
+ def _get_sync_default_azure_credential(self):
conn = self.get_connection(self._conn_id)
extras = conn.extra_dejson
managed_identity_client_id = extras.get("managed_identity_client_id")
workload_identity_tenant_id = extras.get("workload_identity_tenant_id")
- return get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ return get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
@property
def share_service_client(self):
@@ -126,7 +129,7 @@ class AzureFileShareHook(BaseHook):
else:
return ShareServiceClient(
account_url=self._account_url,
- credential=self._get_default_azure_credential(),
+ credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
@@ -151,7 +154,7 @@ class AzureFileShareHook(BaseHook):
account_url=self._account_url,
share_name=self.share_name,
directory_path=self.directory_path,
- credential=self._get_default_azure_credential(),
+ credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
@@ -176,7 +179,7 @@ class AzureFileShareHook(BaseHook):
account_url=self._account_url,
share_name=self.share_name,
file_path=self.file_path,
- credential=self._get_default_azure_credential(),
+ credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py
b/airflow/providers/microsoft/azure/hooks/synapse.py
index a4f0f5d7c3..e284194376 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -26,8 +26,8 @@ from airflow.exceptions import AirflowTaskTimeout
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
get_field,
+ get_sync_default_azure_credential,
)
if TYPE_CHECKING:
@@ -131,7 +131,10 @@ class AzureSynapseHook(BaseHook):
else:
managed_identity_client_id = self._get_field(extras,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extras,
"workload_identity_tenant_id")
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self._conn = self._create_client(credential, conn.host, spark_pool,
livy_api_version, subscription_id)
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py
b/airflow/providers/microsoft/azure/hooks/wasb.py
index 8c7f1636a1..43aea8c99c 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -49,7 +49,8 @@ from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
- get_default_azure_credential,
+ get_async_default_azure_credential,
+ get_sync_default_azure_credential,
)
if TYPE_CHECKING:
@@ -214,8 +215,10 @@ class WasbHook(BaseHook):
if not credential:
managed_identity_client_id = self._get_field(extra,
"managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extra,
"workload_identity_tenant_id")
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
-
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self.log.info("Using DefaultAzureCredential as credential")
return BlobServiceClient(
account_url=account_url,
@@ -642,7 +645,12 @@ class WasbAsyncHook(WasbHook):
# Fall back to old auth (password) or use managed identity if not
provided.
credential = conn.password
if not credential:
- credential = AsyncDefaultAzureCredential()
+ managed_identity_client_id = self._get_field(extra,
"managed_identity_client_id")
+ workload_identity_tenant_id = self._get_field(extra,
"workload_identity_tenant_id")
+ credential = get_async_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self.log.info("Using DefaultAzureCredential as credential")
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url,
diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py
b/airflow/providers/microsoft/azure/secrets/key_vault.py
index f4c51f1f64..794788206c 100644
--- a/airflow/providers/microsoft/azure/secrets/key_vault.py
+++ b/airflow/providers/microsoft/azure/secrets/key_vault.py
@@ -35,7 +35,7 @@ from azure.identity import ClientSecretCredential,
DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import
get_sync_default_azure_credential
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.version import version as airflow_version
@@ -143,7 +143,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend,
LoggingMixin):
if all([self.tenant_id, self.client_id, self.client_secret]):
credential = ClientSecretCredential(self.tenant_id,
self.client_id, self.client_secret)
else:
- credential = get_default_azure_credential(
+ credential = get_sync_default_azure_credential(
managed_identity_client_id=self.managed_identity_client_id,
workload_identity_tenant_id=self.workload_identity_tenant_id,
)
diff --git a/airflow/providers/microsoft/azure/utils.py
b/airflow/providers/microsoft/azure/utils.py
index c41a4e758c..1b738ed957 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -18,14 +18,19 @@
from __future__ import annotations
import warnings
-from functools import wraps
+from functools import partial, wraps
+from typing import TYPE_CHECKING
from azure.core.pipeline import PipelineContext, PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest
from azure.identity import DefaultAzureCredential
+from azure.identity.aio import DefaultAzureCredential as
AsyncDefaultAzureCredential
from msrest.authentication import BasicTokenAuthentication
+if TYPE_CHECKING:
+ pass
+
def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we
check for prefixed name."""
@@ -52,22 +57,39 @@ def get_field(*, conn_id: str, conn_type: str, extras:
dict, field_name: str):
return ret
-def get_default_azure_credential(
- managed_identity_client_id: str | None, workload_identity_tenant_id: str |
None
-) -> DefaultAzureCredential:
+def _get_default_azure_credential(
+ *,
+ managed_identity_client_id: str | None,
+ workload_identity_tenant_id: str | None,
+ use_async: bool = False,
+) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
"""Get DefaultAzureCredential based on provided arguments.
If managed_identity_client_id and workload_identity_tenant_id are
provided, this function returns
DefaultAzureCredential with managed identity.
"""
+ credential_cls: type[AsyncDefaultAzureCredential] |
type[DefaultAzureCredential] = (
+ AsyncDefaultAzureCredential if use_async else DefaultAzureCredential
+ )
if managed_identity_client_id and workload_identity_tenant_id:
- return DefaultAzureCredential(
+ return credential_cls(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
additionally_allowed_tenants=[workload_identity_tenant_id],
)
else:
- return DefaultAzureCredential()
+ return credential_cls()
+
+
+get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(
+ _get_default_azure_credential, # type: ignore[arg-type]
+ use_async=False,
+)
+
+get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] =
partial(
+ _get_default_azure_credential, # type: ignore[arg-type]
+ use_async=False,
+)
def add_managed_identity_connection_widgets(func):
@@ -123,7 +145,10 @@ class
AzureIdentityCredentialAdapter(BasicTokenAuthentication):
"""
super().__init__(None) # type: ignore[arg-type]
if credential is None:
- credential =
get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+ credential = get_sync_default_azure_credential(
+ managed_identity_client_id=managed_identity_client_id,
+ workload_identity_tenant_id=workload_identity_tenant_id,
+ )
self._policy = BearerTokenCredentialPolicy(credential, resource_id,
**kwargs)
def _make_request(self):
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py
b/tests/providers/microsoft/azure/hooks/test_adx.py
index c8de2d933f..d919ce699a 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -235,7 +235,7 @@ class TestAzureDataExplorerHook:
],
indirect=True,
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_default_azure_credential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_sync_default_azure_credential")
@mock.patch.object(KustoClient, "__init__")
def test_conn_method_azure_token_cred(self, mock_init,
mock_default_azure_credential, mocked_connection):
mock_init.return_value = None
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 195017bfd3..8777093a07 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -61,7 +61,7 @@ class TestAdminClientHook:
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
- @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{MODULE}.AdminClientHook.get_connection")
def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
self, mock_connection, mock_default_azure_credential
@@ -69,7 +69,9 @@ class TestAdminClientHook:
mock_connection.return_value = self.mock_conn_without_schema
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
- mock_default_azure_credential.assert_called_with(None, None)
+ assert mock_default_azure_credential.called_with(
+ managed_identity_client_id=None, workload_identity_tenant_id=None
+ )
@mock.patch("azure.servicebus.management.QueueProperties")
@mock.patch(f"{MODULE}.AdminClientHook.get_conn")
@@ -172,7 +174,7 @@ class TestMessageHook:
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
- @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{MODULE}.MessageHook.get_connection")
def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
self, mock_connection, mock_default_azure_credential
@@ -180,7 +182,9 @@ class TestMessageHook:
mock_connection.return_value = self.mock_conn_without_schema
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
- mock_default_azure_credential.assert_called_with(None, None)
+ assert mock_default_azure_credential.called_with(
+ managed_identity_client_id=None, workload_identity_tenant_id=None
+ )
@pytest.mark.parametrize(
"mock_message, mock_batch_flag",
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index bfe8a57484..00d2562f92 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -131,7 +131,7 @@ class TestAzureContainerInstanceHook:
class TestAzureContainerInstanceHookWithoutSetupCredential:
@patch("airflow.providers.microsoft.azure.hooks.container_instance.ContainerInstanceManagementClient")
@patch("azure.common.credentials.ServicePrincipalCredentials")
-
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_default_azure_credential")
+
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_sync_default_azure_credential")
def test_get_conn_fallback_to_default_azure_credential(
self,
mock_default_azure_credential,
@@ -148,7 +148,9 @@ class TestAzureContainerInstanceHookWithoutSetupCredential:
hook =
AzureContainerInstanceHook(azure_conn_id=connection_without_login_password_tenant_id.conn_id)
conn = hook.get_conn()
- mock_default_azure_credential.assert_called_with(None, None)
+ assert mock_default_azure_credential.called_with(
+ managed_identity_client_id=None, workload_identity_tenant_id=None
+ )
assert not mock_service_pricipal_credential.called
assert conn == mock_client_instance
mock_client_cls.assert_called_once_with(
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 063a1290d2..def088ee2b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -63,7 +63,9 @@ class TestAzureContainerRegistryHook:
@mock.patch(
"airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient"
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.get_default_azure_credential")
+ @mock.patch(
+
"airflow.providers.microsoft.azure.hooks.container_registry.get_sync_default_azure_credential"
+ )
def test_get_conn_with_default_azure_credential(
self, mocked_default_azure_credential, mocked_client, mocked_connection
):
@@ -80,4 +82,6 @@ class TestAzureContainerRegistryHook:
assert hook.connection.password == "password"
assert hook.connection.server == "test.cr"
- mocked_default_azure_credential.assert_called_with(None, None)
+ assert mocked_default_azure_credential.called_with(
+ managed_identity_client_id=None, workload_identity_tenant_id=None
+ )
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index 5eec510f4e..e0405f1f33 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -86,7 +86,7 @@ class TestAzureContainerVolumeHook:
indirect=True,
)
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_default_azure_credential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_sync_default_azure_credential")
def test_get_file_volume_default_azure_credential(
self, mocked_default_azure_credential, mocked_client, mocked_connection
):
@@ -112,4 +112,6 @@ class TestAzureContainerVolumeHook:
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True
- mocked_default_azure_credential.assert_called_with(None, None)
+ assert mocked_default_azure_credential.called_with(
+ managed_identity_client_id=None, workload_identity_tenant_id=None
+ )
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index bbb7f7f870..4a57b1743d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -73,7 +73,7 @@ class TestAzureCosmosDbHook:
],
indirect=True,
)
- @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{MODULE}.CosmosDBManagementClient")
@mock.patch(f"{MODULE}.CosmosClient")
def test_get_conn(self, mock_cosmos, mock_cosmos_db,
mock_default_azure_credential, mocked_connection):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 14a7121202..fe80cc757c 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -178,7 +178,7 @@ def
test_get_conn_by_credential_client_secret(mock_credential):
mock_create_client.assert_called_with(mock_credential(),
"subscriptionId")
[email protected](f"{MODULE}.get_default_azure_credential")
[email protected](f"{MODULE}.get_sync_default_azure_credential")
def test_get_conn_by_default_azure_credential(mock_credential):
hook = AzureDataFactoryHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
@@ -657,6 +657,37 @@ class TestAzureDataFactoryAsyncHook:
with pytest.raises(AirflowException):
await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ extra={
+ "subscriptionId": "subscriptionId",
+ "resource_group_name": RESOURCE_GROUP_NAME,
+ "factory_name": DATAFACTORY_NAME,
+ "managed_identity_client_id": "test_client_id",
+ "workload_identity_tenant_id": "test_tenant_id",
+ },
+ )
+ ],
+ indirect=True,
+ )
+ @mock.patch(f"{MODULE}.get_async_default_azure_credential")
+ async def test_get_async_conn_with_default_azure_credential(
+ self, mock_default_azure_credential, mocked_connection
+ ):
+ """"""
+ hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
+ response = await hook.get_async_conn()
+ assert isinstance(response, DataFactoryManagementClient)
+
+ assert mock_default_azure_credential.called_with(
+ managed_identity_client_id="test_client_id",
workload_identity_tenant_id="test_tenant_id"
+ )
+
@pytest.mark.asyncio
@pytest.mark.parametrize(
"mocked_connection",
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index 288a560cfa..d66268798d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -106,7 +106,7 @@ def
test_get_connection_by_credential_client_secret(mock_credential):
)
-@patch(f"{MODULE}.get_default_azure_credential")
+@patch(f"{MODULE}.get_sync_default_azure_credential")
def test_get_conn_by_default_azure_credential(mock_credential):
hook = AzureSynapseHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 7ac4fb1857..bc4bed018e 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -47,7 +47,7 @@ def mocked_blob_service_client():
@pytest.fixture
def mocked_default_azure_credential():
- with
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.get_default_azure_credential")
as m:
+ with
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.get_sync_default_azure_credential")
as m:
yield m
diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
index 51b20084ed..2c791f87a3 100644
--- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
+++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
@@ -35,7 +35,7 @@ class TestAzureKeyVaultBackend:
conn = AzureKeyVaultBackend().get_connection("fake_conn")
assert conn.host == "host"
- @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+ @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
mock_cred = mock.Mock()
@@ -153,7 +153,7 @@ class TestAzureKeyVaultBackend:
assert backend.get_config("test_mysql") is None
mock_get_secret.assert_not_called()
- @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+ @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_client_authenticate_with_default_azure_credential(
@@ -169,7 +169,7 @@ class TestAzureKeyVaultBackend:
assert not mock_client_secret_credential.called
mock_defaul_azure_credential.assert_called_once()
- @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+ @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def
test_client_authenticate_with_default_azure_credential_and_customized_configuration(
@@ -187,7 +187,7 @@ class TestAzureKeyVaultBackend:
workload_identity_tenant_id="workload_identity_tenant_id",
)
- @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+ @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_client_authenticate_with_client_secret_credential(