This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b27047430f Refactor azure managed identity (#35367)
b27047430f is described below
commit b27047430fa49538a737138e3c2e57368c4d33b0
Author: Wei Lee <[email protected]>
AuthorDate: Fri Nov 3 05:35:56 2023 +0800
Refactor azure managed identity (#35367)
---
airflow/providers/microsoft/azure/hooks/adx.py | 14 +--
airflow/providers/microsoft/azure/hooks/asb.py | 15 ++-
.../providers/microsoft/azure/hooks/base_azure.py | 14 +--
airflow/providers/microsoft/azure/hooks/batch.py | 31 +++---
.../microsoft/azure/hooks/container_registry.py | 15 ++-
.../microsoft/azure/hooks/container_volume.py | 39 ++++---
airflow/providers/microsoft/azure/hooks/cosmos.py | 15 ++-
.../microsoft/azure/hooks/data_factory.py | 14 +--
.../providers/microsoft/azure/hooks/data_lake.py | 30 +++---
.../providers/microsoft/azure/hooks/fileshare.py | 79 ++++++--------
airflow/providers/microsoft/azure/hooks/synapse.py | 15 ++-
airflow/providers/microsoft/azure/hooks/wasb.py | 14 +--
airflow/providers/microsoft/azure/utils.py | 24 +++++
tests/providers/microsoft/azure/hooks/test_adx.py | 118 +++++++++++----------
tests/providers/microsoft/azure/hooks/test_asb.py | 41 +++----
.../microsoft/azure/hooks/test_azure_batch.py | 16 +--
.../azure/hooks/test_azure_container_instance.py | 2 +-
.../microsoft/azure/hooks/test_azure_cosmos.py | 60 ++++++++---
.../azure/hooks/test_azure_data_factory.py | 68 ++++++------
.../microsoft/azure/hooks/test_azure_data_lake.py | 76 ++++++-------
.../microsoft/azure/hooks/test_azure_fileshare.py | 24 +++--
.../microsoft/azure/hooks/test_azure_synapse.py | 45 +++++---
.../microsoft/azure/hooks/test_base_azure.py | 46 +++++++-
tests/providers/microsoft/azure/hooks/test_wasb.py | 1 +
tests/providers/microsoft/azure/test_utils.py | 17 ++-
25 files changed, 466 insertions(+), 367 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/adx.py
b/airflow/providers/microsoft/azure/hooks/adx.py
index 8ee3ae8f17..4d2ba0ae6d 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -34,7 +34,10 @@ from azure.kusto.data.exceptions import KustoServiceError
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+)
if TYPE_CHECKING:
from azure.kusto.data.response import KustoResponseDataSetV2
@@ -80,6 +83,7 @@ class AzureDataExplorerHook(BaseHook):
hook_name = "Azure Data Explorer"
@classmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
@@ -95,12 +99,6 @@ class AzureDataExplorerHook(BaseHook):
"thumbprint": PasswordField(
lazy_gettext("Application Certificate Thumbprint"),
widget=BS3PasswordFieldWidget()
),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@classmethod
@@ -119,8 +117,6 @@ class AzureDataExplorerHook(BaseHook):
"tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS",
"certificate": "Used with AAD_APP_CERT",
"thumbprint": "Used with AAD_APP_CERT",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py
b/airflow/providers/microsoft/azure/hooks/asb.py
index e81bc84527..d1580f2aa4 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -22,7 +22,11 @@ from azure.servicebus import ServiceBusClient,
ServiceBusMessage, ServiceBusSend
from azure.servicebus.management import QueueProperties,
ServiceBusAdministrationClient
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+ get_field,
+)
if TYPE_CHECKING:
from azure.identity import DefaultAzureCredential
@@ -42,6 +46,7 @@ class BaseAzureServiceBusHook(BaseHook):
hook_name = "Azure Service Bus"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -53,12 +58,6 @@ class BaseAzureServiceBusHook(BaseHook):
lazy_gettext("Fully Qualified Namespace"),
widget=BS3TextFieldWidget()
),
"credential": PasswordField(lazy_gettext("Credential"),
widget=BS3TextFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -73,8 +72,6 @@ class BaseAzureServiceBusHook(BaseHook):
),
"credential": "credential",
"schema": "Endpoint=sb://<Resource
group>.servicebus.windows.net/;SharedAccessKeyName=<AccessKeyName>;SharedAccessKey=<SharedAccessKey>",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py
b/airflow/providers/microsoft/azure/hooks/base_azure.py
index 2868816379..d6ecf185c4 100644
--- a/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ b/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -24,7 +24,10 @@ from azure.common.credentials import
ServicePrincipalCredentials
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter
+from airflow.providers.microsoft.azure.utils import (
+ AzureIdentityCredentialAdapter,
+ add_managed_identity_connection_widgets,
+)
class AzureBaseHook(BaseHook):
@@ -45,6 +48,7 @@ class AzureBaseHook(BaseHook):
hook_name = "Azure"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -54,12 +58,6 @@ class AzureBaseHook(BaseHook):
return {
"tenantId": StringField(lazy_gettext("Azure Tenant ID"),
widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Azure Subscription
ID"), widget=BS3TextFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -85,8 +83,6 @@ class AzureBaseHook(BaseHook):
"password": "secret (token credentials auth)",
"tenantId": "tenantId (token credentials auth)",
"subscriptionId": "subscriptionId (token credentials auth)",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/batch.py
b/airflow/providers/microsoft/azure/hooks/batch.py
index 7ced29cdc2..c3007d0b59 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -26,7 +26,11 @@ from azure.batch import BatchServiceClient, batch_auth,
models as batch_models
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+ AzureIdentityCredentialAdapter,
+ add_managed_identity_connection_widgets,
+ get_field,
+)
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -46,15 +50,8 @@ class AzureBatchHook(BaseHook):
conn_type = "azure_batch"
hook_name = "Azure Batch Service"
- def _get_field(self, extras, name):
- return get_field(
- conn_id=self.conn_id,
- conn_type=self.conn_type,
- extras=extras,
- field_name=name,
- )
-
@classmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -63,12 +60,6 @@ class AzureBatchHook(BaseHook):
return {
"account_url": StringField(lazy_gettext("Batch Account URL"),
widget=BS3TextFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@classmethod
@@ -79,8 +70,6 @@ class AzureBatchHook(BaseHook):
"relabeling": {
"login": "Batch Account Name",
"password": "Batch Account Access Key",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
@@ -88,6 +77,14 @@ class AzureBatchHook(BaseHook):
super().__init__()
self.conn_id = azure_batch_conn_id
+ def _get_field(self, extras, name):
+ return get_field(
+ conn_id=self.conn_id,
+ conn_type=self.conn_type,
+ extras=extras,
+ field_name=name,
+ )
+
@cached_property
def connection(self) -> BatchServiceClient:
"""Get the Batch client connection (cached)."""
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index ea7b129a46..04cfa719ef 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -25,7 +25,11 @@ from azure.mgmt.containerinstance.models import
ImageRegistryCredential
from azure.mgmt.containerregistry import ContainerRegistryManagementClient
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+ get_field,
+)
class AzureContainerRegistryHook(BaseHook):
@@ -43,6 +47,7 @@ class AzureContainerRegistryHook(BaseHook):
hook_name = "Azure Container Registry"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -58,12 +63,6 @@ class AzureContainerRegistryHook(BaseHook):
lazy_gettext("Resource group name (optional)"),
widget=BS3TextFieldWidget(),
),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@classmethod
@@ -82,8 +81,6 @@ class AzureContainerRegistryHook(BaseHook):
"host": "docker image registry server",
"subscription_id": "Subscription id (required for Azure AD
authentication)",
"resource_group": "Resource group name (required for Azure AD
authentication)",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index 791c08303b..fdf6d96407 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -22,7 +22,11 @@ from azure.mgmt.containerinstance.models import
AzureFileVolume, Volume
from azure.mgmt.storage import StorageManagementClient
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+ get_field,
+)
class AzureContainerVolumeHook(BaseHook):
@@ -39,19 +43,8 @@ class AzureContainerVolumeHook(BaseHook):
conn_type = "azure_container_volume"
hook_name = "Azure Container Volume"
- def __init__(self, azure_container_volume_conn_id: str =
"azure_container_volume_default") -> None:
- super().__init__()
- self.conn_id = azure_container_volume_conn_id
-
- def _get_field(self, extras, name):
- return get_field(
- conn_id=self.conn_id,
- conn_type=self.conn_type,
- extras=extras,
- field_name=name,
- )
-
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
@@ -70,12 +63,6 @@ class AzureContainerVolumeHook(BaseHook):
lazy_gettext("Resource group name (optional)"),
widget=BS3TextFieldWidget(),
),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -93,11 +80,21 @@ class AzureContainerVolumeHook(BaseHook):
"connection_string": "connection string auth",
"subscription_id": "Subscription id (required for Azure AD
authentication)",
"resource_group": "Resource group name (required for Azure AD
authentication)",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
+ def __init__(self, azure_container_volume_conn_id: str =
"azure_container_volume_default") -> None:
+ super().__init__()
+ self.conn_id = azure_container_volume_conn_id
+
+ def _get_field(self, extras, name):
+ return get_field(
+ conn_id=self.conn_id,
+ conn_type=self.conn_type,
+ extras=extras,
+ field_name=name,
+ )
+
def get_storagekey(self, *, storage_account_name: str | None = None) ->
str:
"""Get Azure File Volume storage key."""
conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py
b/airflow/providers/microsoft/azure/hooks/cosmos.py
index 596faf6b04..1af863090c 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -35,7 +35,11 @@ from azure.mgmt.cosmosdb import CosmosDBManagementClient
from airflow.exceptions import AirflowBadRequest, AirflowException
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+ get_field,
+)
class AzureCosmosDBHook(BaseHook):
@@ -56,6 +60,7 @@ class AzureCosmosDBHook(BaseHook):
hook_name = "Azure CosmosDB"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -77,12 +82,6 @@ class AzureCosmosDBHook(BaseHook):
lazy_gettext("Resource Group Name (optional)"),
widget=BS3TextFieldWidget(),
),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -101,8 +100,6 @@ class AzureCosmosDBHook(BaseHook):
"collection_name": "collection name",
"subscription_id": "Subscription ID (required for Azure AD
authentication)",
"resource_group_name": "Resource Group Name (required for
Azure AD authentication)",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index cc9cc2ec8a..e8a80242ea 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -48,7 +48,10 @@ from azure.mgmt.datafactory.aio import
DataFactoryManagementClient as AsyncDataF
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+)
if TYPE_CHECKING:
from azure.core.polling import LROPoller
@@ -153,6 +156,7 @@ class AzureDataFactoryHook(BaseHook):
hook_name: str = "Azure Data Factory"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -166,12 +170,6 @@ class AzureDataFactoryHook(BaseHook):
lazy_gettext("Resource Group Name"),
widget=BS3TextFieldWidget()
),
"factory_name": StringField(lazy_gettext("Factory Name"),
widget=BS3TextFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -182,8 +180,6 @@ class AzureDataFactoryHook(BaseHook):
"relabeling": {
"login": "Client ID",
"password": "Secret",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 0b344a41ad..557c70fe99 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -34,7 +34,11 @@ from azure.storage.filedatalake import (
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+ AzureIdentityCredentialAdapter,
+ add_managed_identity_connection_widgets,
+ get_field,
+)
Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
@@ -62,6 +66,7 @@ class AzureDataLakeHook(BaseHook):
hook_name = "Azure Data Lake"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -265,6 +270,7 @@ class AzureDataLakeStorageV2Hook(BaseHook):
hook_name = "Azure Date Lake Storage V2"
@classmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
@@ -305,6 +311,17 @@ class AzureDataLakeStorageV2Hook(BaseHook):
self.conn_id = adls_conn_id
self.public_read = public_read
+ def _get_field(self, extra_dict, field_name):
+ prefix = "extra__adls__"
+ if field_name.startswith("extra__"):
+ raise ValueError(
+ f"Got prefixed name {field_name}; please remove the '{prefix}'
prefix "
+ f"when using this method."
+ )
+ if field_name in extra_dict:
+ return extra_dict[field_name] or None
+ return extra_dict.get(f"{prefix}{field_name}") or None
+
@cached_property
def service_client(self) -> DataLakeServiceClient:
"""Return the DataLakeServiceClient object (cached)."""
@@ -338,17 +355,6 @@ class AzureDataLakeStorageV2Hook(BaseHook):
**extra,
)
- def _get_field(self, extra_dict, field_name):
- prefix = "extra__adls__"
- if field_name.startswith("extra__"):
- raise ValueError(
- f"Got prefixed name {field_name}; please remove the '{prefix}'
prefix "
- f"when using this method."
- )
- if field_name in extra_dict:
- return extra_dict[field_name] or None
- return extra_dict.get(f"{prefix}{field_name}") or None
-
def create_file_system(self, file_system_name: str) -> None:
"""Create a new file system under the specified account.
diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py
b/airflow/providers/microsoft/azure/hooks/fileshare.py
index 16c7218505..92235e7ecf 100644
--- a/airflow/providers/microsoft/azure/hooks/fileshare.py
+++ b/airflow/providers/microsoft/azure/hooks/fileshare.py
@@ -22,7 +22,10 @@ from typing import IO, Any
from azure.storage.fileshare import FileProperties, ShareDirectoryClient,
ShareFileClient, ShareServiceClient
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+)
class AzureFileShareHook(BaseHook):
@@ -39,24 +42,8 @@ class AzureFileShareHook(BaseHook):
conn_type = "azure_fileshare"
hook_name = "Azure FileShare"
- def __init__(
- self,
- share_name: str | None = None,
- file_path: str | None = None,
- directory_path: str | None = None,
- azure_fileshare_conn_id: str = "azure_fileshare_default",
- ) -> None:
- super().__init__()
- self._conn_id = azure_fileshare_conn_id
- self.share_name = share_name
- self.file_path = file_path
- self.directory_path = directory_path
- self._account_url: str | None = None
- self._connection_string: str | None = None
- self._account_access_key: str | None = None
- self._sas_token: str | None = None
-
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
@@ -68,12 +55,6 @@ class AzureFileShareHook(BaseHook):
"connection_string": StringField(
lazy_gettext("Connection String (optional)"),
widget=BS3TextFieldWidget()
),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -90,11 +71,26 @@ class AzureFileShareHook(BaseHook):
"password": "secret",
"sas_token": "account url or token (optional)",
"connection_string": "account url or token (optional)",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
+ def __init__(
+ self,
+ share_name: str | None = None,
+ file_path: str | None = None,
+ directory_path: str | None = None,
+ azure_fileshare_conn_id: str = "azure_fileshare_default",
+ ) -> None:
+ super().__init__()
+ self._conn_id = azure_fileshare_conn_id
+ self.share_name = share_name
+ self.file_path = file_path
+ self.directory_path = directory_path
+ self._account_url: str | None = None
+ self._connection_string: str | None = None
+ self._account_access_key: str | None = None
+ self._sas_token: str | None = None
+
def get_conn(self) -> None:
conn = self.get_connection(self._conn_id)
extras = conn.extra_dejson
@@ -110,6 +106,13 @@ class AzureFileShareHook(BaseHook):
return f"https://{account_url}.file.core.windows.net"
return account_url
+ def _get_default_azure_credential(self):
+ conn = self.get_connection(self._conn_id)
+ extras = conn.extra_dejson
+ managed_identity_client_id = extras.get("managed_identity_client_id")
+ workload_identity_tenant_id = extras.get("workload_identity_tenant_id")
+ return get_default_azure_credential(managed_identity_client_id,
workload_identity_tenant_id)
+
@property
def share_service_client(self):
self.get_conn()
@@ -121,15 +124,9 @@ class AzureFileShareHook(BaseHook):
credential = self._sas_token or self._account_access_key
return ShareServiceClient(account_url=self._account_url,
credential=credential)
else:
- conn = self.get_connection(self._conn_id)
- extras = conn.extra_dejson
- managed_identity_client_id =
extras.get("managed_identity_client_id")
- workload_identity_tenant_id =
extras.get("workload_identity_tenant_id")
return ShareServiceClient(
account_url=self._account_url,
- credential=get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
- ),
+ credential=self._get_default_azure_credential(),
token_intent="backup",
)
@@ -150,17 +147,11 @@ class AzureFileShareHook(BaseHook):
credential=credential,
)
else:
- conn = self.get_connection(self._conn_id)
- extras = conn.extra_dejson
- managed_identity_client_id =
extras.get("managed_identity_client_id")
- workload_identity_tenant_id =
extras.get("workload_identity_tenant_id")
return ShareDirectoryClient(
account_url=self._account_url,
share_name=self.share_name,
directory_path=self.directory_path,
- credential=get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
- ),
+ credential=self._get_default_azure_credential(),
token_intent="backup",
)
@@ -181,17 +172,11 @@ class AzureFileShareHook(BaseHook):
credential=credential,
)
else:
- conn = self.get_connection(self._conn_id)
- extras = conn.extra_dejson
- managed_identity_client_id =
extras.get("managed_identity_client_id")
- workload_identity_tenant_id =
extras.get("workload_identity_tenant_id")
return ShareFileClient(
account_url=self._account_url,
share_name=self.share_name,
file_path=self.file_path,
- credential=get_default_azure_credential(
- managed_identity_client_id, workload_identity_tenant_id
- ),
+ credential=self._get_default_azure_credential(),
token_intent="backup",
)
diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py
b/airflow/providers/microsoft/azure/hooks/synapse.py
index eb275781bb..a4f0f5d7c3 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -24,7 +24,11 @@ from azure.synapse.spark import SparkClient
from airflow.exceptions import AirflowTaskTimeout
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential, get_field
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+ get_field,
+)
if TYPE_CHECKING:
from azure.synapse.spark.models import SparkBatchJobOptions
@@ -63,6 +67,7 @@ class AzureSynapseHook(BaseHook):
hook_name: str = "Azure Synapse"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -72,12 +77,6 @@ class AzureSynapseHook(BaseHook):
return {
"tenantId": StringField(lazy_gettext("Tenant ID"),
widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Subscription ID"),
widget=BS3TextFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -89,8 +88,6 @@ class AzureSynapseHook(BaseHook):
"login": "Client ID",
"password": "Secret",
"host": "Synapse Workspace URL",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py
b/airflow/providers/microsoft/azure/hooks/wasb.py
index b69e999c5b..8c7f1636a1 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -47,7 +47,10 @@ from azure.storage.blob.aio import (
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import
get_default_azure_credential
+from airflow.providers.microsoft.azure.utils import (
+ add_managed_identity_connection_widgets,
+ get_default_azure_credential,
+)
if TYPE_CHECKING:
from azure.storage.blob._models import BlobProperties
@@ -78,6 +81,7 @@ class WasbHook(BaseHook):
hook_name = "Azure Blob Storage"
@staticmethod
+ @add_managed_identity_connection_widgets
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
@@ -95,12 +99,6 @@ class WasbHook(BaseHook):
lazy_gettext("Tenant Id (Active Directory Auth)"),
widget=BS3TextFieldWidget()
),
"sas_token": PasswordField(lazy_gettext("SAS Token (optional)"),
widget=BS3PasswordFieldWidget()),
- "managed_identity_client_id": StringField(
- lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
- ),
- "workload_identity_tenant_id": StringField(
- lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
- ),
}
@staticmethod
@@ -122,8 +120,6 @@ class WasbHook(BaseHook):
"shared_access_key": "shared access key",
"sas_token": "account url or token",
"extra": "additional options for use with
ClientSecretCredential or DefaultAzureCredential",
- "managed_identity_client_id": "Managed Identity Client ID",
- "workload_identity_tenant_id": "Workload Identity Tenant ID",
},
}
diff --git a/airflow/providers/microsoft/azure/utils.py
b/airflow/providers/microsoft/azure/utils.py
index f63580e958..c41a4e758c 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import warnings
+from functools import wraps
from azure.core.pipeline import PipelineContext, PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
@@ -69,6 +70,29 @@ def get_default_azure_credential(
return DefaultAzureCredential()
+def add_managed_identity_connection_widgets(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+ from flask_babel import lazy_gettext
+ from wtforms import StringField
+
+ widgets = func(*args, **kwargs)
+ widgets.update(
+ {
+ "managed_identity_client_id": StringField(
+ lazy_gettext("Managed Identity Client ID"),
widget=BS3TextFieldWidget()
+ ),
+ "workload_identity_tenant_id": StringField(
+ lazy_gettext("Workload Identity Tenant ID"),
widget=BS3TextFieldWidget()
+ ),
+ }
+ )
+ return widgets
+
+ return wrapper
+
+
class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
"""Adapt azure-identity credentials for backward compatibility.
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py
b/tests/providers/microsoft/azure/hooks/test_adx.py
index 39b408af2c..c8de2d933f 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -34,65 +34,46 @@ pytestmark = pytest.mark.db_test
class TestAzureDataExplorerHook:
@pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="missing_method",
- conn_type="azure_data_explorer",
- login="client_id",
- password="client secret",
- host="https://help.kusto.windows.net",
- extra={},
- )
- ],
- indirect=True,
- )
- def test_conn_missing_method(self, mocked_connection):
- hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
- error_pattern = "is missing: `auth_method`"
- with pytest.raises(AirflowException, match=error_pattern):
- assert hook.get_conn()
- with pytest.raises(AirflowException, match=error_pattern):
- assert hook.connection
-
- @pytest.mark.parametrize(
- "mocked_connection",
+ "mocked_connection, error_pattern",
[
- Connection(
- conn_id="unknown_method",
- conn_type="azure_data_explorer",
- login="client_id",
- password="client secret",
- host="https://help.kusto.windows.net",
- extra={"auth_method": "AAD_OTHER"},
+ (
+ Connection(
+ conn_id="missing_method",
+ conn_type="azure_data_explorer",
+ login="client_id",
+ password="client secret",
+ host="https://help.kusto.windows.net",
+ extra={},
+ ),
+ "is missing: `auth_method`",
),
- ],
- indirect=True,
- )
- def test_conn_unknown_method(self, mocked_connection):
- hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
- error_pattern = "Unknown authentication method: AAD_OTHER"
- with pytest.raises(AirflowException, match=error_pattern):
- assert hook.get_conn()
- with pytest.raises(AirflowException, match=error_pattern):
- assert hook.connection
-
- @pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="missing_cluster",
- conn_type="azure_data_explorer",
- login="client_id",
- password="client secret",
- extra={},
+ (
+ Connection(
+ conn_id="unknown_method",
+ conn_type="azure_data_explorer",
+ login="client_id",
+ password="client secret",
+ host="https://help.kusto.windows.net",
+ extra={"auth_method": "AAD_OTHER"},
+ ),
+ "Unknown authentication method: AAD_OTHER",
+ ),
+ (
+ Connection(
+ conn_id="missing_cluster",
+ conn_type="azure_data_explorer",
+ login="client_id",
+ password="client secret",
+ extra={},
+ ),
+ "Host connection option is required",
),
],
- indirect=True,
+ indirect=["mocked_connection"],
+ ids=["missing_method", "unknown_method", "missing_cluster"],
)
- def test_conn_missing_cluster(self, mocked_connection):
+ def test_conn_errors(self, mocked_connection, error_pattern):
hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
- error_pattern = "Host connection option is required"
with pytest.raises(AirflowException, match=error_pattern):
assert hook.get_conn()
with pytest.raises(AirflowException, match=error_pattern):
@@ -179,7 +160,7 @@ class TestAzureDataExplorerHook:
},
)
],
- indirect=True,
+ indirect=["mocked_connection"],
)
@mock.patch.object(KustoClient, "__init__")
def test_conn_method_aad_app(self, mock_init, mocked_connection):
@@ -238,6 +219,35 @@ class TestAzureDataExplorerHook:
KustoConnectionStringBuilder.with_aad_device_authentication("https://help.kusto.windows.net")
)
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id=ADX_TEST_CONN_ID,
+ conn_type="azure_data_explorer",
+ host="https://help.kusto.windows.net",
+ extra={
+ "auth_method": "AZURE_TOKEN_CRED",
+ "managed_identity_client_id": "test_id",
+ "workload_identity_tenant_id": "test_tenant_id",
+ },
+ )
+ ],
+ indirect=True,
+ )
+
@mock.patch("airflow.providers.microsoft.azure.hooks.adx.get_default_azure_credential")
+ @mock.patch.object(KustoClient, "__init__")
+ def test_conn_method_azure_token_cred(self, mock_init,
mock_default_azure_credential, mocked_connection):
+ mock_init.return_value = None
+
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
+ assert mock_default_azure_credential.called_with("test_id",
"test_tenant_id")
+ assert mock_init.called_with(
+ KustoConnectionStringBuilder.with_azure_token_credential(
+ connection_string="https://help.kusto.windows.net",
+ credential=mock_default_azure_credential,
+ )
+ )
+
@pytest.mark.parametrize(
"mocked_connection",
[
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py
b/tests/providers/microsoft/azure/hooks/test_asb.py
index 4f1939aa8b..195017bfd3 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -31,6 +31,7 @@ from airflow.providers.microsoft.azure.hooks.asb import
AdminClientHook, Message
MESSAGE = "Test Message"
MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
+MODULE = "airflow.providers.microsoft.azure.hooks.asb"
class TestAdminClientHook:
@@ -60,18 +61,18 @@ class TestAdminClientHook:
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.get_default_azure_credential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
+ @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.AdminClientHook.get_connection")
def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
self, mock_connection, mock_default_azure_credential
):
mock_connection.return_value = self.mock_conn_without_schema
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
- mock_default_azure_credential.assert_called_once()
+ mock_default_azure_credential.assert_called_with(None, None)
@mock.patch("azure.servicebus.management.QueueProperties")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+ @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
def test_create_queue(self, mock_sb_admin_client, mock_queue_properties):
"""
Test `create_queue` hook function with mocking connection, queue
properties value and
@@ -85,14 +86,14 @@ class TestAdminClientHook:
response = hook.create_queue(self.queue_name)
assert response == mock_queue_properties
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient")
+ @mock.patch(f"{MODULE}.ServiceBusAdministrationClient")
def test_create_queue_exception(self, mock_sb_admin_client):
"""Test `create_queue` functionality to raise ValueError by passing
queue name as None"""
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
with pytest.raises(TypeError):
hook.create_queue(None)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+ @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
def test_delete_queue(self, mock_sb_admin_client):
"""
Test Delete queue functionality by passing queue name, assert the
function with values,
@@ -103,14 +104,14 @@ class TestAdminClientHook:
expected_calls =
[mock.call().__enter__().delete_queue(self.queue_name)]
mock_sb_admin_client.assert_has_calls(expected_calls)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.ServiceBusAdministrationClient")
+ @mock.patch(f"{MODULE}.ServiceBusAdministrationClient")
def test_delete_queue_exception(self, mock_sb_admin_client):
"""Test `delete_queue` functionality to raise ValueError, by passing
queue name as None"""
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
with pytest.raises(TypeError):
hook.delete_queue(None)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+ @mock.patch(f"{MODULE}.AdminClientHook.get_conn")
def test_delete_subscription(self, mock_sb_admin_client):
"""
Test Delete subscription functionality by passing subscription name
and topic name,
@@ -127,7 +128,7 @@ class TestAdminClientHook:
"mock_subscription_name, mock_topic_name",
[("subscription_1", None), (None, "topic_1")],
)
- @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook")
+ @mock.patch(f"{MODULE}.AdminClientHook")
def test_delete_subscription_exception(
self, mock_sb_admin_client, mock_subscription_name, mock_topic_name
):
@@ -171,15 +172,15 @@ class TestMessageHook:
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.get_default_azure_credential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
+ @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.MessageHook.get_connection")
def
test_get_conn_fallback_to_default_azure_credential_when_schema_is_not_provided(
self, mock_connection, mock_default_azure_credential
):
mock_connection.return_value = self.mock_conn_without_schema
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
- mock_default_azure_credential.assert_called_once()
+ mock_default_azure_credential.assert_called_with(None, None)
@pytest.mark.parametrize(
"mock_message, mock_batch_flag",
@@ -190,9 +191,9 @@ class TestMessageHook:
(MESSAGE_LIST, False),
],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_list_messages")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_batch_message")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.send_list_messages")
+ @mock.patch(f"{MODULE}.MessageHook.send_batch_message")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_send_message(
self, mock_sb_client, mock_batch_message, mock_list_message,
mock_message, mock_batch_flag
):
@@ -225,7 +226,7 @@ class TestMessageHook:
]
mock_sb_client.assert_has_calls(expected_calls, any_order=False)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_send_message_exception(self, mock_sb_client):
"""
Test `send_message` functionality to raise AirflowException in Azure
MessageHook
@@ -236,7 +237,7 @@ class TestMessageHook:
hook.send_message(queue_name=None, messages="",
batch_message_flag=False)
@mock.patch("azure.servicebus.ServiceBusMessage")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_message(self, mock_sb_client, mock_service_bus_message):
"""
Test `receive_message` hook function and assert the function with mock
value,
@@ -260,7 +261,7 @@ class TestMessageHook:
]
mock_sb_client.assert_has_calls(expected_calls)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_message_exception(self, mock_sb_client):
"""
Test `receive_message` functionality to raise AirflowException in
Azure MessageHook
@@ -270,7 +271,7 @@ class TestMessageHook:
with pytest.raises(TypeError):
hook.receive_message(None)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_subscription_message(self, mock_sb_client):
"""
Test `receive_subscription_message` hook function and assert the
function with mock value,
@@ -299,7 +300,7 @@ class TestMessageHook:
"mock_subscription_name, mock_topic_name, mock_max_count,
mock_wait_time",
[("subscription_1", None, None, None), (None, "topic_1", None, None)],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+ @mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_subscription_message_exception(
self, mock_sb_client, mock_subscription_name, mock_topic_name,
mock_max_count, mock_wait_time
):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 057342edfe..d9daa2f13d 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -134,7 +134,7 @@ class TestAzureBatchHook:
)
assert isinstance(pool, batch_models.PoolAddParameter)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_create_pool_with_vm_config(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
mock_instance = mock_batch.return_value.pool.add
@@ -150,7 +150,7 @@ class TestAzureBatchHook:
hook.create_pool(pool=pool)
mock_instance.assert_called_once_with(pool)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_create_pool_with_cloud_config(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
mock_instance = mock_batch.return_value.pool.add
@@ -166,12 +166,12 @@ class TestAzureBatchHook:
hook.create_pool(pool=pool)
mock_instance.assert_called_once_with(pool)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_wait_for_all_nodes(self, mock_batch):
# TODO: Add test
pass
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_job_configuration_and_create_job(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
mock_instance = mock_batch.return_value.job.add
@@ -180,7 +180,7 @@ class TestAzureBatchHook:
assert isinstance(job, batch_models.JobAddParameter)
mock_instance.assert_called_once_with(job)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_add_single_task_to_job(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
mock_instance = mock_batch.return_value.task.add
@@ -189,12 +189,12 @@ class TestAzureBatchHook:
assert isinstance(task, batch_models.TaskAddParameter)
mock_instance.assert_called_once_with(job_id="myjob", task=task)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_wait_for_all_task_to_complete(self, mock_batch):
# TODO: Add test
pass
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_connection_success(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
hook.connection.job.return_value = {}
@@ -202,7 +202,7 @@ class TestAzureBatchHook:
assert status is True
assert msg == "Successfully connected to Azure Batch."
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+ @mock.patch(f"{MODULE}.BatchServiceClient")
def test_connection_failure(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
hook.connection.job.list =
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index 3122ba2b9e..bfe8a57484 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -148,7 +148,7 @@ class TestAzureContainerInstanceHookWithoutSetupCredential:
hook =
AzureContainerInstanceHook(azure_conn_id=connection_without_login_password_tenant_id.conn_id)
conn = hook.get_conn()
- mock_default_azure_credential.assert_called_once()
+ mock_default_azure_credential.assert_called_with(None, None)
assert not mock_service_pricipal_credential.called
assert conn == mock_client_instance
mock_client_cls.assert_called_once_with(
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index 2508cb9dd8..bbb7f7f870 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -29,6 +29,8 @@ from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
+MODULE = "airflow.providers.microsoft.azure.hooks.cosmos"
+
class TestAzureCosmosDbHook:
# Set up an environment to test with
@@ -54,12 +56,40 @@ class TestAzureCosmosDbHook:
)
)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient",
autospec=True)
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="azure_cosmos_test_default_credential",
+ conn_type="azure_cosmos",
+ login="https://test_endpoint:443",
+ extra={
+ "resource_group_name": "resource-group-name",
+ "subscription_id": "subscription_id",
+ "managed_identity_client_id": "test_client_id",
+ "workload_identity_tenant_id": "test_tenant_id",
+ },
+ )
+ ],
+ indirect=True,
+ )
+ @mock.patch(f"{MODULE}.get_default_azure_credential")
+ @mock.patch(f"{MODULE}.CosmosDBManagementClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
+ def test_get_conn(self, mock_cosmos, mock_cosmos_db,
mock_default_azure_credential, mocked_connection):
+
mock_cosmos_db.return_value.database_accounts.list_keys.return_value.primary_master_key
= "master-key"
+
+ hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_default_credential")
+ hook.get_conn()
+
+ assert mock_default_azure_credential.called_with("test_client_id",
"test_tenant_id")
+
+ @mock.patch(f"{MODULE}.CosmosClient", autospec=True)
def test_client(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
assert isinstance(hook.get_conn(), CosmosClient)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_create_database(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.create_database(self.test_database_name)
@@ -67,19 +97,19 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_create_database_exception(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
with pytest.raises(AirflowException):
hook.create_database(None)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_create_container_exception(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
with pytest.raises(AirflowException):
hook.create_collection(None)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_create_container(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.create_collection(self.test_collection_name,
self.test_database_name)
@@ -91,7 +121,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_create_container_default(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.create_collection(self.test_collection_name)
@@ -103,7 +133,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_upsert_document_default(self, mock_cosmos):
test_id = str(uuid.uuid4())
@@ -124,7 +154,7 @@ class TestAzureCosmosDbHook:
logging.getLogger().info(returned_item)
assert returned_item["id"] == test_id
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_upsert_document(self, mock_cosmos):
test_id = str(uuid.uuid4())
@@ -152,7 +182,7 @@ class TestAzureCosmosDbHook:
logging.getLogger().info(returned_item)
assert returned_item["id"] == test_id
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_insert_documents(self, mock_cosmos):
test_id1 = str(uuid.uuid4())
test_id2 = str(uuid.uuid4())
@@ -183,7 +213,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls, any_order=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_delete_database(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.delete_database(self.test_database_name)
@@ -191,7 +221,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_delete_database_exception(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
with pytest.raises(AirflowException):
@@ -203,7 +233,7 @@ class TestAzureCosmosDbHook:
with pytest.raises(AirflowException):
hook.delete_collection(None)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_delete_container(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.delete_collection(self.test_collection_name,
self.test_database_name)
@@ -213,7 +243,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_delete_container_default(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.delete_collection(self.test_collection_name)
@@ -223,7 +253,7 @@ class TestAzureCosmosDbHook:
mock_cosmos.assert_any_call(self.test_end_point, {"masterKey":
self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_connection_success(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.get_conn().list_databases.return_value = {"id":
self.test_database_name}
@@ -231,7 +261,7 @@ class TestAzureCosmosDbHook:
assert status is True
assert msg == "Successfully connected to Azure Cosmos."
- @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
+ @mock.patch(f"{MODULE}.CosmosClient")
def test_connection_failure(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
hook.get_conn().list_databases =
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 34d0bcb481..14a7121202 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -21,7 +21,6 @@ from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
from azure.mgmt.datafactory.models import FactoryListResponse
@@ -55,6 +54,8 @@ MODEL = object()
NAME = "testName"
ID = "testId"
+MODULE = "airflow.providers.microsoft.azure.hooks.data_factory"
+
@pytest.fixture(autouse=True)
def setup_connections(create_mock_connections):
@@ -164,23 +165,30 @@ def test_provide_targeted_factory():
provide_targeted_factory(echo)(hook)
[email protected](
- ("connection_id", "credential_type"),
- [
- (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
- (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
- ],
-)
-def test_get_connection_by_credential_client_secret(connection_id: str,
credential_type: type):
- hook = AzureDataFactoryHook(connection_id)
[email protected](f"{MODULE}.ClientSecretCredential")
+def test_get_conn_by_credential_client_secret(mock_credential):
+ hook = AzureDataFactoryHook(DEFAULT_CONNECTION_CLIENT_SECRET)
with patch.object(hook, "_create_client") as mock_create_client:
mock_create_client.return_value = MagicMock()
+
+ connection = hook.get_conn()
+ assert connection is not None
+
+ mock_create_client.assert_called_with(mock_credential(),
"subscriptionId")
+
+
[email protected](f"{MODULE}.get_default_azure_credential")
+def test_get_conn_by_default_azure_credential(mock_credential):
+ hook = AzureDataFactoryHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
+
+ with patch.object(hook, "_create_client") as mock_create_client:
+ mock_create_client.return_value = MagicMock()
+
connection = hook.get_conn()
assert connection is not None
- mock_create_client.assert_called_once()
- assert isinstance(mock_create_client.call_args.args[0],
credential_type)
- assert mock_create_client.call_args.args[1] == "subscriptionId"
+ assert mock_credential.called_with(None, None)
+ mock_create_client.assert_called_with(mock_credential(),
"subscriptionId")
def test_get_factory(hook: AzureDataFactoryHook):
@@ -520,7 +528,7 @@ def test_connection_failure_missing_tenant_id():
pytest.param("a://?resource_group_name=abc&factory_name=abc",
id="no-prefix"),
],
)
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
+@patch(f"{MODULE}.AzureDataFactoryHook.get_conn")
def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = AzureDataFactoryHook("my_conn")
@@ -539,8 +547,8 @@ def
test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
pytest.param("a://hi:yo@?tenantId=ten&subscriptionId=sub",
id="no-prefix"),
],
)
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.ClientSecretCredential")
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook._create_client")
+@patch(f"{MODULE}.ClientSecretCredential")
+@patch(f"{MODULE}.AzureDataFactoryHook._create_client")
def test_get_conn_backcompat_prefix_works(mock_create, mock_cred, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = AzureDataFactoryHook("my_conn")
@@ -549,7 +557,7 @@ def test_get_conn_backcompat_prefix_works(mock_create,
mock_cred, uri):
mock_create.assert_called_with(mock_cred.return_value, "sub")
-@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
+@patch(f"{MODULE}.AzureDataFactoryHook.get_conn")
def test_backcompat_prefix_both_prefers_short(mock_connect):
with patch.dict(
os.environ,
@@ -573,8 +581,8 @@ def test_refresh_conn(hook):
class TestAzureDataFactoryAsyncHook:
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_queued(self,
mock_get_pipeline_run, mock_conn):
"""Test get_adf_pipeline_run_status function with mocked status"""
mock_status = "Queued"
@@ -584,8 +592,8 @@ class TestAzureDataFactoryAsyncHook:
assert response == mock_status
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_inprogress(
self,
mock_get_pipeline_run,
@@ -599,8 +607,8 @@ class TestAzureDataFactoryAsyncHook:
assert response == mock_status
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_success(self,
mock_get_pipeline_run, mock_conn):
"""Test get_adf_pipeline_run_status function with mocked status"""
mock_status = "Succeeded"
@@ -610,8 +618,8 @@ class TestAzureDataFactoryAsyncHook:
assert response == mock_status
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_failed(self,
mock_get_pipeline_run, mock_conn):
"""Test get_adf_pipeline_run_status function with mocked status"""
mock_status = "Failed"
@@ -621,8 +629,8 @@ class TestAzureDataFactoryAsyncHook:
assert response == mock_status
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_cancelled(self,
mock_get_pipeline_run, mock_conn):
"""Test get_adf_pipeline_run_status function with mocked status"""
mock_status = "Cancelled"
@@ -633,8 +641,8 @@ class TestAzureDataFactoryAsyncHook:
@pytest.mark.asyncio
@mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_connection")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
async def test_get_pipeline_run_exception_without_resource(
self, mock_conn, mock_get_connection, mock_pipeline_run
):
@@ -786,7 +794,7 @@ class TestAzureDataFactoryAsyncHook:
get_field(extras, "non-existent-field", strict=True)
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
+ @mock.patch(f"{MODULE}.AzureDataFactoryAsyncHook.get_async_conn")
async def test_refresh_conn(self, mock_get_async_conn):
"""Test refresh_conn method _conn is reset and get_async_conn is
called"""
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index e1412ca4a4..84a0a9420b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -68,7 +68,7 @@ class TestAzureDataLakeHook:
)
)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_conn(self, mock_lib):
from azure.datalake.store import core
@@ -96,12 +96,12 @@ class TestAzureDataLakeHook:
assert hook._conn is None
assert hook.conn_id == "adl_test_key_without_tenant"
assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
- assert mock_azure_identity_credential_adapter.called
+ assert mock_azure_identity_credential_adapter.called_with(None, None)
assert not mock_datalake_store_lib.auth.called
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_check_for_blob(self, mock_lib, mock_filesystem):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -111,8 +111,8 @@ class TestAzureDataLakeHook:
mocked_glob.assert_called()
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.multithread.ADLUploader", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_upload_file(self, mock_lib, mock_uploader):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -136,8 +136,8 @@ class TestAzureDataLakeHook:
)
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLDownloader",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.multithread.ADLDownloader", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_download_file(self, mock_lib, mock_downloader):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -161,8 +161,8 @@ class TestAzureDataLakeHook:
)
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_list_glob(self, mock_lib, mock_fs):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -171,8 +171,8 @@ class TestAzureDataLakeHook:
mock_fs.return_value.glob.assert_called_once_with("file_path/*")
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_list_walk(self, mock_lib, mock_fs):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -181,8 +181,8 @@ class TestAzureDataLakeHook:
mock_fs.return_value.walk.assert_called_once_with("file_path/some_folder/")
@pytest.mark.usefixtures("connection")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
autospec=True)
- @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
+ @mock.patch(f"{MODULE}.core.AzureDLFileSystem", autospec=True)
+ @mock.patch(f"{MODULE}.lib", autospec=True)
def test_remove(self, mock_lib, mock_fs):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
@@ -198,63 +198,57 @@ class TestAzureDataLakeStorageV2Hook:
self.directory_name = "test_directory"
self.file_name = "test_file_name"
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_create_file_system(self, mock_conn):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.create_file_system("test_file_system")
expected_calls =
[mock.call().create_file_system(file_system=self.file_system_name)]
mock_conn.assert_has_calls(expected_calls)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.FileSystemClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.FileSystemClient")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_get_file_system(self, mock_conn, mock_file_system):
mock_conn.return_value.get_file_system_client.return_value =
mock_file_system
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
result = hook.get_file_system(self.file_system_name)
assert result == mock_file_system
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
- @mock.patch(
-
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
- )
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_create_directory(self, mock_conn, mock_get_file_system,
mock_directory_client):
mock_get_file_system.return_value.create_directory.return_value =
mock_directory_client
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
result = hook.create_directory(self.file_system_name,
self.directory_name)
assert result == mock_directory_client
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
- @mock.patch(
-
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
- )
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_get_directory(self, mock_conn, mock_get_file_system,
mock_directory_client):
mock_get_file_system.return_value.get_directory_client.return_value =
mock_directory_client
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
result = hook.get_directory_client(self.file_system_name,
self.directory_name)
assert result == mock_directory_client
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeFileClient")
- @mock.patch(
-
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
- )
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.DataLakeFileClient")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_create_file(self, mock_conn, mock_get_file_system,
mock_file_client):
mock_get_file_system.return_value.create_file.return_value =
mock_file_client
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
result = hook.create_file(self.file_system_name, self.file_name)
assert result == mock_file_client
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_delete_file_system(self, mock_conn):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.delete_file_system(self.file_system_name)
expected_calls =
[mock.call().delete_file_system(self.file_system_name)]
mock_conn.assert_has_calls(expected_calls)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.DataLakeDirectoryClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.DataLakeDirectoryClient")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_delete_directory(self, mock_conn, mock_directory_client):
mock_conn.return_value.get_directory_client.return_value =
mock_directory_client
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
@@ -267,7 +261,7 @@ class TestAzureDataLakeStorageV2Hook:
]
mock_conn.assert_has_calls(expected_calls)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_list_file_system(self, mock_conn):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.list_file_system(prefix="prefix")
@@ -275,10 +269,8 @@ class TestAzureDataLakeStorageV2Hook:
name_starts_with="prefix", include_metadata=False
)
- @mock.patch(
-
"airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_file_system"
- )
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_file_system")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_list_files_directory(self, mock_conn, mock_get_file_system):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.list_files_directory(self.file_system_name, self.directory_name)
@@ -288,7 +280,7 @@ class TestAzureDataLakeStorageV2Hook:
argnames="list_file_systems_result",
argvalues=[iter([FileSystemProperties]), iter([])],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_connection_success(self, mock_conn, list_file_systems_result):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.get_conn().list_file_systems.return_value =
list_file_systems_result
@@ -297,7 +289,7 @@ class TestAzureDataLakeStorageV2Hook:
assert status is True
assert msg == "Successfully connected to ADLS Gen2 Storage."
-
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook.get_conn")
+ @mock.patch(f"{MODULE}.AzureDataLakeStorageV2Hook.get_conn")
def test_connection_failure(self, mock_conn):
hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.conn_id)
hook.get_conn().list_file_systems =
PropertyMock(side_effect=Exception("Authentication failed."))
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index aa14074363..7ef5262e3e 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -26,6 +26,8 @@ from azure.storage.fileshare import DirectoryProperties,
FileProperties, ShareSe
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.fileshare import
AzureFileShareHook
+MODULE = "airflow.providers.microsoft.azure.hooks.fileshare"
+
class TestAzureFileshareHook:
@pytest.fixture(autouse=True)
@@ -69,7 +71,7 @@ class TestAzureFileshareHook:
share_client = hook.share_service_client
assert isinstance(share_client, ShareServiceClient)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
def test_check_for_directory(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.exists.return_value = True
@@ -79,7 +81,7 @@ class TestAzureFileshareHook:
assert hook.check_for_directory()
mock_instance.exists.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
def test_load_data(self, mock_service):
mock_instance = mock_service.return_value
hook = AzureFileShareHook(
@@ -88,7 +90,7 @@ class TestAzureFileshareHook:
hook.load_data("big string")
mock_instance.upload_file.assert_called_once_with("big string")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
def test_list_directories_and_files(self, mock_service):
mock_instance = mock_service.return_value
hook = AzureFileShareHook(
@@ -97,7 +99,7 @@ class TestAzureFileshareHook:
hook.list_directories_and_files()
mock_instance.list_directories_and_files.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
def test_list_files(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.list_directories_and_files.return_value = [
@@ -113,7 +115,7 @@ class TestAzureFileshareHook:
assert files == ["file1", "file2"]
mock_instance.list_directories_and_files.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareDirectoryClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareDirectoryClient", autospec=True)
def test_create_directory(self, mock_service):
mock_instance = mock_service.return_value
hook = AzureFileShareHook(
@@ -122,7 +124,7 @@ class TestAzureFileshareHook:
hook.create_directory()
mock_instance.create_directory.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
def test_get_file(self, mock_service):
mock_instance = mock_service.return_value
hook = AzureFileShareHook(
@@ -131,7 +133,7 @@ class TestAzureFileshareHook:
hook.get_file("path")
mock_instance.download_file.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareFileClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareFileClient", autospec=True)
def test_get_file_to_stream(self, mock_service):
mock_instance = mock_service.return_value
hook = AzureFileShareHook(
@@ -141,21 +143,21 @@ class TestAzureFileshareHook:
hook.get_file_to_stream(stream=data)
mock_instance.download_file.assert_called_once_with()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
def test_create_share(self, mock_service):
mock_instance = mock_service.return_value
hook =
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
hook.create_share(share_name="my_share")
mock_instance.create_share.assert_called_once_with("my_share")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
def test_delete_share(self, mock_service):
mock_instance = mock_service.return_value
hook =
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
hook.delete_share("my_share")
mock_instance.delete_share.assert_called_once_with("my_share")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
def test_connection_success(self, mock_service):
mock_instance = mock_service.return_value
hook =
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
@@ -164,7 +166,7 @@ class TestAzureFileshareHook:
assert status is True
assert msg == "Successfully connected to Azure File Share."
-
@mock.patch("airflow.providers.microsoft.azure.hooks.fileshare.ShareServiceClient",
autospec=True)
+ @mock.patch(f"{MODULE}.ShareServiceClient", autospec=True)
def test_connection_failure(self, mock_service):
mock_instance = mock_service.return_value
hook =
AzureFileShareHook(azure_fileshare_conn_id="azure_fileshare_extras")
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index ae3291fbd4..288a560cfa 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -19,7 +19,6 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.synapse.spark import SparkClient
from airflow.models.connection import Connection
@@ -38,6 +37,8 @@ NAME = "testName"
ID = "testId"
JOB_ID = 1
+MODULE = "airflow.providers.microsoft.azure.hooks.synapse"
+
@pytest.fixture(autouse=True)
def setup_connections(create_mock_connections):
@@ -87,23 +88,41 @@ def hook():
return client
[email protected](
- ("connection_id", "credential_type"),
- [
- (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
- (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
- ],
-)
-def test_get_connection_by_credential_client_secret(connection_id: str,
credential_type: type):
- hook = AzureSynapseHook(connection_id)
+@patch(f"{MODULE}.ClientSecretCredential")
+def test_get_connection_by_credential_client_secret(mock_credential):
+ hook = AzureSynapseHook(DEFAULT_CONNECTION_CLIENT_SECRET)
with patch.object(hook, "_create_client") as mock_create_client:
mock_create_client.return_value = MagicMock()
+
connection = hook.get_conn()
assert connection is not None
- mock_create_client.assert_called_once()
- assert isinstance(mock_create_client.call_args.args[0],
credential_type)
- assert mock_create_client.call_args.args[-1] == "subscriptionId"
+ mock_create_client.assert_called_with(
+ mock_credential(),
+ "https://testsynapse.dev.azuresynapse.net",
+ "",
+ "2022-02-22-preview",
+ "subscriptionId",
+ )
+
+
+@patch(f"{MODULE}.get_default_azure_credential")
+def test_get_conn_by_default_azure_credential(mock_credential):
+ hook = AzureSynapseHook(DEFAULT_CONNECTION_DEFAULT_CREDENTIAL)
+
+ with patch.object(hook, "_create_client") as mock_create_client:
+ mock_create_client.return_value = MagicMock()
+
+ connection = hook.get_conn()
+ assert connection is not None
+ assert mock_credential.called_with(None, None)
+ mock_create_client.assert_called_with(
+ mock_credential(),
+ "https://testsynapse.dev.azuresynapse.net",
+ "",
+ "2022-02-22-preview",
+ "subscriptionId",
+ )
def test_run_spark_job(hook: AzureSynapseHook):
diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py
b/tests/providers/microsoft/azure/hooks/test_base_azure.py
index 313d61cd11..8e666d2524 100644
--- a/tests/providers/microsoft/azure/hooks/test_base_azure.py
+++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py
@@ -25,6 +25,8 @@ from airflow.providers.microsoft.azure.hooks.base_azure
import AzureBaseHook
pytestmark = pytest.mark.db_test
+MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
+
class TestBaseAzureHook:
@pytest.mark.parametrize(
@@ -32,7 +34,7 @@ class TestBaseAzureHook:
[Connection(conn_id="azure_default", extra={"key_path":
"key_file.json"})],
indirect=True,
)
-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
+ @patch(f"{MODULE}.get_client_from_auth_file")
def test_get_conn_with_key_path(self, mock_get_client_from_auth_file,
mocked_connection):
mock_get_client_from_auth_file.return_value = "foo-bar"
mock_sdk_client = Mock()
@@ -49,7 +51,7 @@ class TestBaseAzureHook:
[Connection(conn_id="azure_default", extra={"key_json": {"test":
"test"}})],
indirect=True,
)
-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
+ @patch(f"{MODULE}.get_client_from_json_dict")
def test_get_conn_with_key_json(self, mock_get_client_from_json_dict,
mocked_connection):
mock_sdk_client = Mock()
mock_get_client_from_json_dict.return_value = "foo-bar"
@@ -60,7 +62,7 @@ class TestBaseAzureHook:
)
assert auth_sdk_client == "foo-bar"
-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials")
+ @patch(f"{MODULE}.ServicePrincipalCredentials")
@pytest.mark.parametrize(
"mocked_connection",
[
@@ -88,3 +90,41 @@ class TestBaseAzureHook:
subscription_id=mocked_connection.extra_dejson["subscriptionId"],
)
assert auth_sdk_client == "spam-egg"
+
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="azure_default",
+ extra={
+ "managed_identity_client_id": "test_client_id",
+ "workload_identity_tenant_id": "test_tenant_id",
+ "subscriptionId": "subscription_id",
+ },
+ )
+ ],
+ indirect=True,
+ )
+ @patch("azure.common.credentials.ServicePrincipalCredentials")
+
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
+ def test_get_conn_fallback_to_azure_identity_credential_adapter(
+ self,
+ mock_credential_adapter,
+ mock_service_pricipal_credential,
+ mocked_connection,
+ ):
+ mock_credential = Mock()
+ mock_credential_adapter.return_value = mock_credential
+
+ mock_sdk_client = Mock()
+ AzureBaseHook(mock_sdk_client).get_conn()
+
+ mock_credential_adapter.assert_called_with(
+ managed_identity_client_id="test_client_id",
+ workload_identity_tenant_id="test_tenant_id",
+ )
+ assert not mock_service_pricipal_credential.called
+ mock_sdk_client.assert_called_once_with(
+ credentials=mock_credential,
+ subscription_id="subscription_id",
+ )
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index b5a8717547..7ac4fb1857 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -187,6 +187,7 @@ class TestWasbHook:
)
def test_managed_identity(self, mocked_default_azure_credential,
mocked_blob_service_client):
+ assert mocked_default_azure_credential.called_with(None, None)
mocked_default_azure_credential.return_value = "foo-bar"
WasbHook(wasb_conn_id=self.managed_identity_conn_id).get_conn()
mocked_blob_service_client.assert_called_once_with(
diff --git a/tests/providers/microsoft/azure/test_utils.py
b/tests/providers/microsoft/azure/test_utils.py
index ad5e3cbc2d..5a081441ca 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -17,11 +17,16 @@
from __future__ import annotations
+from typing import Any
from unittest import mock
import pytest
-from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter, get_field
+from airflow.providers.microsoft.azure.utils import (
+ AzureIdentityCredentialAdapter,
+ add_managed_identity_connection_widgets,
+ get_field,
+)
MODULE = "airflow.providers.microsoft.azure.utils"
@@ -62,6 +67,16 @@ def test_get_field_non_prefixed(input, expected):
assert value == expected
+def test_add_managed_identity_connection_widgets():
+ def test_func() -> dict[str, Any]:
+ return {}
+
+ widgets = add_managed_identity_connection_widgets(test_func)()
+
+ assert "managed_identity_client_id" in widgets
+ assert "workload_identity_tenant_id" in widgets
+
+
class TestAzureIdentityCredentialAdapter:
@mock.patch(f"{MODULE}.PipelineRequest")
@mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")