This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b27047430f Refactor azure managed identity (#35367)
b27047430f is described below

commit b27047430fa49538a737138e3c2e57368c4d33b0
Author: Wei Lee <[email protected]>
AuthorDate: Fri Nov 3 05:35:56 2023 +0800

    Refactor azure managed identity (#35367)
---
 airflow/providers/microsoft/azure/hooks/adx.py     |  14 +--
 airflow/providers/microsoft/azure/hooks/asb.py     |  15 ++-
 .../providers/microsoft/azure/hooks/base_azure.py  |  14 +--
 airflow/providers/microsoft/azure/hooks/batch.py   |  31 +++---
 .../microsoft/azure/hooks/container_registry.py    |  15 ++-
 .../microsoft/azure/hooks/container_volume.py      |  39 ++++---
 airflow/providers/microsoft/azure/hooks/cosmos.py  |  15 ++-
 .../microsoft/azure/hooks/data_factory.py          |  14 +--
 .../providers/microsoft/azure/hooks/data_lake.py   |  30 +++---
 .../providers/microsoft/azure/hooks/fileshare.py   |  79 ++++++--------
 airflow/providers/microsoft/azure/hooks/synapse.py |  15 ++-
 airflow/providers/microsoft/azure/hooks/wasb.py    |  14 +--
 airflow/providers/microsoft/azure/utils.py         |  24 +++++
 tests/providers/microsoft/azure/hooks/test_adx.py  | 118 +++++++++++----------
 tests/providers/microsoft/azure/hooks/test_asb.py  |  41 +++----
 .../microsoft/azure/hooks/test_azure_batch.py      |  16 +--
 .../azure/hooks/test_azure_container_instance.py   |   2 +-
 .../microsoft/azure/hooks/test_azure_cosmos.py     |  60 ++++++++---
 .../azure/hooks/test_azure_data_factory.py         |  68 ++++++------
 .../microsoft/azure/hooks/test_azure_data_lake.py  |  76 ++++++-------
 .../microsoft/azure/hooks/test_azure_fileshare.py  |  24 +++--
 .../microsoft/azure/hooks/test_azure_synapse.py    |  45 +++++---
 .../microsoft/azure/hooks/test_base_azure.py       |  46 +++++++-
 tests/providers/microsoft/azure/hooks/test_wasb.py |   1 +
 tests/providers/microsoft/azure/test_utils.py      |  17 ++-
 25 files changed, 466 insertions(+), 367 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/adx.py 
b/airflow/providers/microsoft/azure/hooks/adx.py
index 8ee3ae8f17..4d2ba0ae6d 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -34,7 +34,10 @@ from azure.kusto.data.exceptions import KustoServiceError
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+)
 
 if TYPE_CHECKING:
     from azure.kusto.data.response import KustoResponseDataSetV2
@@ -80,6 +83,7 @@ class AzureDataExplorerHook(BaseHook):
     hook_name = "Azure Data Explorer"
 
     @classmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
@@ -95,12 +99,6 @@ class AzureDataExplorerHook(BaseHook):
             "thumbprint": PasswordField(
                 lazy_gettext("Application Certificate Thumbprint"), 
widget=BS3PasswordFieldWidget()
             ),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @classmethod
@@ -119,8 +117,6 @@ class AzureDataExplorerHook(BaseHook):
                 "tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS",
                 "certificate": "Used with AAD_APP_CERT",
                 "thumbprint": "Used with AAD_APP_CERT",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py 
b/airflow/providers/microsoft/azure/hooks/asb.py
index e81bc84527..d1580f2aa4 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -22,7 +22,11 @@ from azure.servicebus import ServiceBusClient, 
ServiceBusMessage, ServiceBusSend
 from azure.servicebus.management import QueueProperties, 
ServiceBusAdministrationClient
 
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+    get_field,
+)
 
 if TYPE_CHECKING:
     from azure.identity import DefaultAzureCredential
@@ -42,6 +46,7 @@ class BaseAzureServiceBusHook(BaseHook):
     hook_name = "Azure Service Bus"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -53,12 +58,6 @@ class BaseAzureServiceBusHook(BaseHook):
                 lazy_gettext("Fully Qualified Namespace"), 
widget=BS3TextFieldWidget()
             ),
             "credential": PasswordField(lazy_gettext("Credential"), 
widget=BS3TextFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -73,8 +72,6 @@ class BaseAzureServiceBusHook(BaseHook):
                 ),
                 "credential": "credential",
                 "schema": "Endpoint=sb://<Resource 
group>.servicebus.windows.net/;SharedAccessKeyName=<AccessKeyName>;SharedAccessKey=<SharedAccessKey>",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py 
b/airflow/providers/microsoft/azure/hooks/base_azure.py
index 2868816379..d6ecf185c4 100644
--- a/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ b/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -24,7 +24,10 @@ from azure.common.credentials import 
ServicePrincipalCredentials
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter
+from airflow.providers.microsoft.azure.utils import (
+    AzureIdentityCredentialAdapter,
+    add_managed_identity_connection_widgets,
+)
 
 
 class AzureBaseHook(BaseHook):
@@ -45,6 +48,7 @@ class AzureBaseHook(BaseHook):
     hook_name = "Azure"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -54,12 +58,6 @@ class AzureBaseHook(BaseHook):
         return {
             "tenantId": StringField(lazy_gettext("Azure Tenant ID"), 
widget=BS3TextFieldWidget()),
             "subscriptionId": StringField(lazy_gettext("Azure Subscription 
ID"), widget=BS3TextFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -85,8 +83,6 @@ class AzureBaseHook(BaseHook):
                 "password": "secret (token credentials auth)",
                 "tenantId": "tenantId (token credentials auth)",
                 "subscriptionId": "subscriptionId (token credentials auth)",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/batch.py 
b/airflow/providers/microsoft/azure/hooks/batch.py
index 7ced29cdc2..c3007d0b59 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -26,7 +26,11 @@ from azure.batch import BatchServiceClient, batch_auth, 
models as batch_models
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+    AzureIdentityCredentialAdapter,
+    add_managed_identity_connection_widgets,
+    get_field,
+)
 from airflow.utils import timezone
 
 if TYPE_CHECKING:
@@ -46,15 +50,8 @@ class AzureBatchHook(BaseHook):
     conn_type = "azure_batch"
     hook_name = "Azure Batch Service"
 
-    def _get_field(self, extras, name):
-        return get_field(
-            conn_id=self.conn_id,
-            conn_type=self.conn_type,
-            extras=extras,
-            field_name=name,
-        )
-
     @classmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -63,12 +60,6 @@ class AzureBatchHook(BaseHook):
 
         return {
             "account_url": StringField(lazy_gettext("Batch Account URL"), 
widget=BS3TextFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @classmethod
@@ -79,8 +70,6 @@ class AzureBatchHook(BaseHook):
             "relabeling": {
                 "login": "Batch Account Name",
                 "password": "Batch Account Access Key",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
@@ -88,6 +77,14 @@ class AzureBatchHook(BaseHook):
         super().__init__()
         self.conn_id = azure_batch_conn_id
 
+    def _get_field(self, extras, name):
+        return get_field(
+            conn_id=self.conn_id,
+            conn_type=self.conn_type,
+            extras=extras,
+            field_name=name,
+        )
+
     @cached_property
     def connection(self) -> BatchServiceClient:
         """Get the Batch client connection (cached)."""
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py 
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index ea7b129a46..04cfa719ef 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -25,7 +25,11 @@ from azure.mgmt.containerinstance.models import 
ImageRegistryCredential
 from azure.mgmt.containerregistry import ContainerRegistryManagementClient
 
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+    get_field,
+)
 
 
 class AzureContainerRegistryHook(BaseHook):
@@ -43,6 +47,7 @@ class AzureContainerRegistryHook(BaseHook):
     hook_name = "Azure Container Registry"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -58,12 +63,6 @@ class AzureContainerRegistryHook(BaseHook):
                 lazy_gettext("Resource group name (optional)"),
                 widget=BS3TextFieldWidget(),
             ),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @classmethod
@@ -82,8 +81,6 @@ class AzureContainerRegistryHook(BaseHook):
                 "host": "docker image registry server",
                 "subscription_id": "Subscription id (required for Azure AD 
authentication)",
                 "resource_group": "Resource group name (required for Azure AD 
authentication)",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py 
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index 791c08303b..fdf6d96407 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -22,7 +22,11 @@ from azure.mgmt.containerinstance.models import 
AzureFileVolume, Volume
 from azure.mgmt.storage import StorageManagementClient
 
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+    get_field,
+)
 
 
 class AzureContainerVolumeHook(BaseHook):
