This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13865ab05a add managed identity support to AsyncDefaultAzureCredential 
(#35394)
13865ab05a is described below

commit 13865ab05a26bc4923f00a734889d03aa3b8d2b2
Author: Wei Lee <[email protected]>
AuthorDate: Fri Nov 3 23:47:13 2023 +0800

    add managed identity support to AsyncDefaultAzureCredential (#35394)
    
    * feat(providers/microsoft): add AsyncDefaultAzureCredential support and 
make args keyword args for get_default_azure_credential
    
    * refactor(providers/microsoft): split get_default_azure_credential into 2 
functions
---
 airflow/providers/microsoft/azure/hooks/adx.py     |  7 ++--
 airflow/providers/microsoft/azure/hooks/asb.py     | 12 ++++---
 .../microsoft/azure/hooks/container_instance.py    |  7 ++--
 .../microsoft/azure/hooks/container_registry.py    | 10 +++---
 .../microsoft/azure/hooks/container_volume.py      |  9 ++---
 airflow/providers/microsoft/azure/hooks/cosmos.py  | 10 +++---
 .../microsoft/azure/hooks/data_factory.py          | 15 +++++++--
 .../providers/microsoft/azure/hooks/data_lake.py   | 14 ++++++--
 .../providers/microsoft/azure/hooks/fileshare.py   | 15 +++++----
 airflow/providers/microsoft/azure/hooks/synapse.py |  7 ++--
 airflow/providers/microsoft/azure/hooks/wasb.py    | 16 ++++++---
 .../providers/microsoft/azure/secrets/key_vault.py |  4 +--
 airflow/providers/microsoft/azure/utils.py         | 39 ++++++++++++++++++----
 tests/providers/microsoft/azure/hooks/test_adx.py  |  2 +-
 tests/providers/microsoft/azure/hooks/test_asb.py  | 12 ++++---
 .../azure/hooks/test_azure_container_instance.py   |  6 ++--
 .../azure/hooks/test_azure_container_registry.py   |  8 +++--
 .../azure/hooks/test_azure_container_volume.py     |  6 ++--
 .../microsoft/azure/hooks/test_azure_cosmos.py     |  2 +-
 .../azure/hooks/test_azure_data_factory.py         | 33 +++++++++++++++++-
 .../microsoft/azure/hooks/test_azure_synapse.py    |  2 +-
 tests/providers/microsoft/azure/hooks/test_wasb.py |  2 +-
 .../azure/secrets/test_azure_key_vault.py          |  8 ++---
 23 files changed, 180 insertions(+), 66 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/adx.py 
b/airflow/providers/microsoft/azure/hooks/adx.py
index 4d2ba0ae6d..d573e381e7 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -36,7 +36,7 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
+    get_sync_default_azure_credential,
 )
 
 if TYPE_CHECKING:
@@ -198,7 +198,10 @@ class AzureDataExplorerHook(BaseHook):
         elif auth_method == "AZURE_TOKEN_CRED":
             managed_identity_client_id = 
conn.extra_dejson.get("managed_identity_client_id")
             workload_identity_tenant_id = 
conn.extra_dejson.get("workload_identity_tenant_id")
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
             kcsb = KustoConnectionStringBuilder.with_azure_token_credential(
                 connection_string=cluster,
                 credential=credential,
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py 
b/airflow/providers/microsoft/azure/hooks/asb.py
index d1580f2aa4..1c8a55544c 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -24,8 +24,8 @@ from azure.servicebus.management import QueueProperties, 
ServiceBusAdministratio
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
     get_field,
+    get_sync_default_azure_credential,
 )
 
 if TYPE_CHECKING:
@@ -119,8 +119,9 @@ class AdminClientHook(BaseAzureServiceBusHook):
                 workload_identity_tenant_id = self._get_field(
                     extras=extras, field_name="workload_identity_tenant_id"
                 )
-                credential = get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
+                credential = get_sync_default_azure_credential(
+                    managed_identity_client_id=managed_identity_client_id,
+                    workload_identity_tenant_id=workload_identity_tenant_id,
                 )
             client = ServiceBusAdministrationClient(
                 fully_qualified_namespace=fully_qualified_namespace,
@@ -211,8 +212,9 @@ class MessageHook(BaseAzureServiceBusHook):
                 workload_identity_tenant_id = self._get_field(
                     extras=extras, field_name="workload_identity_tenant_id"
                 )
-                credential = get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
+                credential = get_sync_default_azure_credential(
+                    managed_identity_client_id=managed_identity_client_id,
+                    workload_identity_tenant_id=workload_identity_tenant_id,
                 )
             client = ServiceBusClient(
                 fully_qualified_namespace=fully_qualified_namespace,
diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py 
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index ccbfdcfe69..04878b98bd 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -27,7 +27,7 @@ from azure.mgmt.containerinstance import 
ContainerInstanceManagementClient
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import 
get_sync_default_azure_credential
 
 if TYPE_CHECKING:
     from azure.mgmt.containerinstance.models import (
@@ -95,7 +95,10 @@ class AzureContainerInstanceHook(AzureBaseHook):
             self.log.info("Using DefaultAzureCredential as credential")
             managed_identity_client_id = 
conn.extra_dejson.get("managed_identity_client_id")
             workload_identity_tenant_id = 
conn.extra_dejson.get("workload_identity_tenant_id")
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
 
         subscription_id = cast(str, conn.extra_dejson.get("subscriptionId"))
         return ContainerInstanceManagementClient(
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py 
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index 04cfa719ef..04f7893d39 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -27,8 +27,8 @@ from azure.mgmt.containerregistry import 
ContainerRegistryManagementClient
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
     get_field,
+    get_sync_default_azure_credential,
 )
 
 
@@ -109,10 +109,12 @@ class AzureContainerRegistryHook(BaseHook):
             resource_group = self._get_field(extras, "resource_group")
             managed_identity_client_id = self._get_field(extras, 
"managed_identity_client_id")
             workload_identity_tenant_id = self._get_field(extras, 
"workload_identity_tenant_id")
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
             client = ContainerRegistryManagementClient(
-                credential=get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
-                ),
+                credential=credential,
                 subscription_id=subscription_id,
             )
             credentials = client.registries.list_credentials(resource_group, 
conn.login).as_dict()
diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py 
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index fdf6d96407..b9a712738c 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -24,8 +24,8 @@ from azure.mgmt.storage import StorageManagementClient
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
     get_field,
+    get_sync_default_azure_credential,
 )
 
 
@@ -111,10 +111,11 @@ class AzureContainerVolumeHook(BaseHook):
         if subscription_id and storage_account_name and resource_group:
             managed_identity_client_id = self._get_field(extras, 
"managed_identity_client_id")
             workload_identity_tenant_id = self._get_field(extras, 
"workload_identity_tenant_id")
-            credentials = get_default_azure_credential(
-                managed_identity_client_id, workload_identity_tenant_id
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
             )
-            storage_client = StorageManagementClient(credentials, 
subscription_id)
+            storage_client = StorageManagementClient(credential, 
subscription_id)
             storage_account_list_keys_result = 
storage_client.storage_accounts.list_keys(
                 resource_group, storage_account_name
             )
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py 
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 1af863090c..4737554b50 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -37,8 +37,8 @@ from airflow.exceptions import AirflowBadRequest, 
AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
     get_field,
+    get_sync_default_azure_credential,
 )
 
 
@@ -133,10 +133,12 @@ class AzureCosmosDBHook(BaseHook):
                 managed_identity_client_id = self._get_field(extras, 
"managed_identity_client_id")
                 workload_identity_tenant_id = self._get_field(extras, 
"workload_identity_tenant_id")
                 subscritption_id = self._get_field(extras, "subscription_id")
+                credential = get_sync_default_azure_credential(
+                    managed_identity_client_id=managed_identity_client_id,
+                    workload_identity_tenant_id=workload_identity_tenant_id,
+                )
                 management_client = CosmosDBManagementClient(
-                    credential=get_default_azure_credential(
-                        managed_identity_client_id, workload_identity_tenant_id
-                    ),
+                    credential=credential,
                     subscription_id=subscritption_id,
                 )
 
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py 
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index e8a80242ea..8047f15939 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -50,7 +50,8 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
+    get_async_default_azure_credential,
+    get_sync_default_azure_credential,
 )
 
 if TYPE_CHECKING:
