This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new faf32539d6 Optimise Airflow DB backend usage in Azure Provider (#33750)
faf32539d6 is described below
commit faf32539d6a1be2bfba1b97e72e4508fb6896af6
Author: Andrey Anshin <[email protected]>
AuthorDate: Sat Aug 26 10:15:18 2023 +0400
Optimise Airflow DB backend usage in Azure Provider (#33750)
---
airflow/providers/microsoft/azure/hooks/adx.py | 15 +-
airflow/providers/microsoft/azure/hooks/batch.py | 23 +-
.../microsoft/azure/hooks/container_instance.py | 6 +-
.../microsoft/azure/hooks/container_registry.py | 10 +-
.../providers/microsoft/azure/hooks/data_lake.py | 15 +-
airflow/providers/microsoft/azure/hooks/wasb.py | 7 +-
.../microsoft/azure/log/wasb_task_handler.py | 1 -
.../providers/microsoft/azure/operators/batch.py | 3 +-
.../microsoft/azure/operators/data_factory.py | 7 +-
.../providers/microsoft/azure/operators/synapse.py | 9 +-
.../microsoft/azure/sensors/data_factory.py | 7 +-
docs/spelling_wordlist.txt | 1 +
tests/providers/microsoft/azure/hooks/test_adx.py | 289 +++++-----
tests/providers/microsoft/azure/hooks/test_asb.py | 42 +-
.../microsoft/azure/hooks/test_azure_batch.py | 35 +-
.../azure/hooks/test_azure_container_instance.py | 17 +-
.../azure/hooks/test_azure_container_registry.py | 15 +-
.../azure/hooks/test_azure_container_volume.py | 31 +-
.../microsoft/azure/hooks/test_azure_cosmos.py | 18 +-
.../azure/hooks/test_azure_data_factory.py | 232 ++++----
.../microsoft/azure/hooks/test_azure_data_lake.py | 14 +-
.../microsoft/azure/hooks/test_azure_fileshare.py | 34 +-
.../microsoft/azure/hooks/test_azure_synapse.py | 83 ++-
.../microsoft/azure/hooks/test_base_azure.py | 72 +--
tests/providers/microsoft/azure/hooks/test_wasb.py | 595 +++++++++------------
.../microsoft/azure/operators/test_azure_batch.py | 33 +-
.../microsoft/azure/operators/test_azure_cosmos.py | 16 +-
.../azure/operators/test_azure_data_factory.py | 10 +-
.../azure/operators/test_azure_synapse.py | 21 +-
tests/providers/microsoft/conftest.py | 68 +++
30 files changed, 872 insertions(+), 857 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/adx.py
b/airflow/providers/microsoft/azure/hooks/adx.py
index f2cfd1a6c7..53da21e396 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -26,6 +26,7 @@ This module contains Azure Data Explorer hook.
from __future__ import annotations
import warnings
+from functools import cached_property
from typing import Any
from azure.identity import DefaultAzureCredential
@@ -76,8 +77,8 @@ class AzureDataExplorerHook(BaseHook):
conn_type = "azure_data_explorer"
hook_name = "Azure Data Explorer"
- @staticmethod
- def get_connection_form_widgets() -> dict[str, Any]:
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
from flask_babel import lazy_gettext
@@ -94,8 +95,8 @@ class AzureDataExplorerHook(BaseHook):
),
}
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
@@ -116,7 +117,11 @@ class AzureDataExplorerHook(BaseHook):
def __init__(self, azure_data_explorer_conn_id: str = default_conn_name)
-> None:
super().__init__()
self.conn_id = azure_data_explorer_conn_id
- self.connection = self.get_conn() # todo: make this a property, or
just delete
+
+ @cached_property
+ def connection(self) -> KustoClient:
+ """Return a KustoClient object (cached)."""
+ return self.get_conn()
def get_conn(self) -> KustoClient:
"""Return a KustoClient object."""
diff --git a/airflow/providers/microsoft/azure/hooks/batch.py
b/airflow/providers/microsoft/azure/hooks/batch.py
index deca28216d..594725c0da 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import time
from datetime import timedelta
+from functools import cached_property
from typing import Any
from azure.batch import BatchServiceClient, batch_auth, models as batch_models
@@ -26,7 +27,6 @@ from azure.batch.models import JobAddParameter,
PoolAddParameter, TaskAddParamet
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.models import Connection
from airflow.providers.microsoft.azure.utils import
AzureIdentityCredentialAdapter, get_field
from airflow.utils import timezone
@@ -52,8 +52,8 @@ class AzureBatchHook(BaseHook):
field_name=name,
)
- @staticmethod
- def get_connection_form_widgets() -> dict[str, Any]:
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
@@ -63,8 +63,8 @@ class AzureBatchHook(BaseHook):
"account_url": StringField(lazy_gettext("Batch Account URL"),
widget=BS3TextFieldWidget()),
}
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "host", "extra"],
@@ -77,20 +77,19 @@ class AzureBatchHook(BaseHook):
def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_batch_conn_id
- self.connection = self.get_conn()
- def _connection(self) -> Connection:
- """Get connected to Azure Batch service."""
- conn = self.get_connection(self.conn_id)
- return conn
+ @cached_property
+ def connection(self) -> BatchServiceClient:
+ """Get the Batch client connection (cached)."""
+ return self.get_conn()
- def get_conn(self):
+ def get_conn(self) -> BatchServiceClient:
"""
Get the Batch client connection.
:return: Azure Batch client
"""
- conn = self._connection()
+ conn = self.get_connection(self.conn_id)
batch_account_url = self._get_field(conn.extra_dejson, "account_url")
if not batch_account_url:
diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index 9a0d0ec210..8fc845bf13 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import warnings
+from functools import cached_property
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import ContainerGroup
@@ -47,7 +48,10 @@ class AzureContainerInstanceHook(AzureBaseHook):
def __init__(self, azure_conn_id: str = default_conn_name) -> None:
super().__init__(sdk_client=ContainerInstanceManagementClient,
conn_id=azure_conn_id)
- self.connection = self.get_conn()
+
+ @cached_property
+ def connection(self):
+ return self.get_conn()
def create_or_update(self, resource_group: str, name: str,
container_group: ContainerGroup) -> None:
"""
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index a3298117cc..c1217e3a86 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -18,6 +18,7 @@
"""Hook for Azure Container Registry."""
from __future__ import annotations
+from functools import cached_property
from typing import Any
from azure.mgmt.containerinstance.models import ImageRegistryCredential
@@ -39,8 +40,8 @@ class AzureContainerRegistryHook(BaseHook):
conn_type = "azure_container_registry"
hook_name = "Azure Container Registry"
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
@@ -59,7 +60,10 @@ class AzureContainerRegistryHook(BaseHook):
def __init__(self, conn_id: str = "azure_registry") -> None:
super().__init__()
self.conn_id = conn_id
- self.connection = self.get_conn()
+
+ @cached_property
+ def connection(self) -> ImageRegistryCredential:
+ return self.get_conn()
def get_conn(self) -> ImageRegistryCredential:
conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 95ef4c6cc2..3849727e86 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from functools import cached_property
from typing import Any
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
@@ -256,8 +257,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
conn_type = "adls"
hook_name = "Azure Date Lake Storage V2"
- @staticmethod
- def get_connection_form_widgets() -> dict[str, Any]:
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
from flask_babel import lazy_gettext
@@ -272,8 +273,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
),
}
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port"],
@@ -296,7 +297,11 @@ class AzureDataLakeStorageV2Hook(BaseHook):
super().__init__()
self.conn_id = adls_conn_id
self.public_read = public_read
- self.service_client = self.get_conn()
+
+ @cached_property
+ def service_client(self) -> DataLakeServiceClient:
+ """Return the DataLakeServiceClient object (cached)."""
+ return self.get_conn()
def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
"""Return the DataLakeServiceClient object."""
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py
b/airflow/providers/microsoft/azure/hooks/wasb.py
index 95d4764b67..55c1aba086 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -27,6 +27,7 @@ from __future__ import annotations
import logging
import os
+from functools import cached_property
from typing import Any, Union
from urllib.parse import urlparse
@@ -123,7 +124,6 @@ class WasbHook(BaseHook):
super().__init__()
self.conn_id = wasb_conn_id
self.public_read = public_read
- self.blob_service_client: BlobServiceClient = self.get_conn()
logger =
logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
try:
@@ -142,6 +142,11 @@ class WasbHook(BaseHook):
return extra_dict[field_name] or None
return extra_dict.get(f"{prefix}{field_name}") or None
+ @cached_property
+ def blob_service_client(self) -> BlobServiceClient:
+ """Return the BlobServiceClient object (cached)."""
+ return self.get_conn()
+
def get_conn(self) -> BlobServiceClient:
"""Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
index 97a8af5ae1..21e96f1003 100644
--- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
+++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
@@ -67,7 +67,6 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
self.wasb_container = wasb_container
self.remote_base = wasb_log_folder
self.log_relative_path = ""
- self._hook = None
self.closed = False
self.upload_on_close = True
self.delete_local_copy = (
diff --git a/airflow/providers/microsoft/azure/operators/batch.py
b/airflow/providers/microsoft/azure/operators/batch.py
index e26f56dd6e..06122ed1f9 100644
--- a/airflow/providers/microsoft/azure/operators/batch.py
+++ b/airflow/providers/microsoft/azure/operators/batch.py
@@ -179,7 +179,8 @@ class AzureBatchOperator(BaseOperator):
self.should_delete_pool = should_delete_pool
@cached_property
- def hook(self):
+ def hook(self) -> AzureBatchHook:
+ """Create and return an AzureBatchHook (cached)."""
return self.get_hook()
def _check_inputs(self) -> Any:
diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py
b/airflow/providers/microsoft/azure/operators/data_factory.py
index d6b4592e35..12962e5610 100644
--- a/airflow/providers/microsoft/azure/operators/data_factory.py
+++ b/airflow/providers/microsoft/azure/operators/data_factory.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import time
import warnings
+from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
@@ -159,8 +160,12 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
self.check_interval = check_interval
self.deferrable = deferrable
+ @cached_property
+ def hook(self) -> AzureDataFactoryHook:
+ """Create and return an AzureDataFactoryHook (cached)."""
+ return
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+
def execute(self, context: Context) -> None:
- self.hook =
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
self.log.info("Executing the %s pipeline.", self.pipeline_name)
response = self.hook.run_pipeline(
pipeline_name=self.pipeline_name,
diff --git a/airflow/providers/microsoft/azure/operators/synapse.py
b/airflow/providers/microsoft/azure/operators/synapse.py
index b9d97704c5..dd6dda5555 100644
--- a/airflow/providers/microsoft/azure/operators/synapse.py
+++ b/airflow/providers/microsoft/azure/operators/synapse.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from azure.synapse.spark.models import SparkBatchJobOptions
@@ -73,10 +74,12 @@ class AzureSynapseRunSparkBatchOperator(BaseOperator):
self.timeout = timeout
self.check_interval = check_interval
+ @cached_property
+ def hook(self):
+ """Create and return an AzureSynapseHook (cached)."""
+ return
AzureSynapseHook(azure_synapse_conn_id=self.azure_synapse_conn_id,
spark_pool=self.spark_pool)
+
def execute(self, context: Context) -> None:
- self.hook = AzureSynapseHook(
- azure_synapse_conn_id=self.azure_synapse_conn_id,
spark_pool=self.spark_pool
- )
self.log.info("Executing the Synapse spark job.")
response = self.hook.run_spark_job(payload=self.payload)
self.log.info(response)
diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py
b/airflow/providers/microsoft/azure/sensors/data_factory.py
index 4caa26a99d..91bc4072c0 100644
--- a/airflow/providers/microsoft/azure/sensors/data_factory.py
+++ b/airflow/providers/microsoft/azure/sensors/data_factory.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import warnings
from datetime import timedelta
+from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
@@ -72,8 +73,12 @@ class
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
self.deferrable = deferrable
+ @cached_property
+ def hook(self):
+ """Create and return an AzureDataFactoryHook (cached)."""
+ return
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+
def poke(self, context: Context) -> bool:
- self.hook =
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
pipeline_run_status = self.hook.get_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index ec01067b12..eb80939a4c 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -154,6 +154,7 @@ BaseView
BaseXCom
bashrc
batchGet
+BatchServiceClient
bc
bcc
bdist
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py
b/tests/providers/microsoft/azure/hooks/test_adx.py
index 4cf61c440b..2268f8b8a6 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -17,10 +17,7 @@
# under the License.
from __future__ import annotations
-import json
-import os
from unittest import mock
-from unittest.mock import patch
import pytest
from azure.kusto.data import ClientRequestProperties, KustoClient,
KustoConnectionStringBuilder
@@ -29,196 +26,220 @@ from pytest import param
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
-from airflow.utils import db
-from airflow.utils.session import create_session
from tests.test_utils.providers import get_provider_min_airflow_version
ADX_TEST_CONN_ID = "adx_test_connection_id"
class TestAzureDataExplorerHook:
- def teardown_method(self):
- with create_session() as session:
- session.query(Connection).filter(Connection.conn_id ==
ADX_TEST_CONN_ID).delete()
-
- def test_conn_missing_method(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="missing_method",
conn_type="azure_data_explorer",
login="client_id",
password="client secret",
host="https://help.kusto.windows.net",
- extra=json.dumps({}),
+ extra={},
)
- )
- with pytest.raises(AirflowException) as ctx:
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- assert "is missing: `data_explorer__auth_method`" in str(ctx.value)
+ ],
+ indirect=True,
+ )
+ def test_conn_missing_method(self, mocked_connection):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+ error_pattern = "is missing: `auth_method`"
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.get_conn()
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.connection
- def test_conn_unknown_method(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="unknown_method",
conn_type="azure_data_explorer",
login="client_id",
password="client secret",
host="https://help.kusto.windows.net",
- extra=json.dumps({"auth_method": "AAD_OTHER"}),
- )
- )
- with pytest.raises(AirflowException) as ctx:
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- assert "Unknown authentication method: AAD_OTHER" in str(ctx.value)
+ extra={"auth_method": "AAD_OTHER"},
+ ),
+ ],
+ indirect=True,
+ )
+ def test_conn_unknown_method(self, mocked_connection):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+ error_pattern = "Unknown authentication method: AAD_OTHER"
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.get_conn()
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.connection
- def test_conn_missing_cluster(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="missing_cluster",
conn_type="azure_data_explorer",
login="client_id",
password="client secret",
- extra=json.dumps({}),
- )
- )
- with pytest.raises(AirflowException) as ctx:
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- assert "Host connection option is required" in str(ctx.value)
+ extra={},
+ ),
+ ],
+ indirect=True,
+ )
+ def test_conn_missing_cluster(self, mocked_connection):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+ error_pattern = "Host connection option is required"
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.get_conn()
+ with pytest.raises(AirflowException, match=error_pattern):
+ assert hook.connection
- @mock.patch.object(KustoClient, "__init__")
- def test_conn_method_aad_creds(self, mock_init):
- mock_init.return_value = None
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="method_aad_creds",
conn_type="azure_data_explorer",
login="client_id",
password="client secret",
host="https://help.kusto.windows.net",
- extra=json.dumps(
- {
- "tenant": "tenant",
- "auth_method": "AAD_CREDS",
- }
- ),
+ extra={"tenant": "tenant", "auth_method": "AAD_CREDS"},
)
- )
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+ ],
+ indirect=True,
+ )
+ @mock.patch.object(KustoClient, "__init__")
+ def test_conn_method_aad_creds(self, mock_init, mocked_connection):
+ mock_init.return_value = None
+
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
assert mock_init.called_with(
KustoConnectionStringBuilder.with_aad_user_password_authentication(
"https://help.kusto.windows.net", "client_id", "client
secret", "tenant"
)
)
-
@mock.patch("azure.identity._credentials.environment.ClientSecretCredential")
- def test_conn_method_token_creds(self, mock1):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="method_token_creds",
conn_type="azure_data_explorer",
host="https://help.kusto.windows.net",
- extra=json.dumps(
- {
- "auth_method": "AZURE_TOKEN_CRED",
- }
- ),
- )
+ extra={
+ "auth_method": "AZURE_TOKEN_CRED",
+ },
+ ),
+ ],
+ indirect=True,
+ )
+
@mock.patch("azure.identity._credentials.environment.ClientSecretCredential")
+ def test_conn_method_token_creds(self, mock1, mocked_connection,
monkeypatch):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+
+ monkeypatch.setenv("AZURE_TENANT_ID", "tenant")
+ monkeypatch.setenv("AZURE_CLIENT_ID", "client")
+ monkeypatch.setenv("AZURE_CLIENT_SECRET", "secret")
+
+ assert hook.connection._kcsb.data_source ==
"https://help.kusto.windows.net"
+ mock1.assert_called_once_with(
+ tenant_id="tenant",
+ client_id="client",
+ client_secret="secret",
+ authority="https://login.microsoftonline.com",
)
- with patch.dict(
- in_dict=os.environ,
- values={
- "AZURE_TENANT_ID": "tenant",
- "AZURE_CLIENT_ID": "client",
- "AZURE_CLIENT_SECRET": "secret",
- },
- ):
- hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- assert hook.connection._kcsb.data_source ==
"https://help.kusto.windows.net"
- mock1.assert_called_once_with(
- tenant_id="tenant",
- client_id="client",
- client_secret="secret",
- authority="https://login.microsoftonline.com",
- )
- @mock.patch.object(KustoClient, "__init__")
- def test_conn_method_aad_app(self, mock_init):
- mock_init.return_value = None
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="method_aad_app",
conn_type="azure_data_explorer",
login="app_id",
password="app key",
host="https://help.kusto.windows.net",
- extra=json.dumps(
- {
- "tenant": "tenant",
- "auth_method": "AAD_APP",
- }
- ),
+ extra={
+ "tenant": "tenant",
+ "auth_method": "AAD_APP",
+ },
)
- )
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+ ],
+ indirect=True,
+ )
+ @mock.patch.object(KustoClient, "__init__")
+ def test_conn_method_aad_app(self, mock_init, mocked_connection):
+ mock_init.return_value = None
+
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
assert mock_init.called_with(
KustoConnectionStringBuilder.with_aad_application_key_authentication(
"https://help.kusto.windows.net", "app_id", "app key", "tenant"
)
)
- @mock.patch.object(KustoClient, "__init__")
- def test_conn_method_aad_app_cert(self, mock_init):
- mock_init.return_value = None
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
- conn_id=ADX_TEST_CONN_ID,
+ conn_id="method_aad_app",
conn_type="azure_data_explorer",
- login="client_id",
+ login="app_id",
+ password="app key",
host="https://help.kusto.windows.net",
- extra=json.dumps(
- {
- "tenant": "tenant",
- "auth_method": "AAD_APP_CERT",
- "certificate": "PEM",
- "thumbprint": "thumbprint",
- }
- ),
+ extra={
+ "tenant": "tenant",
+ "auth_method": "AAD_APP",
+ },
)
- )
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+ ],
+ indirect=True,
+ )
+ @mock.patch.object(KustoClient, "__init__")
+ def test_conn_method_aad_app_cert(self, mock_init, mocked_connection):
+ mock_init.return_value = None
+
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
assert mock_init.called_with(
KustoConnectionStringBuilder.with_aad_application_certificate_authentication(
"https://help.kusto.windows.net", "client_id", "PEM",
"thumbprint", "tenant"
)
)
- @mock.patch.object(KustoClient, "__init__")
- def test_conn_method_aad_device(self, mock_init):
- mock_init.return_value = None
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
conn_id=ADX_TEST_CONN_ID,
conn_type="azure_data_explorer",
host="https://help.kusto.windows.net",
- extra=json.dumps({"auth_method": "AAD_DEVICE"}),
+ extra={"auth_method": "AAD_DEVICE"},
)
- )
- AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+ ],
+ indirect=True,
+ )
+ @mock.patch.object(KustoClient, "__init__")
+ def test_conn_method_aad_device(self, mock_init, mocked_connection):
+ mock_init.return_value = None
+
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
assert mock_init.called_with(
KustoConnectionStringBuilder.with_aad_device_authentication("https://help.kusto.windows.net")
)
- @mock.patch.object(KustoClient, "execute")
- def test_run_query(self, mock_execute):
- mock_execute.return_value = None
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
conn_id=ADX_TEST_CONN_ID,
conn_type="azure_data_explorer",
host="https://help.kusto.windows.net",
- extra=json.dumps({"auth_method": "AAD_DEVICE"}),
+ extra={"auth_method": "AAD_DEVICE"},
)
- )
+ ],
+ indirect=True,
+ )
+ @mock.patch.object(KustoClient, "execute")
+ def test_run_query(self, mock_execute, mocked_connection):
+ mock_execute.return_value = None
hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
hook.run_query("Database", "Logs | schema", options={"option1":
"option_value"})
properties = ClientRequestProperties()
@@ -246,7 +267,7 @@ class TestAzureDataExplorerHook:
)
@pytest.mark.parametrize(
- "uri",
+ "mocked_connection",
[
param(
"a://usr:pw@host?extra__azure_data_explorer__tenant=my-tenant"
@@ -255,24 +276,28 @@ class TestAzureDataExplorerHook:
),
param("a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP",
id="no-prefix"),
],
+ indirect=True,
)
- def test_backcompat_prefix_works(self, uri):
- with patch.dict(os.environ, AIRFLOW_CONN_MY_CONN=uri):
- hook =
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is
called in init
+ def test_backcompat_prefix_works(self, mocked_connection):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
assert hook.connection._kcsb.data_source == "host"
assert hook.connection._kcsb.application_client_id == "usr"
assert hook.connection._kcsb.application_key == "pw"
assert hook.connection._kcsb.authority_id == "my-tenant"
- def test_backcompat_prefix_both_causes_warning(self):
- with patch.dict(
- in_dict=os.environ,
-
AIRFLOW_CONN_MY_CONN="a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP"
- "&extra__azure_data_explorer__auth_method=AAD_APP",
- ):
- with pytest.warns(Warning, match="Using value for `auth_method`"):
- hook =
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn")
- assert hook.connection._kcsb.data_source == "host"
- assert hook.connection._kcsb.application_client_id == "usr"
- assert hook.connection._kcsb.application_key == "pw"
- assert hook.connection._kcsb.authority_id == "my-tenant"
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ (
+ "a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP"
+ "&extra__azure_data_explorer__auth_method=AAD_APP"
+ )
+ ],
+ indirect=True,
+ )
+ def test_backcompat_prefix_both_causes_warning(self, mocked_connection):
+ hook =
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+ assert hook.connection._kcsb.data_source == "host"
+ assert hook.connection._kcsb.application_client_id == "usr"
+ assert hook.connection._kcsb.application_key == "pw"
+ assert hook.connection._kcsb.authority_id == "my-tenant"
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py
b/tests/providers/microsoft/azure/hooks/test_asb.py
index a9a3851561..5f626d6c29 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -34,22 +34,23 @@ MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(0, 10)]
class TestAdminClientHook:
- def setup_class(self) -> None:
- self.queue_name: str = "test_queue"
- self.conn_id: str = "azure_service_bus_default"
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
+ self.queue_name = "test_queue"
+ self.conn_id = "azure_service_bus_default"
self.connection_string = (
"Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;"
"SharedAccessKeyName=Test;SharedAccessKey=1234566acbc"
)
- self.mock_conn = Connection(
- conn_id="azure_service_bus_default",
- conn_type="azure_service_bus",
- schema=self.connection_string,
+ self.mock_conn = create_mock_connection(
+ Connection(
+ conn_id=self.conn_id,
+ conn_type="azure_service_bus",
+ schema=self.connection_string,
+ )
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
- def test_get_conn(self, mock_connection):
- mock_connection.return_value = self.mock_conn
+ def test_get_conn(self):
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
@@ -124,26 +125,27 @@ class TestAdminClientHook:
class TestMessageHook:
- def setup_class(self) -> None:
- self.queue_name: str = "test_queue"
- self.conn_id: str = "azure_service_bus_default"
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
+ self.queue_name = "test_queue"
+ self.conn_id = "azure_service_bus_default"
self.connection_string = (
"Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;"
"SharedAccessKeyName=Test;SharedAccessKey=1234566acbc"
)
- self.conn = Connection(
- conn_id="azure_service_bus_default",
- conn_type="azure_service_bus",
- schema=self.connection_string,
+ self.mock_conn = create_mock_connection(
+ Connection(
+ conn_id=self.conn_id,
+ conn_type="azure_service_bus",
+ schema=self.connection_string,
+ )
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
- def test_get_service_bus_message_conn(self, mock_connection):
+ def test_get_service_bus_message_conn(self):
"""
Test get_conn() function and check whether the get_conn() function
returns value
is instance of ServiceBusClient
"""
- mock_connection.return_value = self.conn
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
assert isinstance(hook.get_conn(), ServiceBusClient)
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index a3a421f5a0..cd5ab10134 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -17,22 +17,21 @@
# under the License.
from __future__ import annotations
-import json
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
from azure.batch import BatchServiceClient, models as batch_models
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
-from airflow.utils import db
MODULE = "airflow.providers.microsoft.azure.hooks.batch"
class TestAzureBatchHook:
- # set up the test environment
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connections):
# set up the test variable
self.test_vm_conn_id = "test_azure_batch_vm"
self.test_cloud_conn_id = "test_azure_batch_cloud"
@@ -47,27 +46,27 @@ class TestAzureBatchHook:
self.test_cloud_os_version = "test-version"
self.test_node_agent_sku = "test-node-agent-sku"
- # connect with vm configuration
- db.merge_conn(
+ create_mock_connections(
+ # connect with vm configuration
Connection(
conn_id=self.test_vm_conn_id,
- conn_type="azure_batch",
- extra=json.dumps({"account_url": self.test_account_url}),
- )
- )
- # connect with cloud service
- db.merge_conn(
+ conn_type="azure-batch",
+ extra={"account_url": self.test_account_url},
+ ),
+ # connect with cloud service
Connection(
conn_id=self.test_cloud_conn_id,
- conn_type="azure_batch",
- extra=json.dumps({"account_url": self.test_account_url}),
- )
+ conn_type="azure-batch",
+ extra={"account_url": self.test_account_url},
+ ),
)
def test_connection_and_client(self):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
- assert isinstance(hook._connection(), Connection)
assert isinstance(hook.get_conn(), BatchServiceClient)
+ conn = hook.connection
+ assert isinstance(conn, BatchServiceClient)
+ assert hook.connection is conn, "`connection` property should be
cached"
@mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials")
@mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter")
@@ -195,7 +194,7 @@ class TestAzureBatchHook:
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
def test_connection_success(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
- hook.get_conn().job.return_value = {}
+ hook.connection.job.return_value = {}
status, msg = hook.test_connection()
assert status is True
assert msg == "Successfully connected to Azure Batch."
@@ -203,7 +202,7 @@ class TestAzureBatchHook:
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
def test_connection_failure(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
- hook.get_conn().job.list =
PropertyMock(side_effect=Exception("Authentication failed."))
+ hook.connection.job.list =
PropertyMock(side_effect=Exception("Authentication failed."))
status, msg = hook.test_connection()
assert status is False
assert msg == "Authentication failed."
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index a10dde9c65..786df4eb16 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -17,9 +17,9 @@
# under the License.
from __future__ import annotations
-import json
from unittest.mock import patch
+import pytest
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
@@ -30,27 +30,26 @@ from azure.mgmt.containerinstance.models import (
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_instance import
AzureContainerInstanceHook
-from airflow.utils import db
class TestAzureContainerInstanceHook:
- def setup_method(self):
- db.merge_conn(
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
+ mock_connection = create_mock_connection(
Connection(
conn_id="azure_container_instance_test",
conn_type="azure_container_instances",
login="login",
password="key",
- extra=json.dumps({"tenantId": "tenant_id", "subscriptionId":
"subscription_id"}),
+ extra={"tenantId": "tenant_id", "subscriptionId":
"subscription_id"},
)
)
-
self.resources =
ResourceRequirements(requests=ResourceRequests(memory_in_gb="4", cpu="1"))
- with patch(
+ self.hook =
AzureContainerInstanceHook(azure_conn_id=mock_connection.conn_id)
+ with
patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"), patch(
"azure.common.credentials.ServicePrincipalCredentials.__init__",
autospec=True, return_value=None
):
- with
patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"):
- self.hook =
AzureContainerInstanceHook(azure_conn_id="azure_container_instance_test")
+ yield
@patch("azure.mgmt.containerinstance.models.ContainerGroup")
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update")
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 91ae933d2b..38f326d298 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -17,14 +17,16 @@
# under the License.
from __future__ import annotations
+import pytest
+
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_registry import
AzureContainerRegistryHook
-from airflow.utils import db
class TestAzureContainerRegistryHook:
- def test_get_conn(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
conn_id="azure_container_registry",
conn_type="azure_container_registry",
@@ -32,8 +34,11 @@ class TestAzureContainerRegistryHook:
password="password",
host="test.cr",
)
- )
- hook = AzureContainerRegistryHook(conn_id="azure_container_registry")
+ ],
+ indirect=True,
+ )
+ def test_get_conn(self, mocked_connection):
+ hook = AzureContainerRegistryHook(conn_id=mocked_connection.conn_id)
assert hook.connection is not None
assert hook.connection.username == "myuser"
assert hook.connection.password == "password"
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index 3dcbfd9a89..b4c7b8d1c7 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -17,22 +17,25 @@
# under the License.
from __future__ import annotations
-import json
+import pytest
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_volume import
AzureContainerVolumeHook
-from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version
class TestAzureContainerVolumeHook:
- def test_get_file_volume(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
conn_id="azure_container_test_connection", conn_type="wasb",
login="login", password="key"
)
- )
- hook =
AzureContainerVolumeHook(azure_container_volume_conn_id="azure_container_test_connection")
+ ],
+ indirect=True,
+ )
+ def test_get_file_volume(self, mocked_connection):
+ hook =
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
volume = hook.get_file_volume(
mount_name="mount", share_name="share",
storage_account_name="storage", read_only=True
)
@@ -43,19 +46,21 @@ class TestAzureContainerVolumeHook:
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True
- def test_get_file_volume_connection_string(self):
- db.merge_conn(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
Connection(
conn_id="azure_container_test_connection_connection_string",
conn_type="wasb",
login="login",
password="key",
- extra=json.dumps({"connection_string": "a=b;AccountKey=1"}),
+ extra={"connection_string": "a=b;AccountKey=1"},
)
- )
- hook = AzureContainerVolumeHook(
-
azure_container_volume_conn_id="azure_container_test_connection_connection_string"
- )
+ ],
+ indirect=True,
+ )
+ def test_get_file_volume_connection_string(self, mocked_connection):
+ hook =
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
volume = hook.get_file_volume(
mount_name="mount", share_name="share",
storage_account_name="storage", read_only=True
)
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index af649f1d6e..f63b8e8dbd 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import json
import logging
import uuid
from unittest import mock
@@ -29,14 +28,14 @@ from azure.cosmos.cosmos_client import CosmosClient
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
-from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version
class TestAzureCosmosDbHook:
# Set up an environment to test with
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
# set up some test variables
self.test_end_point = "https://test_endpoint:443"
self.test_master_key = "magic_test_key"
@@ -44,25 +43,22 @@ class TestAzureCosmosDbHook:
self.test_collection_name = "test_collection_name"
self.test_database_default = "test_database_default"
self.test_collection_default = "test_collection_default"
- db.merge_conn(
+ create_mock_connection(
Connection(
conn_id="azure_cosmos_test_key_id",
conn_type="azure_cosmos",
login=self.test_end_point,
password=self.test_master_key,
- extra=json.dumps(
- {
- "database_name": self.test_database_default,
- "collection_name": self.test_collection_default,
- }
- ),
+ extra={
+ "database_name": self.test_database_default,
+ "collection_name": self.test_collection_default,
+ },
)
)
@mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient",
autospec=True)
def test_client(self, mock_cosmos):
hook =
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
- assert hook._conn is None
assert isinstance(hook.get_conn(), CosmosClient)
@mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 4ca2eb4920..63a22614dc 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import json
import os
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
@@ -25,9 +24,9 @@ import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
from azure.mgmt.datafactory.models import FactoryListResponse
-from pytest import fixture, param
+from pytest import param
-from airflow import AirflowException
+from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryAsyncHook,
@@ -37,7 +36,6 @@ from airflow.providers.microsoft.azure.hooks.data_factory
import (
get_field,
provide_targeted_factory,
)
-from airflow.utils import db
DEFAULT_RESOURCE_GROUP = "defaultResourceGroup"
AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default"
@@ -59,66 +57,60 @@ NAME = "testName"
ID = "testId"
-def setup_module():
- connection_client_secret = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
[email protected](autouse=True)
+def setup_connections(create_mock_connections):
+ create_mock_connections(
+ # connection_client_secret
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
"tenantId": "tenantId",
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
- }
+ },
),
- )
- connection_default_credential = Connection(
- conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
- conn_type="azure_data_factory",
- extra=json.dumps(
- {
+ # connection_default_credential
+ Connection(
+ conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+ conn_type="azure_data_factory",
+ extra={
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
- }
+ },
),
- )
- connection_missing_subscription_id = Connection(
- conn_id="azure_data_factory_missing_subscription_id",
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
+ Connection(
+ # connection_missing_subscription_id
+ conn_id="azure_data_factory_missing_subscription_id",
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
"tenantId": "tenantId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
- }
+ },
),
- )
- connection_missing_tenant_id = Connection(
- conn_id="azure_data_factory_missing_tenant_id",
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
+ # connection_missing_tenant_id
+ Connection(
+ conn_id="azure_data_factory_missing_tenant_id",
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
- }
+ },
),
)
- db.merge_conn(connection_client_secret)
- db.merge_conn(connection_default_credential)
- db.merge_conn(connection_missing_subscription_id)
- db.merge_conn(connection_missing_tenant_id)
-
-@fixture
[email protected]
def hook():
client =
AzureDataFactoryHook(azure_data_factory_conn_id=DEFAULT_CONNECTION_CLIENT_SECRET)
client._conn = MagicMock(
@@ -799,7 +791,7 @@ class TestAzureDataFactoryAsyncHook:
Test get_pipeline_run function without passing the resource name to
check the decorator function and
raise exception
"""
- mock_connection = Connection(extra=json.dumps({"factory_name":
DATAFACTORY_NAME}))
+ mock_connection = Connection(extra={"factory_name": DATAFACTORY_NAME})
mock_get_connection.return_value = mock_connection
mock_conn.return_value.pipeline_runs.get.return_value =
mock_pipeline_run
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
@@ -807,98 +799,98 @@ class TestAzureDataFactoryAsyncHook:
await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
- async def test_get_async_conn(self, mock_connection):
- """"""
- mock_conn = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
"tenantId": "tenantId",
"subscriptionId": "subscriptionId",
"resource_group_name": RESOURCE_GROUP_NAME,
"factory_name": DATAFACTORY_NAME,
- }
- ),
- )
- mock_connection.return_value = mock_conn
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ },
+ )
+ ],
+ indirect=True,
+ )
+ async def test_get_async_conn(self, mocked_connection):
+ """"""
+ hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
response = await hook.get_async_conn()
assert isinstance(response, DataFactoryManagementClient)
@pytest.mark.asyncio
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
- async def test_get_async_conn_without_login_id(self, mock_connection):
- """Test get_async_conn function without login id"""
- mock_conn = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_data_factory",
- extra=json.dumps(
- {
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ extra={
"tenantId": "tenantId",
"subscriptionId": "subscriptionId",
"resource_group_name": RESOURCE_GROUP_NAME,
"factory_name": DATAFACTORY_NAME,
- }
+ },
),
- )
- mock_connection.return_value = mock_conn
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ ],
+ indirect=True,
+ )
+ async def test_get_async_conn_without_login_id(self, mocked_connection):
+ """Test get_async_conn function without login id"""
+ hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
response = await hook.get_async_conn()
assert isinstance(response, DataFactoryManagementClient)
@pytest.mark.asyncio
@pytest.mark.parametrize(
- "mock_connection_params",
+ "mocked_connection",
[
- {
- "tenantId": "tenantId",
- "resource_group_name": RESOURCE_GROUP_NAME,
- "factory_name": DATAFACTORY_NAME,
- }
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
+ "tenantId": "tenantId",
+ "resource_group_name": RESOURCE_GROUP_NAME,
+ "factory_name": DATAFACTORY_NAME,
+ },
+ )
],
+ indirect=True,
)
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
- async def test_get_async_conn_key_error_subscription_id(self,
mock_connection, mock_connection_params):
+ async def test_get_async_conn_key_error_subscription_id(self,
mocked_connection):
"""Test get_async_conn function when subscription_id is missing in the
connection"""
- mock_conn = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(mock_connection_params),
- )
- mock_connection.return_value = mock_conn
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
with pytest.raises(ValueError):
await hook.get_async_conn()
@pytest.mark.asyncio
@pytest.mark.parametrize(
- "mock_connection_params",
+ "mocked_connection",
[
- {
- "subscriptionId": "subscriptionId",
- "resource_group_name": RESOURCE_GROUP_NAME,
- "factory_name": DATAFACTORY_NAME,
- },
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_data_factory",
+ login="clientId",
+ password="clientSecret",
+ extra={
+ "subscriptionId": "subscriptionId",
+ "resource_group_name": RESOURCE_GROUP_NAME,
+ "factory_name": DATAFACTORY_NAME,
+ },
+ )
],
+ indirect=True,
)
-
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
- async def test_get_async_conn_key_error_tenant_id(self, mock_connection,
mock_connection_params):
+ async def test_get_async_conn_key_error_tenant_id(self, mocked_connection):
"""Test get_async_conn function when tenant id is missing in the
connection"""
- mock_conn = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_data_factory",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(mock_connection_params),
- )
- mock_connection.return_value = mock_conn
- hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+ hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
with pytest.raises(ValueError):
await hook.get_async_conn()
@@ -907,14 +899,12 @@ class TestAzureDataFactoryAsyncHook:
mock_conn = Connection(
conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
conn_type="azure_data_factory",
- extra=json.dumps(
- {
- "tenantId": "tenantId",
- "subscriptionId": "subscriptionId",
- "resource_group_name": RESOURCE_GROUP_NAME,
- "factory_name": DATAFACTORY_NAME,
- }
- ),
+ extra={
+ "tenantId": "tenantId",
+ "subscriptionId": "subscriptionId",
+ "resource_group_name": RESOURCE_GROUP_NAME,
+ "factory_name": DATAFACTORY_NAME,
+ },
)
extras = mock_conn.extra_dejson
assert get_field(extras, "tenantId", strict=True) == "tenantId"
@@ -929,14 +919,12 @@ class TestAzureDataFactoryAsyncHook:
mock_conn = Connection(
conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
conn_type="azure_data_factory",
- extra=json.dumps(
- {
- "tenantId": "tenantId",
- "subscriptionId": "subscriptionId",
- "resource_group_name": RESOURCE_GROUP_NAME,
- "factory_name": DATAFACTORY_NAME,
- }
- ),
+ extra={
+ "tenantId": "tenantId",
+ "subscriptionId": "subscriptionId",
+ "resource_group_name": RESOURCE_GROUP_NAME,
+ "factory_name": DATAFACTORY_NAME,
+ },
)
extras = mock_conn.extra_dejson
assert get_field(extras, "tenantId", strict=True) == "tenantId"
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index 122949beac..f5e2e8be5c 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import json
from unittest import mock
from unittest.mock import PropertyMock
@@ -26,18 +25,18 @@ from azure.storage.filedatalake._models import
FileSystemProperties
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeStorageV2Hook
-from airflow.utils import db
class TestAzureDataLakeHook:
- def setup_method(self):
- db.merge_conn(
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_mock_connections):
+ create_mock_connections(
Connection(
conn_id="adl_test_key",
conn_type="azure_data_lake",
login="client_id",
password="client secret",
- extra=json.dumps({"tenant": "tenant", "account_name":
"accountname"}),
+ extra={"tenant": "tenant", "account_name": "accountname"},
)
)
@@ -58,9 +57,10 @@ class TestAzureDataLakeHook:
def test_check_for_blob(self, mock_lib, mock_filesystem):
from airflow.providers.microsoft.azure.hooks.data_lake import
AzureDataLakeHook
+ mocked_glob = mock_filesystem.return_value.glob
hook = AzureDataLakeHook(azure_data_lake_conn_id="adl_test_key")
hook.check_for_file("file_path")
- mock_filesystem.glob.called
+ mocked_glob.assert_called()
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader",
autospec=True)
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib",
autospec=True)
@@ -140,7 +140,7 @@ class TestAzureDataLakeHook:
class TestAzureDataLakeStorageV2Hook:
def setup_class(self) -> None:
- self.conn_id: str = "adls_conn_id"
+ self.conn_id: str = "adls_conn_id1"
self.file_system_name = "test_file_system"
self.directory_name = "test_directory"
self.file_name = "test_file_name"
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index 1529eadcd9..a99ca9e90b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -25,7 +25,6 @@ and password (=Storage account key), or login and SAS token
in the extra field
"""
from __future__ import annotations
-import json
import os
from unittest import mock
from unittest.mock import patch
@@ -36,51 +35,36 @@ from pytest import param
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.fileshare import
AzureFileShareHook
-from airflow.utils import db
class TestAzureFileshareHook:
- def setup_method(self):
- db.merge_conn(
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connections):
+ create_mock_connections(
Connection(
conn_id="azure_fileshare_test_key",
conn_type="azure_file_share",
login="login",
password="key",
- )
- )
- db.merge_conn(
+ ),
Connection(
conn_id="azure_fileshare_extras",
conn_type="azure_fileshare",
login="login",
- extra=json.dumps(
- {
- "sas_token": "token",
- "protocol": "http",
- }
- ),
- )
- )
- db.merge_conn(
+ extra={"sas_token": "token", "protocol": "http"},
+ ),
# Neither password nor sas_token present
Connection(
conn_id="azure_fileshare_missing_credentials",
conn_type="azure_fileshare",
login="login",
- )
- )
- db.merge_conn(
+ ),
Connection(
conn_id="azure_fileshare_extras_wrong",
conn_type="azure_fileshare",
login="login",
- extra=json.dumps(
- {
- "wrong_key": "token",
- }
- ),
- )
+ extra={"wrong_key": "token"},
+ ),
)
def test_key_and_connection(self):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index 3a82efd812..b63dacc6da 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import json
from unittest.mock import MagicMock, patch
import pytest
@@ -26,7 +25,6 @@ from pytest import fixture
from airflow.models.connection import Connection
from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook,
AzureSynapseSparkBatchRunStatus
-from airflow.utils import db
DEFAULT_SPARK_POOL = "defaultSparkPool"
@@ -42,60 +40,45 @@ ID = "testId"
JOB_ID = 1
-def setup_module():
- connection_client_secret = Connection(
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- conn_type="azure_synapse",
- host="https://testsynapse.dev.azuresynapse.net",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
- "tenantId": "tenantId",
- "subscriptionId": "subscriptionId",
- }
[email protected](autouse=True)
+def setup_connections(create_mock_connections):
+ create_mock_connections(
+ # connection_client_secret
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_synapse",
+ host="https://testsynapse.dev.azuresynapse.net",
+ login="clientId",
+ password="clientSecret",
+ extra={"tenantId": "tenantId", "subscriptionId": "subscriptionId"},
),
- )
- connection_default_credential = Connection(
- conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
- conn_type="azure_synapse",
- host="https://testsynapse.dev.azuresynapse.net",
- extra=json.dumps(
- {
- "subscriptionId": "subscriptionId",
- }
+ # connection_default_credential
+ Connection(
+ conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+ conn_type="azure_synapse",
+ host="https://testsynapse.dev.azuresynapse.net",
+ extra={"subscriptionId": "subscriptionId"},
),
- )
- connection_missing_subscription_id = Connection(
- conn_id="azure_synapse_missing_subscription_id",
- conn_type="azure_synapse",
- host="https://testsynapse.dev.azuresynapse.net",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
- "tenantId": "tenantId",
- }
+ Connection(
+ # connection_missing_subscription_id
+ conn_id="azure_synapse_missing_subscription_id",
+ conn_type="azure_synapse",
+ host="https://testsynapse.dev.azuresynapse.net",
+ login="clientId",
+ password="clientSecret",
+ extra={"tenantId": "tenantId"},
),
- )
- connection_missing_tenant_id = Connection(
- conn_id="azure_synapse_missing_tenant_id",
- conn_type="azure_synapse",
- host="https://testsynapse.dev.azuresynapse.net",
- login="clientId",
- password="clientSecret",
- extra=json.dumps(
- {
- "subscriptionId": "subscriptionId",
- }
+ # connection_missing_tenant_id
+ Connection(
+ conn_id="azure_synapse_missing_tenant_id",
+ conn_type="azure_synapse",
+ host="https://testsynapse.dev.azuresynapse.net",
+ login="clientId",
+ password="clientSecret",
+ extra={"subscriptionId": "subscriptionId"},
),
)
- db.merge_conn(connection_client_secret)
- db.merge_conn(connection_default_credential)
- db.merge_conn(connection_missing_subscription_id)
- db.merge_conn(connection_missing_tenant_id)
-
@fixture
def hook():
diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py
b/tests/providers/microsoft/azure/hooks/test_base_azure.py
index 7b587c4121..53e2614a69 100644
--- a/tests/providers/microsoft/azure/hooks/test_base_azure.py
+++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py
@@ -18,63 +18,71 @@ from __future__ import annotations
from unittest.mock import Mock, patch
+import pytest
+
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
class TestBaseAzureHook:
-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
- @patch(
-
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
- return_value=Connection(conn_id="azure_default", extra='{ "key_path":
"key_file.json" }'),
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [Connection(conn_id="azure_default", extra={"key_path":
"key_file.json"})],
+ indirect=True,
)
- def test_get_conn_with_key_path(self, mock_connection,
mock_get_client_from_auth_file):
+
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
+ def test_get_conn_with_key_path(self, mock_get_client_from_auth_file,
mocked_connection):
+ mock_get_client_from_auth_file.return_value = "foo-bar"
mock_sdk_client = Mock()
auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
mock_get_client_from_auth_file.assert_called_once_with(
- client_class=mock_sdk_client,
auth_path=mock_connection.return_value.extra_dejson["key_path"]
+ client_class=mock_sdk_client,
auth_path=mocked_connection.extra_dejson["key_path"]
)
- assert auth_sdk_client == mock_get_client_from_auth_file.return_value
+ assert auth_sdk_client == "foo-bar"
-
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
- @patch(
-
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
- return_value=Connection(conn_id="azure_default", extra='{ "key_json":
{ "test": "test" } }'),
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [Connection(conn_id="azure_default", extra={"key_json": {"test":
"test"}})],
+ indirect=True,
)
- def test_get_conn_with_key_json(self, mock_connection,
mock_get_client_from_json_dict):
+
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
+ def test_get_conn_with_key_json(self, mock_get_client_from_json_dict,
mocked_connection):
mock_sdk_client = Mock()
-
+ mock_get_client_from_json_dict.return_value = "foo-bar"
auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
mock_get_client_from_json_dict.assert_called_once_with(
- client_class=mock_sdk_client,
config_dict=mock_connection.return_value.extra_dejson["key_json"]
+ client_class=mock_sdk_client,
config_dict=mocked_connection.extra_dejson["key_json"]
)
- assert auth_sdk_client == mock_get_client_from_json_dict.return_value
+ assert auth_sdk_client == "foo-bar"
@patch("airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials")
- @patch(
-
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
- return_value=Connection(
- conn_id="azure_default",
- login="my_login",
- password="my_password",
- extra='{ "tenantId": "my_tenant", "subscriptionId":
"my_subscription" }',
- ),
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="azure_default",
+ login="my_login",
+ password="my_password",
+ extra={"tenantId": "my_tenant", "subscriptionId":
"my_subscription"},
+ )
+ ],
+ indirect=True,
)
- def test_get_conn_with_credentials(self, mock_connection, mock_spc):
- mock_sdk_client = Mock()
-
+ def test_get_conn_with_credentials(self, mock_spc, mocked_connection):
+ mock_sdk_client = Mock(return_value="spam-egg")
+ mock_spc.return_value = "foo-bar"
auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
mock_spc.assert_called_once_with(
- client_id=mock_connection.return_value.login,
- secret=mock_connection.return_value.password,
- tenant=mock_connection.return_value.extra_dejson["tenantId"],
+ client_id=mocked_connection.login,
+ secret=mocked_connection.password,
+ tenant=mocked_connection.extra_dejson["tenantId"],
)
mock_sdk_client.assert_called_once_with(
- credentials=mock_spc.return_value,
-
subscription_id=mock_connection.return_value.extra_dejson["subscriptionId"],
+ credentials="foo-bar",
+ subscription_id=mocked_connection.extra_dejson["subscriptionId"],
)
- assert auth_sdk_client == mock_sdk_client.return_value
+ assert auth_sdk_client == "spam-egg"
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 1f48a7b011..06ef4eedfb 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import json
from unittest import mock
import pytest
@@ -34,269 +33,262 @@ CONN_STRING = (
)
ACCESS_KEY_STRING = "AccountName=name;skdkskd"
+PROXIES = {"http": "http_proxy_uri", "https": "https_proxy_uri"}
+
+
[email protected]
+def mocked_blob_service_client():
+ with
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") as
m:
+ yield m
+
+
[email protected]
+def mocked_default_azure_credential():
+ with
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
as m:
+ yield m
+
+
[email protected]
+def mocked_client_secret_credential():
+ with
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential")
as m:
+ yield m
class TestWasbHook:
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_method(self, create_mock_connections):
self.login = "login"
self.wasb_test_key = "wasb_test_key"
self.connection_type = "wasb"
- self.connection_string_id = "azure_test_connection_string"
- self.shared_key_conn_id = "azure_shared_key_test"
- self.shared_key_conn_id_without_host =
"azure_shared_key_test_wihout_host"
- self.ad_conn_id = "azure_AD_test"
+ self.azure_test_connection_string = "azure_test_connection_string"
+ self.azure_shared_key_test = "azure_shared_key_test"
+ self.ad_conn_id = "ad_conn_id"
self.sas_conn_id = "sas_token_id"
- self.extra__wasb__sas_conn_id = "extra__sas_token_id"
- self.http_sas_conn_id = "http_sas_token_id"
- self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id"
+ self.extra__wasb__sas_conn_id = "extra__wasb__sas_conn_id"
+ self.http_sas_conn_id = "http_sas_conn_id"
+ self.extra__wasb__http_sas_conn_id = "extra__wasb__http_sas_conn_id"
self.public_read_conn_id = "pub_read_id"
self.public_read_conn_id_without_host = "pub_read_id_without_host"
- self.managed_identity_conn_id = "managed_identity"
+ self.managed_identity_conn_id = "managed_identity_conn_id"
self.authority = "https://test_authority.com"
- self.proxies = {"http": "http_proxy_uri", "https": "https_proxy_uri"}
+ self.proxies = PROXIES
self.client_secret_auth_config = {
"proxies": self.proxies,
"connection_verify": False,
"authority": self.authority,
}
- self.connection_map = {
- self.wasb_test_key: Connection(
+
+ conns = create_mock_connections(
+ Connection(
conn_id="wasb_test_key",
conn_type=self.connection_type,
login=self.login,
password="key",
),
- self.public_read_conn_id: Connection(
+ Connection(
conn_id=self.public_read_conn_id,
conn_type=self.connection_type,
host="https://accountname.blob.core.windows.net",
- extra=json.dumps({"proxies": self.proxies}),
+ extra={"proxies": self.proxies},
),
- self.public_read_conn_id_without_host: Connection(
+ Connection(
conn_id=self.public_read_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
- extra=json.dumps({"proxies": self.proxies}),
+ extra={"proxies": self.proxies},
),
- self.connection_string_id: Connection(
- conn_id=self.connection_string_id,
+ Connection(
+ conn_id=self.azure_test_connection_string,
conn_type=self.connection_type,
- extra=json.dumps({"connection_string": CONN_STRING, "proxies":
self.proxies}),
+ extra={"connection_string": CONN_STRING, "proxies":
self.proxies},
),
- self.shared_key_conn_id: Connection(
- conn_id=self.shared_key_conn_id,
+ Connection(
+ conn_id=self.azure_shared_key_test,
conn_type=self.connection_type,
host="https://accountname.blob.core.windows.net",
- extra=json.dumps({"shared_access_key": "token", "proxies":
self.proxies}),
+ extra={"shared_access_key": "token", "proxies": self.proxies},
),
- self.shared_key_conn_id_without_host: Connection(
- conn_id=self.shared_key_conn_id_without_host,
- conn_type=self.connection_type,
- login=self.login,
- extra=json.dumps({"shared_access_key": "token", "proxies":
self.proxies}),
- ),
- self.ad_conn_id: Connection(
+ Connection(
conn_id=self.ad_conn_id,
conn_type=self.connection_type,
host="conn_host",
login="appID",
password="appsecret",
- extra=json.dumps(
- {
- "tenant_id": "token",
- "proxies": self.proxies,
- "client_secret_auth_config":
self.client_secret_auth_config,
- }
- ),
+ extra={
+ "tenant_id": "token",
+ "proxies": self.proxies,
+ "client_secret_auth_config":
self.client_secret_auth_config,
+ },
),
- self.managed_identity_conn_id: Connection(
+ Connection(
conn_id=self.managed_identity_conn_id,
conn_type=self.connection_type,
- extra=json.dumps({"proxies": self.proxies}),
+ extra={"proxies": self.proxies},
),
- self.sas_conn_id: Connection(
- conn_id=self.sas_conn_id,
+ Connection(
+ conn_id="sas_conn_id",
conn_type=self.connection_type,
login=self.login,
- extra=json.dumps({"sas_token": "token", "proxies":
self.proxies}),
+ extra={"sas_token": "token", "proxies": self.proxies},
),
- self.extra__wasb__sas_conn_id: Connection(
+ Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
login=self.login,
- extra=json.dumps({"extra__wasb__sas_token": "token",
"proxies": self.proxies}),
+ extra={"extra__wasb__sas_token": "token", "proxies":
self.proxies},
),
- self.http_sas_conn_id: Connection(
+ Connection(
conn_id=self.http_sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps(
- {"sas_token": "https://login.blob.core.windows.net/token",
"proxies": self.proxies}
- ),
+ extra={"sas_token":
"https://login.blob.core.windows.net/token", "proxies": self.proxies},
),
- self.extra__wasb__http_sas_conn_id: Connection(
+ Connection(
conn_id=self.extra__wasb__http_sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps(
- {
- "extra__wasb__sas_token":
"https://login.blob.core.windows.net/token",
- "proxies": self.proxies,
- }
- ),
+ extra={
+ "extra__wasb__sas_token":
"https://login.blob.core.windows.net/token",
+ "proxies": self.proxies,
+ },
),
- }
-
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_key(self, mock_get_conn, mock_blob_service_client):
- conn = self.connection_map[self.wasb_test_key]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.wasb_test_key)
- assert mock_blob_service_client.call_args == mock.call(
- account_url=f"https://{self.login}.blob.core.windows.net/",
- credential=conn.password,
+ )
+ self.connection_map = {conn.conn_id: conn for conn in conns}
+
+ def test_key(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.wasb_test_key)
+ mocked_blob_service_client.assert_not_called() # Not expected during
initialisation
+ hook.get_conn()
+ mocked_blob_service_client.assert_called_once_with(
+ account_url=f"https://{self.login}.blob.core.windows.net/",
credential="key"
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_public_read(self, mock_get_conn, mock_blob_service_client):
- conn = self.connection_map[self.public_read_conn_id]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True)
- assert mock_blob_service_client.call_args == mock.call(
- account_url=conn.host,
- proxies=conn.extra_dejson["proxies"],
+ def test_public_read(self, mocked_blob_service_client):
+ WasbHook(wasb_conn_id=self.public_read_conn_id,
public_read=True).get_conn()
+ mocked_blob_service_client.assert_called_once_with(
+ account_url="https://accountname.blob.core.windows.net",
proxies=self.proxies
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_connection_string(self, mock_get_conn, mock_blob_service_client):
- conn = self.connection_map[self.connection_string_id]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.connection_string_id)
-
mock_blob_service_client.from_connection_string.assert_called_once_with(
- CONN_STRING,
- proxies=conn.extra_dejson["proxies"],
- connection_string=conn.extra_dejson["connection_string"],
+ def test_connection_string(self, mocked_blob_service_client):
+ WasbHook(wasb_conn_id=self.azure_test_connection_string).get_conn()
+
mocked_blob_service_client.from_connection_string.assert_called_once_with(
+ CONN_STRING, proxies=self.proxies, connection_string=CONN_STRING
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_shared_key_connection(self, mock_get_conn,
mock_blob_service_client):
- conn = self.connection_map[self.shared_key_conn_id]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.shared_key_conn_id)
- mock_blob_service_client.assert_called_once_with(
- account_url=conn.host,
- credential=conn.extra_dejson["shared_access_key"],
- proxies=conn.extra_dejson["proxies"],
- shared_access_key=conn.extra_dejson["shared_access_key"],
+ def test_shared_key_connection(self, mocked_blob_service_client):
+ WasbHook(wasb_conn_id=self.azure_shared_key_test).get_conn()
+ mocked_blob_service_client.assert_called_once_with(
+ account_url="https://accountname.blob.core.windows.net",
+ credential="token",
+ proxies=self.proxies,
+ shared_access_key="token",
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_managed_identity(self, mock_get_conn, mock_credential,
mock_blob_service_client):
- conn = self.connection_map[self.managed_identity_conn_id]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.managed_identity_conn_id)
- mock_blob_service_client.assert_called_once_with(
- account_url=f"https://{conn.login}.blob.core.windows.net/",
- credential=mock_credential.return_value,
- proxies=conn.extra_dejson["proxies"],
+ def test_managed_identity(self, mocked_default_azure_credential,
mocked_blob_service_client):
+ mocked_default_azure_credential.return_value = "foo-bar"
+ WasbHook(wasb_conn_id=self.managed_identity_conn_id).get_conn()
+ mocked_blob_service_client.assert_called_once_with(
+ account_url="https://None.blob.core.windows.net/",
+ credential="foo-bar",
+ proxies=self.proxies,
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_azure_directory_connection(self, mock_get_conn, mock_credential,
mock_blob_service_client):
- conn = self.connection_map[self.ad_conn_id]
- mock_get_conn.return_value = conn
- WasbHook(wasb_conn_id=self.ad_conn_id)
- mock_credential.assert_called_once_with(
- conn.extra_dejson["tenant_id"],
- conn.login,
- conn.password,
+ def test_azure_directory_connection(self, mocked_client_secret_credential,
mocked_blob_service_client):
+ mocked_client_secret_credential.return_value = "spam-egg"
+ WasbHook(wasb_conn_id=self.ad_conn_id).get_conn()
+ mocked_client_secret_credential.assert_called_once_with(
+ "token",
+ "appID",
+ "appsecret",
proxies=self.client_secret_auth_config["proxies"],
connection_verify=self.client_secret_auth_config["connection_verify"],
authority=self.client_secret_auth_config["authority"],
)
- mock_blob_service_client.assert_called_once_with(
- account_url=f"https://{conn.login}.blob.core.windows.net/",
- credential=mock_credential.return_value,
- tenant_id=conn.extra_dejson["tenant_id"],
- proxies=conn.extra_dejson["proxies"],
+ mocked_blob_service_client.assert_called_once_with(
+ account_url="https://appID.blob.core.windows.net/",
+ credential="spam-egg",
+ tenant_id="token",
+ proxies=self.proxies,
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_active_directory_ID_used_as_host(self, mock_get_conn,
mock_credential, mock_blob_service_client):
- mock_get_conn.return_value = Connection(
- conn_id="testconn",
- conn_type=self.connection_type,
- login="testaccountname",
- host="testaccountID",
- )
- WasbHook(wasb_conn_id="testconn")
- assert mock_blob_service_client.call_args == mock.call(
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="testconn",
+ conn_type="wasb",
+ login="testaccountname",
+ host="testaccountID",
+ )
+ ],
+ indirect=True,
+ )
+ def test_active_directory_id_used_as_host(
+ self, mocked_connection, mocked_default_azure_credential,
mocked_blob_service_client
+ ):
+ mocked_default_azure_credential.return_value = "fake-credential"
+ WasbHook(wasb_conn_id="testconn").get_conn()
+ mocked_blob_service_client.assert_called_once_with(
account_url="https://testaccountname.blob.core.windows.net/",
- credential=mock_credential.return_value,
+ credential="fake-credential",
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_sas_token_provided_and_active_directory_ID_used_as_host(
- self, mock_get_conn, mock_blob_service_client
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="testconn",
+ conn_type="wasb",
+ login="testaccountname",
+ host="testaccountID",
+ extra={"sas_token": "SAStoken"},
+ )
+ ],
+ indirect=True,
+ )
+ def test_sas_token_provided_and_active_directory_id_used_as_host(
+ self, mocked_connection, mocked_blob_service_client
):
- mock_get_conn.return_value = Connection(
- conn_id="testconn",
- conn_type=self.connection_type,
- login="testaccountname",
- host="testaccountID",
- extra=json.dumps({"sas_token": "SAStoken"}),
- )
- WasbHook(wasb_conn_id="testconn")
- assert mock_blob_service_client.call_args == mock.call(
+ WasbHook(wasb_conn_id="testconn").get_conn()
+ mocked_blob_service_client.assert_called_once_with(
account_url="https://testaccountname.blob.core.windows.net/SAStoken",
sas_token="SAStoken",
)
@pytest.mark.parametrize(
- argnames="conn_id_str",
- argvalues=[
- "wasb_test_key",
- "shared_key_conn_id_without_host",
- "public_read_conn_id_without_host",
+ "mocked_connection",
+ [
+ pytest.param(
+ Connection(
+ conn_type="wasb",
+ login="foo",
+ extra={"shared_access_key": "token", "proxies": PROXIES},
+ ),
+ id="shared-key-without-host",
+ ),
+ pytest.param(
+ Connection(conn_type="wasb", login="foo", extra={"proxies":
PROXIES}),
+ id="public-read-without-host",
+ ),
],
+ indirect=True,
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
def test_account_url_without_host(
- self, mock_get_conn, mock_credential, mock_blob_service_client,
conn_id_str
+ self, mocked_connection, mocked_blob_service_client,
mocked_default_azure_credential
):
- conn_id = self.__getattribute__(conn_id_str)
- connection = self.connection_map[conn_id]
- mock_get_conn.return_value = connection
- WasbHook(wasb_conn_id=conn_id)
- if conn_id_str == "wasb_test_key":
- mock_blob_service_client.assert_called_once_with(
-
account_url=f"https://{connection.login}.blob.core.windows.net/",
- credential=connection.password,
- )
- elif conn_id_str == "shared_key_conn_id_without_host":
- mock_blob_service_client.assert_called_once_with(
-
account_url=f"https://{connection.login}.blob.core.windows.net/",
- credential=connection.extra_dejson["shared_access_key"],
- proxies=connection.extra_dejson["proxies"],
- shared_access_key=connection.extra_dejson["shared_access_key"],
+ mocked_default_azure_credential.return_value = "default-creds"
+ WasbHook(wasb_conn_id=mocked_connection.conn_id).get_conn()
+ if "shared_access_key" in mocked_connection.extra_dejson:
+ mocked_blob_service_client.assert_called_once_with(
+
account_url=f"https://{mocked_connection.login}.blob.core.windows.net/",
+ credential=mocked_connection.extra_dejson["shared_access_key"],
+ proxies=mocked_connection.extra_dejson["proxies"],
+
shared_access_key=mocked_connection.extra_dejson["shared_access_key"],
)
else:
- mock_blob_service_client.assert_called_once_with(
-
account_url=f"https://{connection.login}.blob.core.windows.net/",
- credential=mock_credential.return_value,
- proxies=connection.extra_dejson["proxies"],
+ mocked_blob_service_client.assert_called_once_with(
+
account_url=f"https://{mocked_connection.login}.blob.core.windows.net/",
+ credential="default-creds",
+ proxies=mocked_connection.extra_dejson["proxies"],
)
@pytest.mark.parametrize(
@@ -308,25 +300,22 @@ class TestWasbHook:
("extra__wasb__http_sas_conn_id", "extra__wasb__sas_token"),
],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_sas_token_connection(self, mock_get_conn, conn_id_str, extra_key):
- conn_id = self.__getattribute__(conn_id_str)
- mock_get_conn.return_value = self.connection_map[conn_id]
- hook = WasbHook(wasb_conn_id=conn_id)
+ def test_sas_token_connection(self, conn_id_str, extra_key):
+ hook = WasbHook(wasb_conn_id=conn_id_str)
conn = hook.get_conn()
hook_conn = hook.get_connection(hook.conn_id)
sas_token = hook_conn.extra_dejson[extra_key]
assert isinstance(conn, BlobServiceClient)
assert conn.url.startswith("https://")
if hook_conn.login:
- assert conn.url.__contains__(hook_conn.login)
+ assert hook_conn.login in conn.url
assert conn.url.endswith(sas_token + "/")
@pytest.mark.parametrize(
argnames="conn_id_str",
argvalues=[
- "connection_string_id",
- "shared_key_conn_id",
+ "azure_test_connection_string",
+ "azure_shared_key_test",
"ad_conn_id",
"managed_identity_conn_id",
"sas_conn_id",
@@ -335,27 +324,17 @@ class TestWasbHook:
"extra__wasb__http_sas_conn_id",
],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_connection_extra_arguments(self, mock_get_conn, conn_id_str):
- conn_id = self.__getattribute__(conn_id_str)
- mock_get_conn.return_value = self.connection_map[conn_id]
- hook = WasbHook(wasb_conn_id=conn_id)
- conn = hook.get_conn()
+ def test_connection_extra_arguments(self, conn_id_str):
+ conn = WasbHook(wasb_conn_id=conn_id_str).get_conn()
assert conn._config.proxy_policy.proxies == self.proxies
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_connection_extra_arguments_public_read(self, mock_get_conn):
- conn_id = self.public_read_conn_id
- mock_get_conn.return_value = self.connection_map[conn_id]
- hook = WasbHook(wasb_conn_id=conn_id, public_read=True)
+ def test_connection_extra_arguments_public_read(self):
+ hook = WasbHook(wasb_conn_id=self.public_read_conn_id,
public_read=True)
conn = hook.get_conn()
assert conn._config.proxy_policy.proxies == self.proxies
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_extra_client_secret_auth_config_ad_connection(self,
mock_get_conn):
- mock_get_conn.return_value = self.connection_map[self.ad_conn_id]
- conn_id = self.ad_conn_id
- hook = WasbHook(wasb_conn_id=conn_id)
+ def test_extra_client_secret_auth_config_ad_connection(self):
+ hook = WasbHook(wasb_conn_id=self.ad_conn_id)
conn = hook.get_conn()
assert conn.credential._authority == self.authority
@@ -371,79 +350,62 @@ class TestWasbHook:
("testhost.blob.net", "testhost.blob.net"),
],
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
def test_proper_account_url_update(
- self, mock_get_conn, mock_blob_service_client, provided_host,
expected_host
+ self, mocked_blob_service_client, provided_host, expected_host,
create_mock_connection
):
- mock_get_conn.return_value = Connection(
- conn_id="test_conn",
- conn_type=self.connection_type,
- password="testpass",
- login="accountlogin",
- host=provided_host,
+ conn = create_mock_connection(
+ Connection(
+ conn_type=self.connection_type,
+ password="testpass",
+ login="accountlogin",
+ host=provided_host,
+ )
)
- WasbHook(wasb_conn_id=self.shared_key_conn_id)
-
mock_blob_service_client.assert_called_once_with(account_url=expected_host,
credential="testpass")
-
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_check_for_blob(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ WasbHook(wasb_conn_id=conn.conn_id).get_conn()
+
mocked_blob_service_client.assert_called_once_with(account_url=expected_host,
credential="testpass")
+
+ def test_check_for_blob(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
assert hook.check_for_blob(container_name="mycontainer",
blob_name="myblob")
- mock_blob_client = mock_service.return_value.get_blob_client
+ mock_blob_client =
mocked_blob_service_client.return_value.get_blob_client
mock_blob_client.assert_called_once_with(container="mycontainer",
blob="myblob")
mock_blob_client.return_value.get_blob_properties.assert_called()
- @mock.patch.object(WasbHook, "get_blobs_list")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_check_for_prefix(self, mock_get_conn, get_blobs_list):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- get_blobs_list.return_value = ["blobs"]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ @mock.patch.object(WasbHook, "get_blobs_list", return_value=["blobs"])
+ def test_check_for_prefix(self, get_blobs_list):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
assert hook.check_for_prefix("container", "prefix", timeout=3)
get_blobs_list.assert_called_once_with(container_name="container",
prefix="prefix", timeout=3)
- @mock.patch.object(WasbHook, "get_blobs_list")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_check_for_prefix_empty(self, mock_get_conn, get_blobs_list):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- get_blobs_list.return_value = []
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ @mock.patch.object(WasbHook, "get_blobs_list", return_value=[])
+ def test_check_for_prefix_empty(self, get_blobs_list):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
assert not hook.check_for_prefix("container", "prefix", timeout=3)
get_blobs_list.assert_called_once_with(container_name="container",
prefix="prefix", timeout=3)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_blobs_list(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_get_blobs_list(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.get_blobs_list(container_name="mycontainer", prefix="my",
include=None, delimiter="/")
-
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-
mock_service.return_value.get_container_client.return_value.walk_blobs.assert_called_once_with(
+ mock_container_client =
mocked_blob_service_client.return_value.get_container_client
+ mock_container_client.assert_called_once_with("mycontainer")
+ mock_container_client.return_value.walk_blobs.assert_called_once_with(
name_starts_with="my", include=None, delimiter="/"
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_blobs_list_recursive(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_get_blobs_list_recursive(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.get_blobs_list_recursive(
container_name="mycontainer", prefix="test", include=None,
endswith="file_extension"
)
-
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-
mock_service.return_value.get_container_client.return_value.list_blobs.assert_called_once_with(
+ mock_container_client =
mocked_blob_service_client.return_value.get_container_client
+ mock_container_client.assert_called_once_with("mycontainer")
+ mock_container_client.return_value.list_blobs.assert_called_once_with(
name_starts_with="test", include=None
)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_blobs_list_recursive_endswith(self, mock_get_conn,
mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
-
mock_service.return_value.get_container_client.return_value.list_blobs.return_value
= [
+ def test_get_blobs_list_recursive_endswith(self,
mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
+
mocked_blob_service_client.return_value.get_container_client.return_value.list_blobs.return_value
= [
BlobProperties(name="test/abc.py"),
BlobProperties(name="test/inside_test/abc.py"),
BlobProperties(name="test/abc.csv"),
@@ -455,11 +417,9 @@ class TestWasbHook:
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
@mock.patch.object(WasbHook, "upload")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_load_file(self, mock_get_conn, mock_upload, create_container):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
+ def test_load_file(self, mock_upload, create_container):
with mock.patch("builtins.open", mock.mock_open(read_data="data")):
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.load_file("path", "container", "blob", create_container,
max_connections=1)
mock_upload.assert_called_with(
@@ -472,10 +432,8 @@ class TestWasbHook:
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
@mock.patch.object(WasbHook, "upload")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_load_string(self, mock_get_conn, mock_upload, create_container):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_load_string(self, mock_upload, create_container):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.load_string("big string", "container", "blob", create_container,
max_connections=1)
mock_upload.assert_called_once_with(
container_name="container",
@@ -486,30 +444,22 @@ class TestWasbHook:
)
@mock.patch.object(WasbHook, "download")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_file(self, mock_get_conn, mock_download):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
+ def test_get_file(self, mock_download):
with mock.patch("builtins.open", mock.mock_open(read_data="data")):
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.get_file("path", "container", "blob", max_connections=1)
mock_download.assert_called_once_with(container_name="container",
blob_name="blob", max_connections=1)
mock_download.return_value.readall.assert_called()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "download")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_read_file(self, mock_get_conn, mock_download, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_read_file(self, mock_download, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.read_file("container", "blob", max_connections=1)
mock_download.assert_called_once_with("container", "blob",
max_connections=1)
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_upload(self, mock_get_conn, mock_service, create_container):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_upload(self, mocked_blob_service_client, create_container):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.upload(
container_name="mycontainer",
blob_name="myblob",
@@ -518,80 +468,61 @@ class TestWasbHook:
blob_type="BlockBlob",
length=4,
)
- mock_blob_client = mock_service.return_value.get_blob_client
+ mock_blob_client =
mocked_blob_service_client.return_value.get_blob_client
mock_blob_client.assert_called_once_with(container="mycontainer",
blob="myblob")
mock_blob_client.return_value.upload_blob.assert_called_once_with(b"mydata",
"BlockBlob", length=4)
- mock_container_client = mock_service.return_value.get_container_client
+ mock_container_client =
mocked_blob_service_client.return_value.get_container_client
if create_container:
mock_container_client.assert_called_with("mycontainer")
else:
mock_container_client.assert_not_called()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_download(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- blob_client = mock_service.return_value.get_blob_client
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_download(self, mocked_blob_service_client):
+ blob_client = mocked_blob_service_client.return_value.get_blob_client
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.download(container_name="mycontainer", blob_name="myblob",
offset=2, length=4)
blob_client.assert_called_once_with(container="mycontainer",
blob="myblob")
blob_client.return_value.download_blob.assert_called_once_with(offset=2,
length=4)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_container_client(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_get_container_client(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook._get_container_client("mycontainer")
-
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
+
mocked_blob_service_client.return_value.get_container_client.assert_called_once_with("mycontainer")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_get_blob_client(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_get_blob_client(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook._get_blob_client(container_name="mycontainer", blob_name="myblob")
- mock_instance = mock_service.return_value.get_blob_client
+ mock_instance = mocked_blob_service_client.return_value.get_blob_client
mock_instance.assert_called_once_with(container="mycontainer",
blob="myblob")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_create_container(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_create_container(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.create_container(container_name="mycontainer")
- mock_instance = mock_service.return_value.get_container_client
+ mock_instance =
mocked_blob_service_client.return_value.get_container_client
mock_instance.assert_called_once_with("mycontainer")
mock_instance.return_value.create_container.assert_called()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_container(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_delete_container(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_container("mycontainer")
-
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-
mock_service.return_value.get_container_client.return_value.delete_container.assert_called()
+ mocked_container_client =
mocked_blob_service_client.return_value.get_container_client
+ mocked_container_client.assert_called_once_with("mycontainer")
+ mocked_container_client.return_value.delete_container.assert_called()
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "delete_blobs")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_single_blob(self, mock_get_conn, delete_blobs,
mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_delete_single_blob(self, delete_blobs,
mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_file("container", "blob", is_prefix=False)
delete_blobs.assert_called_once_with("container", "blob")
@mock.patch.object(WasbHook, "delete_blobs")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_multiple_blobs(self, mock_get_conn, mock_check,
mock_get_blobslist, mock_delete_blobs):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
+ def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist,
mock_delete_blobs):
mock_check.return_value = False
mock_get_blobslist.return_value = ["blob_prefix/blob1",
"blob_prefix/blob2"]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_file("container", "blob_prefix", is_prefix=True)
mock_get_blobslist.assert_called_once_with("container",
prefix="blob_prefix", delimiter="")
mock_delete_blobs.assert_any_call(
@@ -604,14 +535,10 @@ class TestWasbHook:
@mock.patch.object(WasbHook, "delete_blobs")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_more_than_256_blobs(
- self, mock_get_conn, mock_check, mock_get_blobslist, mock_delete_blobs
- ):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
+ def test_delete_more_than_256_blobs(self, mock_check, mock_get_blobslist,
mock_delete_blobs):
mock_check.return_value = False
mock_get_blobslist.return_value = [f"blob_prefix/blob{i}" for i in
range(300)]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_file("container", "blob_prefix", is_prefix=True)
mock_get_blobslist.assert_called_once_with("container",
prefix="blob_prefix", delimiter="")
# The maximum number of blobs that can be deleted in a single request
is 256 using the underlying
@@ -620,33 +547,26 @@ class TestWasbHook:
# `ContainerClient.delete_blobs()` in this test.
assert mock_delete_blobs.call_count == 2
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_nonexisting_blob_fails(self, mock_get_conn, mock_check,
mock_getblobs, mock_service):
+ def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs,
mocked_blob_service_client):
mock_getblobs.return_value = []
mock_check.return_value = False
with pytest.raises(Exception) as ctx:
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_file("container", "nonexisting_blob", is_prefix=False,
ignore_if_missing=False)
assert isinstance(ctx.value, AirflowException)
@mock.patch.object(WasbHook, "get_blobs_list")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_delete_multiple_nonexisting_blobs_fails(self, mock_get_conn,
mock_getblobs):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
+ def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs):
mock_getblobs.return_value = []
with pytest.raises(Exception) as ctx:
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.delete_file("container", "nonexisting_blob_prefix",
is_prefix=True, ignore_if_missing=False)
assert isinstance(ctx.value, AirflowException)
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_connection_success(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_connection_success(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.get_conn().get_account_information().return_value = {
"sku_name": "Standard_RAGRS",
"account_kind": "StorageV2",
@@ -656,11 +576,8 @@ class TestWasbHook:
assert status is True
assert msg == "Successfully connected to Azure Blob Storage."
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
- def test_connection_failure(self, mock_get_conn, mock_service):
- mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ def test_connection_failure(self, mocked_blob_service_client):
+ hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
hook.get_conn().get_account_information = mock.PropertyMock(
side_effect=Exception("Authentication failed.")
)
diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py
b/tests/providers/microsoft/azure/operators/test_azure_batch.py
index e920f7c1d9..e70db80913 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py
@@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
from airflow.providers.microsoft.azure.operators.batch import
AzureBatchOperator
-from airflow.utils import db
TASK_ID = "MyDag"
BATCH_POOL_ID = "MyPool"
@@ -40,11 +39,20 @@ FORMULA = """$curTime = time();
$TargetDedicated = $isWorkingWeekdayHour ? 20:10;"""
[email protected]
+def mocked_batch_service_client():
+ with
mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
as m:
+ yield m
+
+
class TestAzureBatchOperator:
# set up the test environment
- @mock.patch("airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
- def setup_method(self, method, mock_batch, mock_hook):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, mocked_batch_service_client,
create_mock_connections):
+ # set up mocked Azure Batch client
+ self.batch_client = mock.MagicMock(name="FakeBatchServiceClient")
+ mocked_batch_service_client.return_value = self.batch_client
+
# set up the test variable
self.test_vm_conn_id = "test_azure_batch_vm2"
self.test_cloud_conn_id = "test_azure_batch_cloud2"
@@ -59,22 +67,21 @@ class TestAzureBatchOperator:
self.test_cloud_os_version = "test-version"
self.test_node_agent_sku = "test-node-agent-sku"
- # connect with vm configuration
- db.merge_conn(
+ create_mock_connections(
+ # connect with vm configuration
Connection(
conn_id=self.test_vm_conn_id,
conn_type="azure_batch",
extra=json.dumps({"account_url": self.test_account_url}),
- )
- )
- # connect with cloud service
- db.merge_conn(
+ ),
+ # connect with cloud service
Connection(
conn_id=self.test_cloud_conn_id,
conn_type="azure_batch",
extra=json.dumps({"account_url": self.test_account_url}),
- )
+ ),
)
+
self.operator = AzureBatchOperator(
task_id=TASK_ID,
batch_pool_id=BATCH_POOL_ID,
@@ -159,10 +166,6 @@ class TestAzureBatchOperator:
target_dedicated_nodes=1,
timeout=2,
)
- self.batch_client = mock_batch.return_value
- self.mock_instance = mock_hook.return_value
- assert self.batch_client == self.operator.hook.connection
- assert self.batch_client == self.operator2_pass.hook.connection
@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_execute_without_failures(self, wait_mock):
diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
index 57e5097006..599e24889e 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
@@ -17,33 +17,35 @@
# under the License.
from __future__ import annotations
-import json
import uuid
from unittest import mock
+import pytest
+
from airflow.models import Connection
from airflow.providers.microsoft.azure.operators.cosmos import
AzureCosmosInsertDocumentOperator
-from airflow.utils import db
class TestAzureCosmosDbHook:
# Set up an environment to test with
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
# set up some test variables
self.test_end_point = "https://test_endpoint:443"
self.test_master_key = "magic_test_key"
self.test_database_name = "test_database_name"
self.test_collection_name = "test_collection_name"
- db.merge_conn(
+ create_mock_connection(
Connection(
conn_id="azure_cosmos_test_key_id",
conn_type="azure_cosmos",
login=self.test_end_point,
password=self.test_master_key,
- extra=json.dumps(
- {"database_name": self.test_database_name,
"collection_name": self.test_collection_name}
- ),
+ extra={
+ "database_name": self.test_database_name,
+ "collection_name": self.test_collection_name,
+ },
)
)
diff --git
a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
index 2a545cd137..e4dd61c1d2 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import json
from unittest import mock
from unittest.mock import MagicMock, patch
@@ -35,7 +34,7 @@ from airflow.providers.microsoft.azure.hooks.data_factory
import (
)
from airflow.providers.microsoft.azure.operators.data_factory import
AzureDataFactoryRunPipelineOperator
from airflow.providers.microsoft.azure.triggers.data_factory import
AzureDataFactoryTrigger
-from airflow.utils import db, timezone
+from airflow.utils import timezone
from airflow.utils.types import DagRunType
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
@@ -60,7 +59,8 @@ AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"
class TestAzureDataFactoryRunPipelineOperator:
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
self.mock_ti = MagicMock()
self.mock_context = {"ti": self.mock_ti}
self.config = {
@@ -73,13 +73,13 @@ class TestAzureDataFactoryRunPipelineOperator:
"timeout": 3,
}
- db.merge_conn(
+ create_mock_connection(
Connection(
conn_id="azure_data_factory_test",
conn_type="azure_data_factory",
login="client-id",
password="client-secret",
- extra=json.dumps(CONN_EXTRAS),
+ extra=CONN_EXTRAS,
)
)
diff --git a/tests/providers/microsoft/azure/operators/test_azure_synapse.py
b/tests/providers/microsoft/azure/operators/test_azure_synapse.py
index c43b11ef7b..233e1c57fd 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_synapse.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import json
from unittest import mock
from unittest.mock import MagicMock
@@ -24,7 +23,7 @@ import pytest
from airflow.models import Connection
from airflow.providers.microsoft.azure.operators.synapse import
AzureSynapseRunSparkBatchOperator
-from airflow.utils import db, timezone
+from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
SUBSCRIPTION_ID = "my-subscription-id"
@@ -39,7 +38,8 @@ JOB_RUN_RESPONSE = {"id": 123}
class TestAzureSynapseRunSparkBatchOperator:
- def setup_method(self):
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
self.mock_ti = MagicMock()
self.mock_context = {"ti": self.mock_ti}
self.config = {
@@ -50,22 +50,21 @@ class TestAzureSynapseRunSparkBatchOperator:
"timeout": 3,
}
- db.merge_conn(
+ create_mock_connection(
Connection(
conn_id=AZURE_SYNAPSE_CONN_ID,
conn_type="azure_synapse",
host="https://synapsetest.net",
login="client-id",
password="client-secret",
- extra=json.dumps(CONN_EXTRAS),
+ extra=CONN_EXTRAS,
)
)
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
def test_azure_synapse_run_spark_batch_operator_success(
- self, mock_run_spark_job, mock_conn, mock_get_job_run_status
+ self, mock_run_spark_job, mock_get_job_run_status
):
mock_get_job_run_status.return_value = "success"
mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
@@ -76,11 +75,8 @@ class TestAzureSynapseRunSparkBatchOperator:
assert op.job_id == JOB_RUN_RESPONSE["id"]
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
- def test_azure_synapse_run_spark_batch_operator_error(
- self, mock_run_spark_job, mock_conn, mock_get_job_run_status
- ):
+ def test_azure_synapse_run_spark_batch_operator_error(self,
mock_run_spark_job, mock_get_job_run_status):
mock_get_job_run_status.return_value = "error"
mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
op = AzureSynapseRunSparkBatchOperator(
@@ -93,11 +89,10 @@ class TestAzureSynapseRunSparkBatchOperator:
op.execute(context=self.mock_context)
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.cancel_job_run")
def test_azure_synapse_run_spark_batch_operator_on_kill(
- self, mock_cancel_job_run, mock_run_spark_job, mock_conn,
mock_get_job_run_status
+ self, mock_cancel_job_run, mock_run_spark_job, mock_get_job_run_status
):
mock_get_job_run_status.return_value = "success"
mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
diff --git a/tests/providers/microsoft/conftest.py
b/tests/providers/microsoft/conftest.py
new file mode 100644
index 0000000000..aa75c95203
--- /dev/null
+++ b/tests/providers/microsoft/conftest.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import string
+from random import choices
+from typing import TypeVar
+
+import pytest
+
+from airflow.models import Connection
+
+T = TypeVar("T", dict, str, Connection)
+
+
[email protected]
+def create_mock_connection(monkeypatch):
+ """Helper fixture for create test connection."""
+
+ def wrapper(conn: T, conn_id: str | None = None):
+ conn_id = conn_id or "test_conn_" +
"".join(choices(string.ascii_lowercase + string.digits, k=6))
+ if isinstance(conn, dict):
+ conn = Connection.from_json(conn)
+ elif isinstance(conn, str):
+ conn = Connection(uri=conn)
+
+ if not isinstance(conn, Connection):
+ raise TypeError(
+ f"Fixture expected either JSON, URI or Connection type, but
got {type(conn).__name__}"
+ )
+ if not conn.conn_id:
+ conn.conn_id = conn_id
+
+ monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}",
conn.get_uri())
+ return conn
+
+ return wrapper
+
+
[email protected]
+def create_mock_connections(create_mock_connection):
+ """Helper fixture for create multiple test connections."""
+
+ def wrapper(*conns: T):
+ return list(map(create_mock_connection, conns))
+
+ return wrapper
+
+
[email protected]
+def mocked_connection(request, create_mock_connection):
+ """Helper indirect fixture for create test connection."""
+ return create_mock_connection(request.param)