@@ -39,19 +43,8 @@ class AzureContainerVolumeHook(BaseHook):
     conn_type = "azure_container_volume"
     hook_name = "Azure Container Volume"
 
-    def __init__(self, azure_container_volume_conn_id: str = 
"azure_container_volume_default") -> None:
-        super().__init__()
-        self.conn_id = azure_container_volume_conn_id
-
-    def _get_field(self, extras, name):
-        return get_field(
-            conn_id=self.conn_id,
-            conn_type=self.conn_type,
-            extras=extras,
-            field_name=name,
-        )
-
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
@@ -70,12 +63,6 @@ class AzureContainerVolumeHook(BaseHook):
                 lazy_gettext("Resource group name (optional)"),
                 widget=BS3TextFieldWidget(),
             ),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -93,11 +80,21 @@ class AzureContainerVolumeHook(BaseHook):
                 "connection_string": "connection string auth",
                 "subscription_id": "Subscription id (required for Azure AD 
authentication)",
                 "resource_group": "Resource group name (required for Azure AD 
authentication)",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
+    def __init__(self, azure_container_volume_conn_id: str = 
"azure_container_volume_default") -> None:
+        super().__init__()
+        self.conn_id = azure_container_volume_conn_id
+
+    def _get_field(self, extras, name):
+        return get_field(
+            conn_id=self.conn_id,
+            conn_type=self.conn_type,
+            extras=extras,
+            field_name=name,
+        )
+
     def get_storagekey(self, *, storage_account_name: str | None = None) -> 
str:
         """Get Azure File Volume storage key."""
         conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py 
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 596faf6b04..1af863090c 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -35,7 +35,11 @@ from azure.mgmt.cosmosdb import CosmosDBManagementClient
 
 from airflow.exceptions import AirflowBadRequest, AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+    get_field,
+)
 
 
 class AzureCosmosDBHook(BaseHook):
@@ -56,6 +60,7 @@ class AzureCosmosDBHook(BaseHook):
     hook_name = "Azure CosmosDB"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -77,12 +82,6 @@ class AzureCosmosDBHook(BaseHook):
                 lazy_gettext("Resource Group Name (optional)"),
                 widget=BS3TextFieldWidget(),
             ),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -101,8 +100,6 @@ class AzureCosmosDBHook(BaseHook):
                 "collection_name": "collection name",
                 "subscription_id": "Subscription ID (required for Azure AD 
authentication)",
                 "resource_group_name": "Resource Group Name (required for 
Azure AD authentication)",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py 
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index cc9cc2ec8a..e8a80242ea 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -48,7 +48,10 @@ from azure.mgmt.datafactory.aio import 
DataFactoryManagementClient as AsyncDataF
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+)
 
 if TYPE_CHECKING:
     from azure.core.polling import LROPoller
@@ -153,6 +156,7 @@ class AzureDataFactoryHook(BaseHook):
     hook_name: str = "Azure Data Factory"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -166,12 +170,6 @@ class AzureDataFactoryHook(BaseHook):
                 lazy_gettext("Resource Group Name"), 