@@ -212,7 +213,10 @@ class AzureDataFactoryHook(BaseHook):
         else:
             managed_identity_client_id = get_field(extras, 
"managed_identity_client_id")
             workload_identity_tenant_id = get_field(extras, 
"workload_identity_tenant_id")
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
         self._conn = self._create_client(credential, subscription_id)
 
         return self._conn
@@ -1147,7 +1151,12 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
                 client_id=conn.login, client_secret=conn.password, 
tenant_id=tenant
             )
         else:
-            credential = AsyncDefaultAzureCredential()
+            managed_identity_client_id = get_field(extras, 
"managed_identity_client_id")
+            workload_identity_tenant_id = get_field(extras, 
"workload_identity_tenant_id")
+            credential = get_async_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
 
         self._async_conn = AsyncDataFactoryManagementClient(
             credential=credential,
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py 
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 557c70fe99..b6f49e5cf8 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -123,7 +123,12 @@ class AzureDataLakeHook(BaseHook):
             if tenant:
                 credential = lib.auth(tenant_id=tenant, 
client_secret=conn.password, client_id=conn.login)
             else:
-                credential = AzureIdentityCredentialAdapter()
+                managed_identity_client_id = self._get_field(extras, 
"managed_identity_client_id")
+                workload_identity_tenant_id = self._get_field(extras, 
"workload_identity_tenant_id")
+                credential = AzureIdentityCredentialAdapter(
+                    managed_identity_client_id=managed_identity_client_id,
+                    workload_identity_tenant_id=workload_identity_tenant_id,
+                )
             self._conn = core.AzureDLFileSystem(credential, 
store_name=self.account_name)
             self._conn.connect()
         return self._conn
@@ -347,7 +352,12 @@ class AzureDataLakeStorageV2Hook(BaseHook):
         elif conn.password:
             credential = conn.password
         else:
-            credential = AzureIdentityCredentialAdapter()
+            managed_identity_client_id = self._get_field(extra, 
"managed_identity_client_id")
+            workload_identity_tenant_id = self._get_field(extra, 
"workload_identity_tenant_id")
+            credential = AzureIdentityCredentialAdapter(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
 
         return DataLakeServiceClient(
             account_url=f"https://{conn.host}.dfs.core.windows.net";,
diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py 
b/airflow/providers/microsoft/azure/hooks/fileshare.py
index 92235e7ecf..5ecaee80da 100644
--- a/airflow/providers/microsoft/azure/hooks/fileshare.py
+++ b/airflow/providers/microsoft/azure/hooks/fileshare.py
@@ -24,7 +24,7 @@ from azure.storage.fileshare import FileProperties, 
ShareDirectoryClient, ShareF
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
+    get_sync_default_azure_credential,
 )
 
 
@@ -106,12 +106,15 @@ class AzureFileShareHook(BaseHook):
             return f"https://{account_url}.file.core.windows.net";
         return account_url
 
-    def _get_default_azure_credential(self):
+    def _get_sync_default_azure_credential(self):
         conn = self.get_connection(self._conn_id)
         extras = conn.extra_dejson
         managed_identity_client_id = extras.get("managed_identity_client_id")
         workload_identity_tenant_id = extras.get("workload_identity_tenant_id")
-        return get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+        return get_sync_default_azure_credential(
+            managed_identity_client_id=managed_identity_client_id,
+            workload_identity_tenant_id=workload_identity_tenant_id,
+        )
 
     @property
     def share_service_client(self):
@@ -126,7 +129,7 @@ class AzureFileShareHook(BaseHook):
         else:
             return ShareServiceClient(
                 account_url=self._account_url,
-                credential=self._get_default_azure_credential(),
+                credential=self._get_sync_default_azure_credential(),
                 token_intent="backup",
             )
 
@@ -151,7 +154,7 @@ class AzureFileShareHook(BaseHook):
                 account_url=self._account_url,
                 share_name=self.share_name,
                 directory_path=self.directory_path,
-                credential=self._get_default_azure_credential(),
+                credential=self._get_sync_default_azure_credential(),
                 token_intent="backup",
             )
 
@@ -176,7 +179,7 @@ class AzureFileShareHook(BaseHook):
                 account_url=self._account_url,
                 share_name=self.share_name,
                 file_path=self.file_path,
-                credential=self._get_default_azure_credential(),
+                credential=self._get_sync_default_azure_credential(),
                 token_intent="backup",
             )
 
diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py 
b/airflow/providers/microsoft/azure/hooks/synapse.py
index a4f0f5d7c3..e284194376 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -26,8 +26,8 @@ from airflow.exceptions import AirflowTaskTimeout
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
     get_field,
+    get_sync_default_azure_credential,
 )
 
 if TYPE_CHECKING:
@@ -131,7 +131,10 @@ class AzureSynapseHook(BaseHook):
         else:
             managed_identity_client_id = self._get_field(extras, 
"managed_identity_client_id")
             workload_identity_tenant_id = self._get_field(extras, 
"workload_identity_tenant_id")
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
 
         self._conn = self._create_client(credential, conn.host, spark_pool, 
livy_api_version, subscription_id)
 
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index 8c7f1636a1..43aea8c99c 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -49,7 +49,8 @@ from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
-    get_default_azure_credential,
+    get_async_default_azure_credential,
+    get_sync_default_azure_credential,
 )
 
 if TYPE_CHECKING:
@@ -214,8 +215,10 @@ class WasbHook(BaseHook):
         if not credential:
             managed_identity_client_id = self._get_field(extra, 
"managed_identity_client_id")
             workload_identity_tenant_id = self._get_field(extra, 
"workload_identity_tenant_id")
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
-
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
             self.log.info("Using DefaultAzureCredential as credential")
         return BlobServiceClient(
             account_url=account_url,
@@ -642,7 +645,12 @@ class WasbAsyncHook(WasbHook):
         # Fall back to old auth (password) or use managed identity if not 
provided.
         credential = conn.password
         if not credential:
-            credential = AsyncDefaultAzureCredential()
+            managed_identity_client_id = self._get_field(extra, 
"managed_identity_client_id")
+            workload_identity_tenant_id = self._get_field(extra, 
"workload_identity_tenant_id")
+            credential = get_async_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
             self.log.info("Using DefaultAzureCredential as credential")
         self.blob_service_client = AsyncBlobServiceClient(
             account_url=account_url,
diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py 
b/airflow/providers/microsoft/azure/secrets/key_vault.py
index f4c51f1f64..794788206c 100644
--- a/airflow/providers/microsoft/azure/secrets/key_vault.py
+++ b/airflow/providers/microsoft/azure/secrets/key_vault.py
@@ -35,7 +35,7 @@ from azure.identity import ClientSecretCredential, 
DefaultAzureCredential
 from azure.keyvault.secrets import SecretClient
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import 
get_sync_default_azure_credential
 from airflow.secrets import BaseSecretsBackend
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.version import version as airflow_version
@@ -143,7 +143,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend, 
LoggingMixin):
         if all([self.tenant_id, self.client_id, self.client_secret]):
             credential = ClientSecretCredential(self.tenant_id, 
self.client_id, self.client_secret)
         else:
-            credential = get_default_azure_credential(
+            credential = get_sync_default_azure_credential(
                 managed_identity_client_id=self.managed_identity_client_id,
                 workload_identity_tenant_id=self.workload_identity_tenant_id,
             )
diff --git a/airflow/providers/microsoft/azure/utils.py 
b/airflow/providers/microsoft/azure/utils.py
index c41a4e758c..1b738ed957 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -18,14 +18,19 @@
 from __future__ import annotations
 
 import warnings
-from functools import wraps
+from functools import partial, wraps
+from typing import TYPE_CHECKING
 
 from azure.core.pipeline import PipelineContext, PipelineRequest
 from azure.core.pipeline.policies import BearerTokenCredentialPolicy
 from azure.core.pipeline.transport import HttpRequest
 from azure.identity import DefaultAzureCredential
+from azure.identity.aio import DefaultAzureCredential as 
AsyncDefaultAzureCredential
 from msrest.authentication import BasicTokenAuthentication
 
+if TYPE_CHECKING:
+    pass
+
 
 def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
     """Get field from extra, first checking short name, then for backcompat we 
check for prefixed name."""
@@ -52,22 +57,39 @@ def get_field(*, conn_id: str, conn_type: str, extras: 
dict, field_name: str):
     return ret
 
 
-def get_default_azure_credential(
-    managed_identity_client_id: str | None, workload_identity_tenant_id: str | 
None
-) -> DefaultAzureCredential:
+def _get_default_azure_credential(
+    *,
+    managed_identity_client_id: str | None,
+    workload_identity_tenant_id: str | None,
+    use_async: bool = False,
+) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
     """Get DefaultAzureCredential based on provided arguments.
 
     If managed_identity_client_id and workload_identity_tenant_id are 
provided, this function returns
     DefaultAzureCredential with managed identity.
     """
+    credential_cls: type[AsyncDefaultAzureCredential] | 
type[DefaultAzureCredential] = (
+        AsyncDefaultAzureCredential if use_async else DefaultAzureCredential
+    )
     if managed_identity_client_id and workload_identity_tenant_id:
-        return DefaultAzureCredential(
+        return credential_cls(
             managed_identity_client_id=managed_identity_client_id,
             workload_identity_tenant_id=workload_identity_tenant_id,
             additionally_allowed_tenants=[workload_identity_tenant_id],
         )
     else:
-        return DefaultAzureCredential()
+        return credential_cls()
+
+
+get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(
+    _get_default_azure_credential,  #  type: ignore[arg-type]
+    use_async=False,
+)
+
+get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] = 
partial(
+    _get_default_azure_credential,  #  type: ignore[arg-type]
+    use_async=False,
+)
 
 
 def add_managed_identity_connection_widgets(func):
@@ -123,7 +145,10 @@ class 
AzureIdentityCredentialAdapter(BasicTokenAuthentication):
         """
         super().__init__(None)  # type: ignore[arg-type]
         if credential is None:
-            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+            credential = get_sync_default_azure_credential(
+                managed_identity_client_id=managed_identity_client_id,
+                workload_identity_tenant_id=workload_identity_tenant_id,
+            )
         self._policy = BearerTokenCredentialPolicy(credential, resource_id, 
**kwargs)
 
     def _make_request(self):
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py 
b/tests/providers/microsoft/azure/hooks/test_adx.py
index c8de2d933f..d919ce699a 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -235,7 +235,7 @@ class TestAzureDataExplorerHook:
         ],
         indirect=True,
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_default_azure_credential")
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_sync_default_azure_credential")
     @mock.patch.object(KustoClient, "__init__")
     def test_conn_method_azure_token_cred(self, mock_init, 
mock_default_azure_credential, mocked_connection):
         mock_init.return_value = None
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py 
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 195017bfd3..8777093a07 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -61,7 +61,7 @@ class TestAdminClientHook:
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
 
-    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{MODULE}.AdminClientHook.get_connection")
     def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
         self, mock_connection, mock_default_azure_credential
@@ -69,7 +69,9 @@ class TestAdminClientHook:
         mock_connection.return_value = self.mock_conn_without_schema
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
-        mock_default_azure_credential.assert_called_with(None, None)
+        assert mock_default_azure_credential.called_with(
+            managed_identity_client_id=None, workload_identity_tenant_id=None
+        )
 
     @mock.patch("azure.servicebus.management.QueueProperties")
     @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
@@ -172,7 +174,7 @@ class TestMessageHook:
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
 
-    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{MODULE}.MessageHook.get_connection")
     def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
         self, mock_connection, mock_default_azure_credential
@@ -180,7 +182,9 @@ class TestMessageHook:
         mock_connection.return_value = self.mock_conn_without_schema
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
-        mock_default_azure_credential.assert_called_with(None, None)
+        assert mock_default_azure_credential.called_with(
+            managed_identity_client_id=None, workload_identity_tenant_id=None
+        )
 
     @pytest.mark.parametrize(
         "mock_message, mock_batch_flag",
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index bfe8a57484..00d2562f92 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -131,7 +131,7 @@ class TestAzureContainerInstanceHook:
 class TestAzureContainerInstanceHookWithoutSetupCredential:
     
@patch("airflow.providers.microsoft.azure.hooks.container_instance.ContainerInstanceManagementClient")
     @patch("azure.common.credentials.ServicePrincipalCredentials")
-    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_default_azure_credential")
+    
@patch("airflow.providers.microsoft.azure.hooks.container_instance.get_sync_default_azure_credential")
     def test_get_conn_fallback_to_default_azure_credential(
         self,
         mock_default_azure_credential,
@@ -148,7 +148,9 @@ class TestAzureContainerInstanceHookWithoutSetupCredential:
         hook = 
AzureContainerInstanceHook(azure_conn_id=connection_without_login_password_tenant_id.conn_id)
         conn = hook.get_conn()
 
-        mock_default_azure_credential.assert_called_with(None, None)
+        assert mock_default_azure_credential.called_with(
+            managed_identity_client_id=None, workload_identity_tenant_id=None
+        )
         assert not mock_service_pricipal_credential.called
         assert conn == mock_client_instance
         mock_client_cls.assert_called_once_with(
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 063a1290d2..def088ee2b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -63,7 +63,9 @@ class TestAzureContainerRegistryHook:
     @mock.patch(
         
"airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient"
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.get_default_azure_credential")
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.hooks.container_registry.get_sync_default_azure_credential"
+    )
     def test_get_conn_with_default_azure_credential(
         self, mocked_default_azure_credential, mocked_client, mocked_connection
     ):
@@ -80,4 +82,6 @@ class TestAzureContainerRegistryHook:
         assert hook.connection.password == "password"
         assert hook.connection.server == "test.cr"
 
-        mocked_default_azure_credential.assert_called_with(None, None)
+        assert mocked_default_azure_credential.called_with(
+            managed_identity_client_id=None, workload_identity_tenant_id=None
+        )
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index 5eec510f4e..e0405f1f33 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -86,7 +86,7 @@ class TestAzureContainerVolumeHook:
         indirect=True,
     )
     
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_default_azure_credential")
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.get_sync_default_azure_credential")
     def test_get_file_volume_default_azure_credential(
         self, mocked_default_azure_credential, mocked_client, mocked_connection
     ):
@@ -112,4 +112,6 @@ class TestAzureContainerVolumeHook:
         assert volume.azure_file.storage_account_name == "storage"
         assert volume.azure_file.read_only is True
 
-        mocked_default_azure_credential.assert_called_with(None, None)
+        assert mocked_default_azure_credential.called_with(
+            managed_identity_client_id=None, workload_identity_tenant_id=None
+        )
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py 
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index bbb7f7f870..4a57b1743d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -73,7 +73,7 @@ class TestAzureCosmosDbHook:
         ],
         indirect=True,
     )
-    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{MODULE}.CosmosDBManagementClient")
     @mock.patch(f"{MODULE}.CosmosClient")
     def test_get_conn(self, mock_cosmos, mock_cosmos_db, 
mock_default_azure_credential, mocked_connection):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 14a7121202..fe80cc757c 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -178,7 +178,7 @@ def 
test_get_conn_by_credential_client_secret(mock_credential):
         mock_create_client.assert_called_with(mock_credential(), 
"subscriptionId")
 
 
[email protected](f"{MODULE}.get_default_azure_credential")
[email protected](f"{MODULE}.get_sync_default_azure_credential")
 def test_get_conn_by_default_azure_credential(mock_credential):
     hook = AzureDataFactoryHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
 
@@ -657,6 +657,37 @@ class TestAzureDataFactoryAsyncHook:
         with pytest.raises(AirflowException):
             await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
 
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+                conn_type="azure_data_factory",
+                extra={
+                    "subscriptionId": "subscriptionId",
+                    "resource_group_name": RESOURCE_GROUP_NAME,
+                    "factory_name": DATAFACTORY_NAME,
+                    "managed_identity_client_id": "test_client_id",
+                    "workload_identity_tenant_id": "test_tenant_id",
+                },
+            )
+        ],
+        indirect=True,
+    )
+    @mock.patch(f"{MODULE}.get_async_default_azure_credential")
+    async def test_get_async_conn_with_default_azure_credential(
+        self, mock_default_azure_credential, mocked_connection
+    ):
+        """"""
+        hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
+        response = await hook.get_async_conn()
+        assert isinstance(response, DataFactoryManagementClient)
+
+        assert mock_default_azure_credential.called_with(
+            managed_identity_client_id="test_client_id", 
workload_identity_tenant_id="test_tenant_id"
+        )
+
     @pytest.mark.asyncio
     @pytest.mark.parametrize(
         "mocked_connection",
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py 
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index 288a560cfa..d66268798d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -106,7 +106,7 @@ def 
test_get_connection_by_credential_client_secret(mock_credential):
         )
 
 
-@patch(f"{MODULE}.get_default_azure_credential")
+@patch(f"{MODULE}.get_sync_default_azure_credential")
 def test_get_conn_by_default_azure_credential(mock_credential):
     hook = AzureSynapseHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
 
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py 
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 7ac4fb1857..bc4bed018e 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -47,7 +47,7 @@ def mocked_blob_service_client():
 
 @pytest.fixture
 def mocked_default_azure_credential():
-    with 
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.get_default_azure_credential")
 as m:
+    with 
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.get_sync_default_azure_credential")
 as m:
         yield m
 
 
diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py 
b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
index 51b20084ed..2c791f87a3 100644
--- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
+++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
@@ -35,7 +35,7 @@ class TestAzureKeyVaultBackend:
         conn = AzureKeyVaultBackend().get_connection("fake_conn")
         assert conn.host == "host"
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
         mock_cred = mock.Mock()
@@ -153,7 +153,7 @@ class TestAzureKeyVaultBackend:
         assert backend.get_config("test_mysql") is None
         mock_get_secret.assert_not_called()
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_default_azure_credential(
@@ -169,7 +169,7 @@ class TestAzureKeyVaultBackend:
         assert not mock_client_secret_credential.called
         mock_defaul_azure_credential.assert_called_once()
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def 
test_client_authenticate_with_default_azure_credential_and_customized_configuration(
@@ -187,7 +187,7 @@ class TestAzureKeyVaultBackend:
             workload_identity_tenant_id="workload_identity_tenant_id",
         )
 
-    @mock.patch(f"{KEY_VAULT_MODULE}.get_default_azure_credential")
+    @mock.patch(f"{KEY_VAULT_MODULE}.get_sync_default_azure_credential")
     @mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
     @mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
     def test_client_authenticate_with_client_secret_credential(


Reply via email to