widget=BS3TextFieldWidget()
             ),
             "factory_name": StringField(lazy_gettext("Factory Name"), 
widget=BS3TextFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -182,8 +180,6 @@ class AzureDataFactoryHook(BaseHook):
             "relabeling": {
                 "login": "Client ID",
                 "password": "Secret",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py 
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 0b344a41ad..557c70fe99 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -34,7 +34,11 @@ from azure.storage.filedatalake import (
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+    AzureIdentityCredentialAdapter,
+    add_managed_identity_connection_widgets,
+    get_field,
+)
 
 Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
 
@@ -62,6 +66,7 @@ class AzureDataLakeHook(BaseHook):
     hook_name = "Azure Data Lake"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -265,6 +270,7 @@ class AzureDataLakeStorageV2Hook(BaseHook):
     hook_name = "Azure Date Lake Storage V2"
 
     @classmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
@@ -305,6 +311,17 @@ class AzureDataLakeStorageV2Hook(BaseHook):
         self.conn_id = adls_conn_id
         self.public_read = public_read
 
+    def _get_field(self, extra_dict, field_name):
+        prefix = "extra__adls__"
+        if field_name.startswith("extra__"):
+            raise ValueError(
+                f"Got prefixed name {field_name}; please remove the '{prefix}' 
prefix "
+                f"when using this method."
+            )
+        if field_name in extra_dict:
+            return extra_dict[field_name] or None
+        return extra_dict.get(f"{prefix}{field_name}") or None
+
     @cached_property
     def service_client(self) -> DataLakeServiceClient:
         """Return the DataLakeServiceClient object (cached)."""
@@ -338,17 +355,6 @@ class AzureDataLakeStorageV2Hook(BaseHook):
             **extra,
         )
 
-    def _get_field(self, extra_dict, field_name):
-        prefix = "extra__adls__"
-        if field_name.startswith("extra__"):
-            raise ValueError(
-                f"Got prefixed name {field_name}; please remove the '{prefix}' 
prefix "
-                f"when using this method."
-            )
-        if field_name in extra_dict:
-            return extra_dict[field_name] or None
-        return extra_dict.get(f"{prefix}{field_name}") or None
-
     def create_file_system(self, file_system_name: str) -> None:
         """Create a new file system under the specified account.
 
diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py 
b/airflow/providers/microsoft/azure/hooks/fileshare.py
index 16c7218505..92235e7ecf 100644
--- a/airflow/providers/microsoft/azure/hooks/fileshare.py
+++ b/airflow/providers/microsoft/azure/hooks/fileshare.py
@@ -22,7 +22,10 @@ from typing import IO, Any
 from azure.storage.fileshare import FileProperties, ShareDirectoryClient, 
ShareFileClient, ShareServiceClient
 
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+)
 
 
 class AzureFileShareHook(BaseHook):
@@ -39,24 +42,8 @@ class AzureFileShareHook(BaseHook):
     conn_type = "azure_fileshare"
     hook_name = "Azure FileShare"
 
-    def __init__(
-        self,
-        share_name: str | None = None,
-        file_path: str | None = None,
-        directory_path: str | None = None,
-        azure_fileshare_conn_id: str = "azure_fileshare_default",
-    ) -> None:
-        super().__init__()
-        self._conn_id = azure_fileshare_conn_id
-        self.share_name = share_name
-        self.file_path = file_path
-        self.directory_path = directory_path
-        self._account_url: str | None = None
-        self._connection_string: str | None = None
-        self._account_access_key: str | None = None
-        self._sas_token: str | None = None
-
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
@@ -68,12 +55,6 @@ class AzureFileShareHook(BaseHook):
             "connection_string": StringField(
                 lazy_gettext("Connection String (optional)"), 
widget=BS3TextFieldWidget()
             ),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -90,11 +71,26 @@ class AzureFileShareHook(BaseHook):
                 "password": "secret",
                 "sas_token": "account url or token (optional)",
                 "connection_string": "account url or token (optional)",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
+    def __init__(
+        self,
+        share_name: str | None = None,
+        file_path: str | None = None,
+        directory_path: str | None = None,
+        azure_fileshare_conn_id: str = "azure_fileshare_default",
+    ) -> None:
+        super().__init__()
+        self._conn_id = azure_fileshare_conn_id
+        self.share_name = share_name
+        self.file_path = file_path
+        self.directory_path = directory_path
+        self._account_url: str | None = None
+        self._connection_string: str | None = None
+        self._account_access_key: str | None = None
+        self._sas_token: str | None = None
+
     def get_conn(self) -> None:
         conn = self.get_connection(self._conn_id)
         extras = conn.extra_dejson
@@ -110,6 +106,13 @@ class AzureFileShareHook(BaseHook):
             return f"https://{account_url}.file.core.windows.net";
         return account_url
 
+    def _get_default_azure_credential(self):
+        conn = self.get_connection(self._conn_id)
+        extras = conn.extra_dejson
+        managed_identity_client_id = extras.get("managed_identity_client_id")
+        workload_identity_tenant_id = extras.get("workload_identity_tenant_id")
+        return get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
+
     @property
     def share_service_client(self):
         self.get_conn()
@@ -121,15 +124,9 @@ class AzureFileShareHook(BaseHook):
             credential = self._sas_token or self._account_access_key
             return ShareServiceClient(account_url=self._account_url, 
credential=credential)
         else:
-            conn = self.get_connection(self._conn_id)
-            extras = conn.extra_dejson
-            managed_identity_client_id = 
extras.get("managed_identity_client_id")
-            workload_identity_tenant_id = 
extras.get("workload_identity_tenant_id")
             return ShareServiceClient(
                 account_url=self._account_url,
-                credential=get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
-                ),
+                credential=self._get_default_azure_credential(),
                 token_intent="backup",
             )
 
@@ -150,17 +147,11 @@ class AzureFileShareHook(BaseHook):
                 credential=credential,
             )
         else:
-            conn = self.get_connection(self._conn_id)
-            extras = conn.extra_dejson
-            managed_identity_client_id = 
extras.get("managed_identity_client_id")
-            workload_identity_tenant_id = 
extras.get("workload_identity_tenant_id")
             return ShareDirectoryClient(
                 account_url=self._account_url,
                 share_name=self.share_name,
                 directory_path=self.directory_path,
-                credential=get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
-                ),
+                credential=self._get_default_azure_credential(),
                 token_intent="backup",
             )
 
@@ -181,17 +172,11 @@ class AzureFileShareHook(BaseHook):
                 credential=credential,
             )
         else:
-            conn = self.get_connection(self._conn_id)
-            extras = conn.extra_dejson
-            managed_identity_client_id = 
extras.get("managed_identity_client_id")
-            workload_identity_tenant_id = 
extras.get("workload_identity_tenant_id")
             return ShareFileClient(
                 account_url=self._account_url,
                 share_name=self.share_name,
                 file_path=self.file_path,
-                credential=get_default_azure_credential(
-                    managed_identity_client_id, workload_identity_tenant_id
-                ),
+                credential=self._get_default_azure_credential(),
                 token_intent="backup",
             )
 
diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py 
b/airflow/providers/microsoft/azure/hooks/synapse.py
index eb275781bb..a4f0f5d7c3 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -24,7 +24,11 @@ from azure.synapse.spark import SparkClient
 
 from airflow.exceptions import AirflowTaskTimeout
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+    get_field,
+)
 
 if TYPE_CHECKING:
     from azure.synapse.spark.models import SparkBatchJobOptions
@@ -63,6 +67,7 @@ class AzureSynapseHook(BaseHook):
     hook_name: str = "Azure Synapse"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -72,12 +77,6 @@ class AzureSynapseHook(BaseHook):
         return {
             "tenantId": StringField(lazy_gettext("Tenant ID"), 
widget=BS3TextFieldWidget()),
             "subscriptionId": StringField(lazy_gettext("Subscription ID"), 
widget=BS3TextFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -89,8 +88,6 @@ class AzureSynapseHook(BaseHook):
                 "login": "Client ID",
                 "password": "Secret",
                 "host": "Synapse Workspace URL",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index b69e999c5b..8c7f1636a1 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -47,7 +47,10 @@ from azure.storage.blob.aio import (
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import 
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+    add_managed_identity_connection_widgets,
+    get_default_azure_credential,
+)
 
 if TYPE_CHECKING:
     from azure.storage.blob._models import BlobProperties
@@ -78,6 +81,7 @@ class WasbHook(BaseHook):
     hook_name = "Azure Blob Storage"
 
     @staticmethod
+    @add_managed_identity_connection_widgets
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
@@ -95,12 +99,6 @@ class WasbHook(BaseHook):
                 lazy_gettext("Tenant Id (Active Directory Auth)"), 
widget=BS3TextFieldWidget()
             ),
             "sas_token": PasswordField(lazy_gettext("SAS Token (optional)"), 
widget=BS3PasswordFieldWidget()),
-            "managed_identity_client_id": StringField(
-                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
-            ),
-            "workload_identity_tenant_id": StringField(
-                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
-            ),
         }
 
     @staticmethod
@@ -122,8 +120,6 @@ class WasbHook(BaseHook):
                 "shared_access_key": "shared access key",
                 "sas_token": "account url or token",
                 "extra": "additional options for use with 
ClientSecretCredential or DefaultAzureCredential",
-                "managed_identity_client_id": "Managed Identity Client ID",
-                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/utils.py 
b/airflow/providers/microsoft/azure/utils.py
index f63580e958..c41a4e758c 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import warnings
+from functools import wraps
 
 from azure.core.pipeline import PipelineContext, PipelineRequest
 from azure.core.pipeline.policies import BearerTokenCredentialPolicy
@@ -69,6 +70,29 @@ def get_default_azure_credential(
         return DefaultAzureCredential()
 
 
+def add_managed_identity_connection_widgets(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import StringField
+
+        widgets = func(*args, **kwargs)
+        widgets.update(
+            {
+                "managed_identity_client_id": StringField(
+                    lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
+                ),
+                "workload_identity_tenant_id": StringField(
+                    lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
+                ),
+            }
+        )
+        return widgets
+
+    return wrapper
+
+
 class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
     """Adapt azure-identity credentials for backward compatibility.
 
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py 
b/tests/providers/microsoft/azure/hooks/test_adx.py
index 39b408af2c..c8de2d933f 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -34,65 +34,46 @@ pytestmark = pytest.mark.db_test
 
 class TestAzureDataExplorerHook:
     @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="missing_method",
-                conn_type="azure_data_explorer",
-                login="client_id",
-                password="client secret",
-                host="https://help.kusto.windows.net";,
-                extra={},
-            )
-        ],
-        indirect=True,
-    )
-    def test_conn_missing_method(self, mocked_connection):
-        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
-        error_pattern = "is missing: `auth_method`"
-        with pytest.raises(AirflowException, match=error_pattern):
-            assert hook.get_conn()
-        with pytest.raises(AirflowException, match=error_pattern):
-            assert hook.connection
-
-    @pytest.mark.parametrize(
-        "mocked_connection",
+        "mocked_connection, error_pattern",
         [
-            Connection(
-                conn_id="unknown_method",
-                conn_type="azure_data_explorer",
-                login="client_id",
-                password="client secret",
-                host="https://help.kusto.windows.net";,
-                extra={"auth_method": "AAD_OTHER"},
+            (
+                Connection(
+                    conn_id="missing_method",
+                    conn_type="azure_data_explorer",
+                    login="client_id",
+                    password="client secret",
+                    host="https://help.kusto.windows.net";,
+                    extra={},
+                ),
+                "is missing: `auth_method`",
             ),
-        ],
-        indirect=True,
-    )
-    def test_conn_unknown_method(self, mocked_connection):
-        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
-        error_pattern = "Unknown authentication method: AAD_OTHER"
-        with pytest.raises(AirflowException, match=error_pattern):
-            assert hook.get_conn()
-        with pytest.raises(AirflowException, match=error_pattern):
-            assert hook.connection
-
-    @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="missing_cluster",
-                conn_type="azure_data_explorer",
-                login="client_id",
-                password="client secret",
-                extra={},
+            (
+                Connection(
+                    conn_id="unknown_method",
+                    conn_type="azure_data_explorer",
+                    login="client_id",
+                    password="client secret",
+                    host="https://help.kusto.windows.net";,
+                    extra={"auth_method": "AAD_OTHER"},
+                ),
+                "Unknown authentication method: AAD_OTHER",
+            ),
+            (
+                Connection(
+                    conn_id="missing_cluster",
+                    conn_type="azure_data_explorer",
+                    login="client_id",
+                    password="client secret",
+                    extra={},
+                ),
+                "Host connection option is required",
             ),
         ],
-        indirect=True,
+        indirect=["mocked_connection"],
+        ids=["missing_method", "unknown_method", "missing_cluster"],
     )
-    def test_conn_missing_cluster(self, mocked_connection):
+    def test_conn_errors(self, mocked_connection, error_pattern):
         hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
-        error_pattern = "Host connection option is required"
         with pytest.raises(AirflowException, match=error_pattern):
             assert hook.get_conn()
         with pytest.raises(AirflowException, match=error_pattern):
@@ -179,7 +160,7 @@ class TestAzureDataExplorerHook:
                 },
             )
         ],
-        indirect=True,
+        indirect=["mocked_connection"],
     )
     @mock.patch.object(KustoClient, "__init__")
     def test_conn_method_aad_app(self, mock_init, mocked_connection):
@@ -238,6 +219,35 @@ class TestAzureDataExplorerHook:
             
KustoConnectionStringBuilder.with_aad_device_authentication("https://help.kusto.windows.net";)
         )
 
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type="azure_data_explorer",
+                host="https://help.kusto.windows.net";,
+                extra={
+                    "auth_method": "AZURE_TOKEN_CRED",
+                    "managed_identity_client_id": "test_id",
+                    "workload_identity_tenant_id": "test_tenant_id",
+                },
+            )
+        ],
+        indirect=True,
+    )
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_default_azure_credential")
+    @mock.patch.object(KustoClient, "__init__")
+    def test_conn_method_azure_token_cred(self, mock_init, 
mock_default_azure_credential, mocked_connection):
+        mock_init.return_value = None
+        
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
+        assert mock_default_azure_credential.called_with("test_id", 
"test_tenant_id")
+        assert mock_init.called_with(
+            KustoConnectionStringBuilder.with_azure_token_credential(
+                connection_string="https://help.kusto.windows.net";,
+                credential=mock_default_azure_credential,
+            )
+        )
+
     @pytest.mark.parametrize(
         "mocked_connection",
         [
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py 
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 4f1939aa8b..195017bfd3 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -31,6 +31,7 @@ from airflow.providers.microsoft.azure.hooks.asb import 
AdminClientHook, Message
 
 MESSAGE = "Test Message"
 MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
+MODULE = "airflow.providers.microsoft.azure.hooks.asb"
 
 
 class TestAdminClientHook:
@@ -60,18 +61,18 @@ class TestAdminClientHook:
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.get_default_azure_credential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
+    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.AdminClientHook.get_connection")
     def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
         self, mock_connection, mock_default_azure_credential
     ):
         mock_connection.return_value = self.mock_conn_without_schema
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
-        mock_default_azure_credential.assert_called_once()
+        mock_default_azure_credential.assert_called_with(None, None)
 
     @mock.patch("azure.servicebus.management.QueueProperties")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+    @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
     def test_create_queue(self, mock_sb_admin_client, mock_queue_properties):
         """
         Test `create_queue` hook function with mocking connection, queue 
properties value and
@@ -85,14 +86,14 @@ class TestAdminClientHook:
         response = hook.create_queue(self.queue_name)
         assert response == mock_queue_properties
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient")
+    @mock.patch(f"{MODULE}.ServiceBusAdministrationClient")
     def test_create_queue_exception(self, mock_sb_admin_client):
         """Test `create_queue` functionality to raise ValueError by passing 
queue name as None"""
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         with pytest.raises(TypeError):
             hook.create_queue(None)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+    @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
     def test_delete_queue(self, mock_sb_admin_client):
         """
         Test Delete queue functionality by passing queue name, assert the 
function with values,
@@ -103,14 +104,14 @@ class TestAdminClientHook:
         expected_calls = 
[mock.call().__enter__().delete_queue(self.queue_name)]
         mock_sb_admin_client.assert_has_calls(expected_calls)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient")
+    @mock.patch(f"{MODULE}.ServiceBusAdministrationClient")
     def test_delete_queue_exception(self, mock_sb_admin_client):
         """Test `delete_queue` functionality to raise ValueError, by passing 
queue name as None"""
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         with pytest.raises(TypeError):
             hook.delete_queue(None)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+    @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
     def test_delete_subscription(self, mock_sb_admin_client):
         """
         Test Delete subscription functionality by passing subscription name 
and topic name,
@@ -127,7 +128,7 @@ class TestAdminClientHook:
         "mock_subscription_name, mock_topic_name",
         [("subscription_1", None), (None, "topic_1")],
     )
-    @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook")
+    @mock.patch(f"{MODULE}.AdminClientHook")
     def test_delete_subscription_exception(
         self, mock_sb_admin_client, mock_subscription_name, mock_topic_name
     ):
@@ -171,15 +172,15 @@ class TestMessageHook:
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.get_default_azure_credential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
+    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.MessageHook.get_connection")
     def 
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
         self, mock_connection, mock_default_azure_credential
     ):
         mock_connection.return_value = self.mock_conn_without_schema
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
-        mock_default_azure_credential.assert_called_once()
+        mock_default_azure_credential.assert_called_with(None, None)
 
     @pytest.mark.parametrize(
         "mock_message, mock_batch_flag",
@@ -190,9 +191,9 @@ class TestMessageHook:
             (MESSAGE_LIST, False),
         ],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_list_messages")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_batch_message")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.send_list_messages")
+    @mock.patch(f"{MODULE}.MessageHook.send_batch_message")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_send_message(
         self, mock_sb_client, mock_batch_message, mock_list_message, 
mock_message, mock_batch_flag
     ):
@@ -225,7 +226,7 @@ class TestMessageHook:
         ]
         mock_sb_client.assert_has_calls(expected_calls, any_order=False)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_send_message_exception(self, mock_sb_client):
         """
         Test `send_message` functionality to raise AirflowException in Azure 
MessageHook
@@ -236,7 +237,7 @@ class TestMessageHook:
             hook.send_message(queue_name=None, messages="", 
batch_message_flag=False)
 
     @mock.patch("azure.servicebus.ServiceBusMessage")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_receive_message(self, mock_sb_client, mock_service_bus_message):
         """
         Test `receive_message` hook function and assert the function with mock 
value,
@@ -260,7 +261,7 @@ class TestMessageHook:
         ]
         mock_sb_client.assert_has_calls(expected_calls)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_receive_message_exception(self, mock_sb_client):
         """
         Test `receive_message` functionality to raise AirflowException in 
Azure MessageHook
@@ -270,7 +271,7 @@ class TestMessageHook:
         with pytest.raises(TypeError):
             hook.receive_message(None)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_receive_subscription_message(self, mock_sb_client):
         """
         Test `receive_subscription_message` hook function and assert the 
function with mock value,
@@ -299,7 +300,7 @@ class TestMessageHook:
         "mock_subscription_name, mock_topic_name, mock_max_count, 
mock_wait_time",
         [("subscription_1", None, None, None), (None, "topic_1", None, None)],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    @mock.patch(f"{MODULE}.MessageHook.get_conn")
     def test_receive_subscription_message_exception(
         self, mock_sb_client, mock_subscription_name, mock_topic_name, 
mock_max_count, mock_wait_time
     ):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py 
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 057342edfe..d9daa2f13d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -134,7 +134,7 @@ class TestAzureBatchHook:
             )
             assert isinstance(pool, batch_models.PoolAddParameter)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_create_pool_with_vm_config(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
         mock_instance = mock_batch.return_value.pool.add
@@ -150,7 +150,7 @@ class TestAzureBatchHook:
         hook.create_pool(pool=pool)
         mock_instance.assert_called_once_with(pool)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_create_pool_with_cloud_config(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
         mock_instance = mock_batch.return_value.pool.add
@@ -166,12 +166,12 @@ class TestAzureBatchHook:
         hook.create_pool(pool=pool)
         mock_instance.assert_called_once_with(pool)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_wait_for_all_nodes(self, mock_batch):
         # TODO: Add test
         pass
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_job_configuration_and_create_job(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
         mock_instance = mock_batch.return_value.job.add
@@ -180,7 +180,7 @@ class TestAzureBatchHook:
         assert isinstance(job, batch_models.JobAddParameter)
         mock_instance.assert_called_once_with(job)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_add_single_task_to_job(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
         mock_instance = mock_batch.return_value.task.add
@@ -189,12 +189,12 @@ class TestAzureBatchHook:
         assert isinstance(task, batch_models.TaskAddParameter)
         mock_instance.assert_called_once_with(job_id="myjob", task=task)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_wait_for_all_task_to_complete(self, mock_batch):
         # TODO: Add test
         pass
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_connection_success(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
         hook.connection.job.return_value = {}
@@ -202,7 +202,7 @@ class TestAzureBatchHook:
         assert status is True
         assert msg == "Successfully connected to Azure Batch."
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    @mock.patch(f"{MODULE}.BatchServiceClient")
     def test_connection_failure(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
         hook.connection.job.list = 
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index 3122ba2b9e..bfe8a57484 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -148,7 +148,7 @@ class TestAzureContainerInstanceHookWithoutSetupCredential:
         hook = 
AzureContainerInstanceHook(azure_conn_id=connection_without_login_password_tenant_id.conn_id)
         conn = hook.get_conn()
 
-        mock_default_azure_credential.assert_called_once()
+        mock_default_azure_credential.assert_called_with(None, None)
         assert not mock_service_pricipal_credential.called
         assert conn == mock_client_instance
         mock_client_cls.assert_called_once_with(
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py 
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index 2508cb9dd8..bbb7f7f870 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -29,6 +29,8 @@ from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
 
+MODULE = "airflow.providers.microsoft.azure.hooks.cosmos"
+
 
 class TestAzureCosmosDbHook:
     # Set up an environment to test with
@@ -54,12 +56,40 @@ class TestAzureCosmosDbHook:
             )
         )
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient", 
autospec=True)
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="azure_cosmos_test_default_credential",
+                conn_type="azure_cosmos",
+                login="https://test_endpoint:443";,
+                extra={
+                    "resource_group_name": "resource-group-name",
+                    "subscription_id": "subscription_id",
+                    "managed_identity_client_id": "test_client_id",
+                    "workload_identity_tenant_id": "test_tenant_id",
+                },
+            )
+        ],
+        indirect=True,
+    )
+    @mock.patch(f"{MODULE}.get_default_azure_credential")
+    @mock.patch(f"{MODULE}.CosmosDBManagementClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
+    def test_get_conn(self, mock_cosmos, mock_cosmos_db, 
mock_default_azure_credential, mocked_connection):
+        
mock_cosmos_db.return_value.database_accounts.list_keys.return_value.primary_master_key
 = "master-key"
+
+        hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_default_credential")
+        hook.get_conn()
+
+        assert mock_default_azure_credential.called_with("test_client_id", 
"test_tenant_id")
+
+    @mock.patch(f"{MODULE}.CosmosClient", autospec=True)
     def test_client(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         assert isinstance(hook.get_conn(), CosmosClient)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_database(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.create_database(self.test_database_name)
@@ -67,19 +97,19 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_database_exception(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         with pytest.raises(AirflowException):
             hook.create_database(None)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_container_exception(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         with pytest.raises(AirflowException):
             hook.create_collection(None)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_container(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.create_collection(self.test_collection_name, 
self.test_database_name)
@@ -91,7 +121,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_create_container_default(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.create_collection(self.test_collection_name)
@@ -103,7 +133,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_upsert_document_default(self, mock_cosmos):
         test_id = str(uuid.uuid4())
 
@@ -124,7 +154,7 @@ class TestAzureCosmosDbHook:
         logging.getLogger().info(returned_item)
         assert returned_item["id"] == test_id
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_upsert_document(self, mock_cosmos):
         test_id = str(uuid.uuid4())
 
@@ -152,7 +182,7 @@ class TestAzureCosmosDbHook:
         logging.getLogger().info(returned_item)
         assert returned_item["id"] == test_id
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_insert_documents(self, mock_cosmos):
         test_id1 = str(uuid.uuid4())
         test_id2 = str(uuid.uuid4())
@@ -183,7 +213,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls, any_order=True)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_delete_database(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.delete_database(self.test_database_name)
@@ -191,7 +221,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_delete_database_exception(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         with pytest.raises(AirflowException):
@@ -203,7 +233,7 @@ class TestAzureCosmosDbHook:
         with pytest.raises(AirflowException):
             hook.delete_collection(None)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_delete_container(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.delete_collection(self.test_collection_name, 
self.test_database_name)
@@ -213,7 +243,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_delete_container_default(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.delete_collection(self.test_collection_name)
@@ -223,7 +253,7 @@ class TestAzureCosmosDbHook:
         mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": 
self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_connection_success(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.get_conn().list_databases.return_value = {"id": 
self.test_database_name}
@@ -231,7 +261,7 @@ class TestAzureCosmosDbHook:
         assert status is True
         assert msg == "Successfully connected to Azure Cosmos."
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+    @mock.patch(f"{MODULE}.CosmosClient")
     def test_connection_failure(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
         hook.get_conn().list_databases = 
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 34d0bcb481..14a7121202 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -21,7 +21,6 @@ from unittest import mock
 from unittest.mock import MagicMock, PropertyMock, patch
 
 import pytest
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
 from azure.mgmt.datafactory.aio import DataFactoryManagementClient
 from azure.mgmt.datafactory.models import FactoryListResponse
 
@@ -55,6 +54,8 @@ MODEL = object()
 NAME = "testName"
 ID = "testId"
 
+MODULE = "airflow.providers.microsoft.azure.hooks.data_factory"
+
 
 @pytest.fixture(autouse=True)
 def setup_connections(create_mock_connections):
@@ -164,23 +165,30 @@ def test_provide_targeted_factory():
         provide_targeted_factory(echo)(hook)
 
 
[email protected](
-    ("connection_id", "credential_type"),
-    [
-        (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
-        (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
-    ],
-)
-def test_get_connection_by_credential_client_secret(connection_id: str, 
credential_type: type):
-    hook = AzureDataFactoryHook(connection_id)
[email protected](f"{MODULE}.ClientSecretCredential")
+def test_get_conn_by_credential_client_secret(mock_credential):
+    hook = AzureDataFactoryHook(DEFAULT_CONNECTION_CLIENT_SECRET)
 
     with patch.object(hook, "_create_client") as mock_create_client:
         mock_create_client.return_value = MagicMock()
+
+        connection = hook.get_conn()
+        assert connection is not None
+
+        mock_create_client.assert_called_with(mock_credential(), 
"subscriptionId")
+
+
[email protected](f"{MODULE}.get_default_azure_credential")
+def test_get_conn_by_default_azure_credential(mock_credential):
+    hook = AzureDataFactoryHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
+
+    with patch.object(hook, "_create_client") as mock_create_client:
+        mock_create_client.return_value = MagicMock()
+
         connection = hook.get_conn()
         assert connection is not None
-        mock_create_client.assert_called_once()
-        assert isinstance(mock_create_client.call_args.args[0], 
credential_type)
-        assert mock_create_client.call_args.args[1] == "subscriptionId"
+        assert mock_credential.called_with(None, None)
+        mock_create_client.assert_called_with(mock_credential(), 
"subscriptionId")
 
 
 def test_get_factory(hook: AzureDataFactoryHook):
@@ -520,7 +528,7 @@ def test_connection_failure_missing_tenant_id():
         pytest.param("a://?resource_group_name=abc&factory_name=abc", 
id="no-prefix"),
     ],
 )
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
+@patch(f"{MODULE}.AzureDataFactoryHook.get_conn")
 def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
     with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
         hook = AzureDataFactoryHook("my_conn")
@@ -539,8 +547,8 @@ def 
test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
         pytest.param("a://hi:yo@?tenantId=ten&subscriptionId=sub", 
id="no-prefix"),
     ],
 )
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.ClientSecretCredential")
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook._create_client")
+@patch(f"{MODULE}.ClientSecretCredential")
+@patch(f"{MODULE}.AzureDataFactoryHook._create_client")
 def test_get_conn_backcompat_prefix_works(mock_create, mock_cred, uri):
     with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
         hook = AzureDataFactoryHook("my_conn")
@@ -549,7 +557,7 @@ def test_get_conn_backcompat_prefix_works(mock_create, 
mock_cred, uri):
         mock_create.assert_called_with(mock_cred.return_value, "sub")
 
 
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
+@patch(f"{MODULE}.AzureDataFactoryHook.get_conn")
 def test_backcompat_prefix_both_prefers_short(mock_connect):
     with patch.dict(
         os.environ,
@@ -573,8 +581,8 @@ def test_refresh_conn(hook):
 
 class TestAzureDataFactoryAsyncHook:
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
     async def test_get_adf_pipeline_run_status_queued(self, 
mock_get_pipeline_run, mock_conn):
         """Test get_adf_pipeline_run_status function with mocked status"""
         mock_status = "Queued"
@@ -584,8 +592,8 @@ class TestAzureDataFactoryAsyncHook:
         assert response == mock_status
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
     async def test_get_adf_pipeline_run_status_inprogress(
         self,
         mock_get_pipeline_run,
@@ -599,8 +607,8 @@ class TestAzureDataFactoryAsyncHook:
         assert response == mock_status
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
     async def test_get_adf_pipeline_run_status_success(self, 
mock_get_pipeline_run, mock_conn):
         """Test get_adf_pipeline_run_status function with mocked status"""
         mock_status = "Succeeded"
@@ -610,8 +618,8 @@ class TestAzureDataFactoryAsyncHook:
         assert response == mock_status
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
     async def test_get_adf_pipeline_run_status_failed(self, 
mock_get_pipeline_run, mock_conn):
         """Test get_adf_pipeline_run_status function with mocked status"""
         mock_status = "Failed"
@@ -621,8 +629,8 @@ class TestAzureDataFactoryAsyncHook:
         assert response == mock_status
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
     async def test_get_adf_pipeline_run_status_cancelled(self, 
mock_get_pipeline_run, mock_conn):
         """Test get_adf_pipeline_run_status function with mocked status"""
         mock_status = "Cancelled"
@@ -633,8 +641,8 @@ class TestAzureDataFactoryAsyncHook:
 
     @pytest.mark.asyncio
     @mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_connection")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
     async def test_get_pipeline_run_exception_without_resource(
         self, mock_conn, mock_get_connection, mock_pipeline_run
     ):
@@ -786,7 +794,7 @@ class TestAzureDataFactoryAsyncHook:
             get_field(extras, "non-existent-field", strict=True)
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+    @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
     async def test_refresh_conn(self, mock_get_async_conn):
         """Test refresh_conn method _conn is reset and get_async_conn is 
called"""
         hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index e1412ca4a4..84a0a9420b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -68,7 +68,7 @@ class TestAzureDataLakeHook:
             )
         )
 
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_conn(self, mock_lib):
         from azure.datalake.store import core
 
@@ -96,12 +96,12 @@ class TestAzureDataLakeHook:
         assert hook._conn is None
         assert hook.conn_id == "adl_test_key_without_tenant"
         assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
-        assert mock_azure_identity_credential_adapter.called
+        assert mock_azure_identity_credential_adapter.called_with(None, None)
         assert not mock_datalake_store_lib.auth.called
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_check_for_blob(self, mock_lib, mock_filesystem):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -111,8 +111,8 @@ class TestAzureDataLakeHook:
         mocked_glob.assert_called()
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.multithread.ADLUploader", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_upload_file(self, mock_lib, mock_uploader):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -136,8 +136,8 @@ class TestAzureDataLakeHook:
         )
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLDownloader",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.multithread.ADLDownloader", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_download_file(self, mock_lib, mock_downloader):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -161,8 +161,8 @@ class TestAzureDataLakeHook:
         )
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_list_glob(self, mock_lib, mock_fs):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -171,8 +171,8 @@ class TestAzureDataLakeHook:
         mock_fs.return_value.glob.assert_called_once_with("file_path/*")
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_list_walk(self, mock_lib, mock_fs):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -181,8 +181,8 @@ class TestAzureDataLakeHook:
         
mock_fs.return_value.walk.assert_called_once_with("file_path/some_folder/")
 
     @pytest.mark.usefixtures("connection")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
-    @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
+    @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+    @mock.patch(f"{MODULE}.lib", autospec=True)
     def test_remove(self, mock_lib, mock_fs):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
@@ -198,63 +198,57 @@ class TestAzureDataLakeStorageV2Hook:
         self.directory_name = "test_directory"
         self.file_name = "test_file_name"
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_create_file_system(self, mock_conn):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.create_file_system("test_file_system")
         expected_calls = 
[mock.call().create_file_system(file_system=self.file_system_name)]
         mock_conn.assert_has_calls(expected_calls)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.FileSystemClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.FileSystemClient")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_get_file_system(self, mock_conn, mock_file_system):
         mock_conn.return_value.get_file_system_client.return_value = 
mock_file_system
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         result = hook.get_file_system(self.file_system_name)
         assert result == mock_file_system
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
-    @mock.patch(
-        
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
-    )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_create_directory(self, mock_conn, mock_get_file_system, 
mock_directory_client):
         mock_get_file_system.return_value.create_directory.return_value = 
mock_directory_client
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         result = hook.create_directory(self.file_system_name, 
self.directory_name)
         assert result == mock_directory_client
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
-    @mock.patch(
-        
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
-    )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_get_directory(self, mock_conn, mock_get_file_system, 
mock_directory_client):
         mock_get_file_system.return_value.get_directory_client.return_value = 
mock_directory_client
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         result = hook.get_directory_client(self.file_system_name, 
self.directory_name)
         assert result == mock_directory_client
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeFileClient")
-    @mock.patch(
-        
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
-    )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.DataLakeFileClient")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_create_file(self, mock_conn, mock_get_file_system, 
mock_file_client):
         mock_get_file_system.return_value.create_file.return_value = 
mock_file_client
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         result = hook.create_file(self.file_system_name, self.file_name)
         assert result == mock_file_client
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_delete_file_system(self, mock_conn):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.delete_file_system(self.file_system_name)
         expected_calls = 
[mock.call().delete_file_system(self.file_system_name)]
         mock_conn.assert_has_calls(expected_calls)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_delete_directory(self, mock_conn, mock_directory_client):
         mock_conn.return_value.get_directory_client.return_value = 
mock_directory_client
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
@@ -267,7 +261,7 @@ class TestAzureDataLakeStorageV2Hook:
         ]
         mock_conn.assert_has_calls(expected_calls)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_list_file_system(self, mock_conn):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.list_file_system(prefix="prefix")
@@ -275,10 +269,8 @@ class TestAzureDataLakeStorageV2Hook:
             name_starts_with="prefix", include_metadata=False
         )
 
-    @mock.patch(
-        
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
-    )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_list_files_directory(self, mock_conn, mock_get_file_system):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.list_files_directory(self.file_system_name, self.directory_name)
@@ -288,7 +280,7 @@ class TestAzureDataLakeStorageV2Hook:
         argnames="list_file_systems_result",
         argvalues=[iter([FileSystemProperties]), iter([])],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_connection_success(self, mock_conn, list_file_systems_result):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.get_conn().list_file_systems.return_value = 
list_file_systems_result
@@ -297,7 +289,7 @@ class TestAzureDataLakeStorageV2Hook:
         assert status is True
         assert msg == "Successfully connected to ADLS Gen2 Storage."
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+    @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
     def test_connection_failure(self, mock_conn):
         hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
         hook.get_conn().list_file_systems = 
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py 
b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index aa14074363..7ef5262e3e 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -26,6 +26,8 @@ from azure.storage.fileshare import DirectoryProperties, 
FileProperties, ShareSe
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.fileshare import 
AzureFileShareHook
 
+MODULE = "airflow.providers.microsoft.azure.hooks.fileshare"
+
 
 class TestAzureFileshareHook:
     @pytest.fixture(autouse=True)
@@ -69,7 +71,7 @@ class TestAzureFileshareHook:
         share_client = hook.share_service_client
         assert isinstance(share_client, ShareServiceClient)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
     def test_check_for_directory(self, mock_service):
         mock_instance = mock_service.return_value
         mock_instance.exists.return_value = True
@@ -79,7 +81,7 @@ class TestAzureFileshareHook:
         assert hook.check_for_directory()
         mock_instance.exists.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
     def test_load_data(self, mock_service):
         mock_instance = mock_service.return_value
         hook = AzureFileShareHook(
@@ -88,7 +90,7 @@ class TestAzureFileshareHook:
         hook.load_data("big string")
         mock_instance.upload_file.assert_called_once_with("big string")
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
     def test_list_directories_and_files(self, mock_service):
         mock_instance = mock_service.return_value
         hook = AzureFileShareHook(
@@ -97,7 +99,7 @@ class TestAzureFileshareHook:
         hook.list_directories_and_files()
         mock_instance.list_directories_and_files.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
     def test_list_files(self, mock_service):
         mock_instance = mock_service.return_value
         mock_instance.list_directories_and_files.return_value = [
@@ -113,7 +115,7 @@ class TestAzureFileshareHook:
         assert files == ["file1", "file2"]
         mock_instance.list_directories_and_files.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
     def test_create_directory(self, mock_service):
         mock_instance = mock_service.return_value
         hook = AzureFileShareHook(
@@ -122,7 +124,7 @@ class TestAzureFileshareHook:
         hook.create_directory()
         mock_instance.create_directory.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
     def test_get_file(self, mock_service):
         mock_instance = mock_service.return_value
         hook = AzureFileShareHook(
@@ -131,7 +133,7 @@ class TestAzureFileshareHook:
         hook.get_file("path")
         mock_instance.download_file.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
     def test_get_file_to_stream(self, mock_service):
         mock_instance = mock_service.return_value
         hook = AzureFileShareHook(
@@ -141,21 +143,21 @@ class TestAzureFileshareHook:
         hook.get_file_to_stream(stream=data)
         mock_instance.download_file.assert_called_once_with()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
     def test_create_share(self, mock_service):
         mock_instance = mock_service.return_value
         hook = 
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
         hook.create_share(share_name="my_share")
         mock_instance.create_share.assert_called_once_with("my_share")
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
     def test_delete_share(self, mock_service):
         mock_instance = mock_service.return_value
         hook = 
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
         hook.delete_share("my_share")
         mock_instance.delete_share.assert_called_once_with("my_share")
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
     def test_connection_success(self, mock_service):
         mock_instance = mock_service.return_value
         hook = 
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
@@ -164,7 +166,7 @@ class TestAzureFileshareHook:
         assert status is True
         assert msg == "Successfully connected to Azure File Share."
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
 autospec=True)
+    @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
     def test_connection_failure(self, mock_service):
         mock_instance = mock_service.return_value
         hook = 
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py 
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index ae3291fbd4..288a560cfa 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 from unittest.mock import MagicMock, patch
 
 import pytest
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
 from azure.synapse.spark import SparkClient
 
 from airflow.models.connection import Connection
@@ -38,6 +37,8 @@ NAME = "testName"
 ID = "testId"
 JOB_ID = 1
 
+MODULE = "airflow.providers.microsoft.azure.hooks.synapse"
+
 
 @pytest.fixture(autouse=True)
 def setup_connections(create_mock_connections):
@@ -87,23 +88,41 @@ def hook():
     return client
 
 
[email protected](
-    ("connection_id", "credential_type"),
-    [
-        (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
-        (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
-    ],
-)
-def test_get_connection_by_credential_client_secret(connection_id: str, 
credential_type: type):
-    hook = AzureSynapseHook(connection_id)
+@patch(f"{MODULE}.ClientSecretCredential")
+def test_get_connection_by_credential_client_secret(mock_credential):
+    hook = AzureSynapseHook(DEFAULT_CONNECTION_CLIENT_SECRET)
 
     with patch.object(hook, "_create_client") as mock_create_client:
         mock_create_client.return_value = MagicMock()
+
         connection = hook.get_conn()
         assert connection is not None
-        mock_create_client.assert_called_once()
-        assert isinstance(mock_create_client.call_args.args[0], 
credential_type)
-        assert mock_create_client.call_args.args[-1] == "subscriptionId"
+        mock_create_client.assert_called_with(
+            mock_credential(),
+            "https://testsynapse.dev.azuresynapse.net";,
+            "",
+            "2022-02-22-preview",
+            "subscriptionId",
+        )
+
+
+@patch(f"{MODULE}.get_default_azure_credential")
+def test_get_conn_by_default_azure_credential(mock_credential):
+    hook = AzureSynapseHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
+
+    with patch.object(hook, "_create_client") as mock_create_client:
+        mock_create_client.return_value = MagicMock()
+
+        connection = hook.get_conn()
+        assert connection is not None
+        assert mock_credential.called_with(None, None)
+        mock_create_client.assert_called_with(
+            mock_credential(),
+            "https://testsynapse.dev.azuresynapse.net";,
+            "",
+            "2022-02-22-preview",
+            "subscriptionId",
+        )
 
 
 def test_run_spark_job(hook: AzureSynapseHook):
diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py 
b/tests/providers/microsoft/azure/hooks/test_base_azure.py
index 313d61cd11..8e666d2524 100644
--- a/tests/providers/microsoft/azure/hooks/test_base_azure.py
+++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py
@@ -25,6 +25,8 @@ from airflow.providers.microsoft.azure.hooks.base_azure 
import AzureBaseHook
 
 pytestmark = pytest.mark.db_test
 
+MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
+
 
 class TestBaseAzureHook:
     @pytest.mark.parametrize(
@@ -32,7 +34,7 @@ class TestBaseAzureHook:
         [Connection(conn_id="azure_default", extra={"key_path": 
"key_file.json"})],
         indirect=True,
     )
-    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
+    @patch(f"{MODULE}.get_client_from_auth_file")
     def test_get_conn_with_key_path(self, mock_get_client_from_auth_file, 
mocked_connection):
         mock_get_client_from_auth_file.return_value = "foo-bar"
         mock_sdk_client = Mock()
@@ -49,7 +51,7 @@ class TestBaseAzureHook:
         [Connection(conn_id="azure_default", extra={"key_json": {"test": 
"test"}})],
         indirect=True,
     )
-    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
+    @patch(f"{MODULE}.get_client_from_json_dict")
     def test_get_conn_with_key_json(self, mock_get_client_from_json_dict, 
mocked_connection):
         mock_sdk_client = Mock()
         mock_get_client_from_json_dict.return_value = "foo-bar"
@@ -60,7 +62,7 @@ class TestBaseAzureHook:
         )
         assert auth_sdk_client == "foo-bar"
 
-    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials")
+    @patch(f"{MODULE}.ServicePrincipalCredentials")
     @pytest.mark.parametrize(
         "mocked_connection",
         [
@@ -88,3 +90,41 @@ class TestBaseAzureHook:
             subscription_id=mocked_connection.extra_dejson["subscriptionId"],
         )
         assert auth_sdk_client == "spam-egg"
+
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="azure_default",
+                extra={
+                    "managed_identity_client_id": "test_client_id",
+                    "workload_identity_tenant_id": "test_tenant_id",
+                    "subscriptionId": "subscription_id",
+                },
+            )
+        ],
+        indirect=True,
+    )
+    @patch("azure.common.credentials.ServicePrincipalCredentials")
+    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
+    def test_get_conn_fallback_to_azure_identity_credential_adapter(
+        self,
+        mock_credential_adapter,
+        mock_service_pricipal_credential,
+        mocked_connection,
+    ):
+        mock_credential = Mock()
+        mock_credential_adapter.return_value = mock_credential
+
+        mock_sdk_client = Mock()
+        AzureBaseHook(mock_sdk_client).get_conn()
+
+        mock_credential_adapter.assert_called_with(
+            managed_identity_client_id="test_client_id",
+            workload_identity_tenant_id="test_tenant_id",
+        )
+        assert not mock_service_pricipal_credential.called
+        mock_sdk_client.assert_called_once_with(
+            credentials=mock_credential,
+            subscription_id="subscription_id",
+        )
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py 
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index b5a8717547..7ac4fb1857 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -187,6 +187,7 @@ class TestWasbHook:
         )
 
     def test_managed_identity(self, mocked_default_azure_credential, 
mocked_blob_service_client):
+        assert mocked_default_azure_credential.called_with(None, None)
         mocked_default_azure_credential.return_value = "foo-bar"
         WasbHook(wasb_conn_id=self.managed_identity_conn_id).get_conn()
         mocked_blob_service_client.assert_called_once_with(
diff --git a/tests/providers/microsoft/azure/test_utils.py 
b/tests/providers/microsoft/azure/test_utils.py
index ad5e3cbc2d..5a081441ca 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -17,11 +17,16 @@
 
 from __future__ import annotations
 
+from typing import Any
 from unittest import mock
 
 import pytest
 
-from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+    AzureIdentityCredentialAdapter,
+    add_managed_identity_connection_widgets,
+    get_field,
+)
 
 MODULE = "airflow.providers.microsoft.azure.utils"
 
@@ -62,6 +67,16 @@ def test_get_field_non_prefixed(input, expected):
     assert value == expected
 
 
+def test_add_managed_identity_connection_widgets():
+    def test_func() -> dict[str, Any]:
+        return {}
+
+    widgets = add_managed_identity_connection_widgets(test_func)()
+
+    assert "managed_identity_client_id" in widgets
+    assert "workload_identity_tenant_id" in widgets
+
+
 class TestAzureIdentityCredentialAdapter:
     @mock.patch(f"{MODULE}.PipelineRequest")
     @mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")

Reply via email to