This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new faf32539d6 Optimise Airflow DB backend usage in Azure Provider (#33750)
faf32539d6 is described below

commit faf32539d6a1be2bfba1b97e72e4508fb6896af6
Author: Andrey Anshin <[email protected]>
AuthorDate: Sat Aug 26 10:15:18 2023 +0400

    Optimise Airflow DB backend usage in Azure Provider (#33750)
---
 airflow/providers/microsoft/azure/hooks/adx.py     |  15 +-
 airflow/providers/microsoft/azure/hooks/batch.py   |  23 +-
 .../microsoft/azure/hooks/container_instance.py    |   6 +-
 .../microsoft/azure/hooks/container_registry.py    |  10 +-
 .../providers/microsoft/azure/hooks/data_lake.py   |  15 +-
 airflow/providers/microsoft/azure/hooks/wasb.py    |   7 +-
 .../microsoft/azure/log/wasb_task_handler.py       |   1 -
 .../providers/microsoft/azure/operators/batch.py   |   3 +-
 .../microsoft/azure/operators/data_factory.py      |   7 +-
 .../providers/microsoft/azure/operators/synapse.py |   9 +-
 .../microsoft/azure/sensors/data_factory.py        |   7 +-
 docs/spelling_wordlist.txt                         |   1 +
 tests/providers/microsoft/azure/hooks/test_adx.py  | 289 +++++-----
 tests/providers/microsoft/azure/hooks/test_asb.py  |  42 +-
 .../microsoft/azure/hooks/test_azure_batch.py      |  35 +-
 .../azure/hooks/test_azure_container_instance.py   |  17 +-
 .../azure/hooks/test_azure_container_registry.py   |  15 +-
 .../azure/hooks/test_azure_container_volume.py     |  31 +-
 .../microsoft/azure/hooks/test_azure_cosmos.py     |  18 +-
 .../azure/hooks/test_azure_data_factory.py         | 232 ++++----
 .../microsoft/azure/hooks/test_azure_data_lake.py  |  14 +-
 .../microsoft/azure/hooks/test_azure_fileshare.py  |  34 +-
 .../microsoft/azure/hooks/test_azure_synapse.py    |  83 ++-
 .../microsoft/azure/hooks/test_base_azure.py       |  72 +--
 tests/providers/microsoft/azure/hooks/test_wasb.py | 595 +++++++++------------
 .../microsoft/azure/operators/test_azure_batch.py  |  33 +-
 .../microsoft/azure/operators/test_azure_cosmos.py |  16 +-
 .../azure/operators/test_azure_data_factory.py     |  10 +-
 .../azure/operators/test_azure_synapse.py          |  21 +-
 tests/providers/microsoft/conftest.py              |  68 +++
 30 files changed, 872 insertions(+), 857 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/adx.py 
b/airflow/providers/microsoft/azure/hooks/adx.py
index f2cfd1a6c7..53da21e396 100644
--- a/airflow/providers/microsoft/azure/hooks/adx.py
+++ b/airflow/providers/microsoft/azure/hooks/adx.py
@@ -26,6 +26,7 @@ This module contains Azure Data Explorer hook.
 from __future__ import annotations
 
 import warnings
+from functools import cached_property
 from typing import Any
 
 from azure.identity import DefaultAzureCredential
@@ -76,8 +77,8 @@ class AzureDataExplorerHook(BaseHook):
     conn_type = "azure_data_explorer"
     hook_name = "Azure Data Explorer"
 
-    @staticmethod
-    def get_connection_form_widgets() -> dict[str, Any]:
+    @classmethod
+    def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
         from flask_babel import lazy_gettext
@@ -94,8 +95,8 @@ class AzureDataExplorerHook(BaseHook):
             ),
         }
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
             "hidden_fields": ["schema", "port", "extra"],
@@ -116,7 +117,11 @@ class AzureDataExplorerHook(BaseHook):
     def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) 
-> None:
         super().__init__()
         self.conn_id = azure_data_explorer_conn_id
-        self.connection = self.get_conn()  # todo: make this a property, or 
just delete
+
+    @cached_property
+    def connection(self) -> KustoClient:
+        """Return a KustoClient object (cached)."""
+        return self.get_conn()
 
     def get_conn(self) -> KustoClient:
         """Return a KustoClient object."""
diff --git a/airflow/providers/microsoft/azure/hooks/batch.py 
b/airflow/providers/microsoft/azure/hooks/batch.py
index deca28216d..594725c0da 100644
--- a/airflow/providers/microsoft/azure/hooks/batch.py
+++ b/airflow/providers/microsoft/azure/hooks/batch.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import time
 from datetime import timedelta
+from functools import cached_property
 from typing import Any
 
 from azure.batch import BatchServiceClient, batch_auth, models as batch_models
@@ -26,7 +27,6 @@ from azure.batch.models import JobAddParameter, 
PoolAddParameter, TaskAddParamet
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.models import Connection
 from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
 from airflow.utils import timezone
 
@@ -52,8 +52,8 @@ class AzureBatchHook(BaseHook):
             field_name=name,
         )
 
-    @staticmethod
-    def get_connection_form_widgets() -> dict[str, Any]:
+    @classmethod
+    def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
         from flask_babel import lazy_gettext
@@ -63,8 +63,8 @@ class AzureBatchHook(BaseHook):
             "account_url": StringField(lazy_gettext("Batch Account URL"), 
widget=BS3TextFieldWidget()),
         }
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
             "hidden_fields": ["schema", "port", "host", "extra"],
@@ -77,20 +77,19 @@ class AzureBatchHook(BaseHook):
     def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
         super().__init__()
         self.conn_id = azure_batch_conn_id
-        self.connection = self.get_conn()
 
-    def _connection(self) -> Connection:
-        """Get connected to Azure Batch service."""
-        conn = self.get_connection(self.conn_id)
-        return conn
+    @cached_property
+    def connection(self) -> BatchServiceClient:
+        """Get the Batch client connection (cached)."""
+        return self.get_conn()
 
-    def get_conn(self):
+    def get_conn(self) -> BatchServiceClient:
         """
         Get the Batch client connection.
 
         :return: Azure Batch client
         """
-        conn = self._connection()
+        conn = self.get_connection(self.conn_id)
 
         batch_account_url = self._get_field(conn.extra_dejson, "account_url")
         if not batch_account_url:
diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py 
b/airflow/providers/microsoft/azure/hooks/container_instance.py
index 9a0d0ec210..8fc845bf13 100644
--- a/airflow/providers/microsoft/azure/hooks/container_instance.py
+++ b/airflow/providers/microsoft/azure/hooks/container_instance.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import warnings
+from functools import cached_property
 
 from azure.mgmt.containerinstance import ContainerInstanceManagementClient
 from azure.mgmt.containerinstance.models import ContainerGroup
@@ -47,7 +48,10 @@ class AzureContainerInstanceHook(AzureBaseHook):
 
     def __init__(self, azure_conn_id: str = default_conn_name) -> None:
         super().__init__(sdk_client=ContainerInstanceManagementClient, 
conn_id=azure_conn_id)
-        self.connection = self.get_conn()
+
+    @cached_property
+    def connection(self):
+        return self.get_conn()
 
     def create_or_update(self, resource_group: str, name: str, 
container_group: ContainerGroup) -> None:
         """
diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py 
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index a3298117cc..c1217e3a86 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -18,6 +18,7 @@
 """Hook for Azure Container Registry."""
 from __future__ import annotations
 
+from functools import cached_property
 from typing import Any
 
 from azure.mgmt.containerinstance.models import ImageRegistryCredential
@@ -39,8 +40,8 @@ class AzureContainerRegistryHook(BaseHook):
     conn_type = "azure_container_registry"
     hook_name = "Azure Container Registry"
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
             "hidden_fields": ["schema", "port", "extra"],
@@ -59,7 +60,10 @@ class AzureContainerRegistryHook(BaseHook):
     def __init__(self, conn_id: str = "azure_registry") -> None:
         super().__init__()
         self.conn_id = conn_id
-        self.connection = self.get_conn()
+
+    @cached_property
+    def connection(self) -> ImageRegistryCredential:
+        return self.get_conn()
 
     def get_conn(self) -> ImageRegistryCredential:
         conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py 
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 95ef4c6cc2..3849727e86 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from functools import cached_property
 from typing import Any
 
 from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
@@ -256,8 +257,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
     conn_type = "adls"
     hook_name = "Azure Date Lake Storage V2"
 
-    @staticmethod
-    def get_connection_form_widgets() -> dict[str, Any]:
+    @classmethod
+    def get_connection_form_widgets(cls) -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
         from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
         from flask_babel import lazy_gettext
@@ -272,8 +273,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
             ),
         }
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
             "hidden_fields": ["schema", "port"],
@@ -296,7 +297,11 @@ class AzureDataLakeStorageV2Hook(BaseHook):
         super().__init__()
         self.conn_id = adls_conn_id
         self.public_read = public_read
-        self.service_client = self.get_conn()
+
+    @cached_property
+    def service_client(self) -> DataLakeServiceClient:
+        """Return the DataLakeServiceClient object (cached)."""
+        return self.get_conn()
 
     def get_conn(self) -> DataLakeServiceClient:  # type: ignore[override]
         """Return the DataLakeServiceClient object."""
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index 95d4764b67..55c1aba086 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -27,6 +27,7 @@ from __future__ import annotations
 
 import logging
 import os
+from functools import cached_property
 from typing import Any, Union
 from urllib.parse import urlparse
 
@@ -123,7 +124,6 @@ class WasbHook(BaseHook):
         super().__init__()
         self.conn_id = wasb_conn_id
         self.public_read = public_read
-        self.blob_service_client: BlobServiceClient = self.get_conn()
 
         logger = 
logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
         try:
@@ -142,6 +142,11 @@ class WasbHook(BaseHook):
             return extra_dict[field_name] or None
         return extra_dict.get(f"{prefix}{field_name}") or None
 
+    @cached_property
+    def blob_service_client(self) -> BlobServiceClient:
+        """Return the BlobServiceClient object (cached)."""
+        return self.get_conn()
+
     def get_conn(self) -> BlobServiceClient:
         """Return the BlobServiceClient object."""
         conn = self.get_connection(self.conn_id)
diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py 
b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
index 97a8af5ae1..21e96f1003 100644
--- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
+++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
@@ -67,7 +67,6 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
         self.wasb_container = wasb_container
         self.remote_base = wasb_log_folder
         self.log_relative_path = ""
-        self._hook = None
         self.closed = False
         self.upload_on_close = True
         self.delete_local_copy = (
diff --git a/airflow/providers/microsoft/azure/operators/batch.py 
b/airflow/providers/microsoft/azure/operators/batch.py
index e26f56dd6e..06122ed1f9 100644
--- a/airflow/providers/microsoft/azure/operators/batch.py
+++ b/airflow/providers/microsoft/azure/operators/batch.py
@@ -179,7 +179,8 @@ class AzureBatchOperator(BaseOperator):
         self.should_delete_pool = should_delete_pool
 
     @cached_property
-    def hook(self):
+    def hook(self) -> AzureBatchHook:
+        """Create and return an AzureBatchHook (cached)."""
         return self.get_hook()
 
     def _check_inputs(self) -> Any:
diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py 
b/airflow/providers/microsoft/azure/operators/data_factory.py
index d6b4592e35..12962e5610 100644
--- a/airflow/providers/microsoft/azure/operators/data_factory.py
+++ b/airflow/providers/microsoft/azure/operators/data_factory.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import time
 import warnings
+from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
@@ -159,8 +160,12 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
         self.check_interval = check_interval
         self.deferrable = deferrable
 
+    @cached_property
+    def hook(self) -> AzureDataFactoryHook:
+        """Create and return an AzureDataFactoryHook (cached)."""
+        return 
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+
     def execute(self, context: Context) -> None:
-        self.hook = 
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
         self.log.info("Executing the %s pipeline.", self.pipeline_name)
         response = self.hook.run_pipeline(
             pipeline_name=self.pipeline_name,
diff --git a/airflow/providers/microsoft/azure/operators/synapse.py 
b/airflow/providers/microsoft/azure/operators/synapse.py
index b9d97704c5..dd6dda5555 100644
--- a/airflow/providers/microsoft/azure/operators/synapse.py
+++ b/airflow/providers/microsoft/azure/operators/synapse.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
 from azure.synapse.spark.models import SparkBatchJobOptions
@@ -73,10 +74,12 @@ class AzureSynapseRunSparkBatchOperator(BaseOperator):
         self.timeout = timeout
         self.check_interval = check_interval
 
+    @cached_property
+    def hook(self):
+        """Create and return an AzureSynapseHook (cached)."""
+        return 
AzureSynapseHook(azure_synapse_conn_id=self.azure_synapse_conn_id, 
spark_pool=self.spark_pool)
+
     def execute(self, context: Context) -> None:
-        self.hook = AzureSynapseHook(
-            azure_synapse_conn_id=self.azure_synapse_conn_id, 
spark_pool=self.spark_pool
-        )
         self.log.info("Executing the Synapse spark job.")
         response = self.hook.run_spark_job(payload=self.payload)
         self.log.info(response)
diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py 
b/airflow/providers/microsoft/azure/sensors/data_factory.py
index 4caa26a99d..91bc4072c0 100644
--- a/airflow/providers/microsoft/azure/sensors/data_factory.py
+++ b/airflow/providers/microsoft/azure/sensors/data_factory.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import warnings
 from datetime import timedelta
+from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
@@ -72,8 +73,12 @@ class 
AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
 
         self.deferrable = deferrable
 
+    @cached_property
+    def hook(self):
+        """Create and return an AzureDataFactoryHook (cached)."""
+        return 
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
+
     def poke(self, context: Context) -> bool:
-        self.hook = 
AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
         pipeline_run_status = self.hook.get_pipeline_run_status(
             run_id=self.run_id,
             resource_group_name=self.resource_group_name,
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index ec01067b12..eb80939a4c 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -154,6 +154,7 @@ BaseView
 BaseXCom
 bashrc
 batchGet
+BatchServiceClient
 bc
 bcc
 bdist
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py 
b/tests/providers/microsoft/azure/hooks/test_adx.py
index 4cf61c440b..2268f8b8a6 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -17,10 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-import json
-import os
 from unittest import mock
-from unittest.mock import patch
 
 import pytest
 from azure.kusto.data import ClientRequestProperties, KustoClient, 
KustoConnectionStringBuilder
@@ -29,196 +26,220 @@ from pytest import param
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
-from airflow.utils import db
-from airflow.utils.session import create_session
 from tests.test_utils.providers import get_provider_min_airflow_version
 
 ADX_TEST_CONN_ID = "adx_test_connection_id"
 
 
 class TestAzureDataExplorerHook:
-    def teardown_method(self):
-        with create_session() as session:
-            session.query(Connection).filter(Connection.conn_id == 
ADX_TEST_CONN_ID).delete()
-
-    def test_conn_missing_method(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="missing_method",
                 conn_type="azure_data_explorer",
                 login="client_id",
                 password="client secret",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps({}),
+                extra={},
             )
-        )
-        with pytest.raises(AirflowException) as ctx:
-            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
-            assert "is missing: `data_explorer__auth_method`" in str(ctx.value)
+        ],
+        indirect=True,
+    )
+    def test_conn_missing_method(self, mocked_connection):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+        error_pattern = "is missing: `auth_method`"
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.get_conn()
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.connection
 
-    def test_conn_unknown_method(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="unknown_method",
                 conn_type="azure_data_explorer",
                 login="client_id",
                 password="client secret",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps({"auth_method": "AAD_OTHER"}),
-            )
-        )
-        with pytest.raises(AirflowException) as ctx:
-            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
-        assert "Unknown authentication method: AAD_OTHER" in str(ctx.value)
+                extra={"auth_method": "AAD_OTHER"},
+            ),
+        ],
+        indirect=True,
+    )
+    def test_conn_unknown_method(self, mocked_connection):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+        error_pattern = "Unknown authentication method: AAD_OTHER"
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.get_conn()
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.connection
 
-    def test_conn_missing_cluster(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="missing_cluster",
                 conn_type="azure_data_explorer",
                 login="client_id",
                 password="client secret",
-                extra=json.dumps({}),
-            )
-        )
-        with pytest.raises(AirflowException) as ctx:
-            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
-        assert "Host connection option is required" in str(ctx.value)
+                extra={},
+            ),
+        ],
+        indirect=True,
+    )
+    def test_conn_missing_cluster(self, mocked_connection):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+        error_pattern = "Host connection option is required"
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.get_conn()
+        with pytest.raises(AirflowException, match=error_pattern):
+            assert hook.connection
 
-    @mock.patch.object(KustoClient, "__init__")
-    def test_conn_method_aad_creds(self, mock_init):
-        mock_init.return_value = None
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="method_aad_creds",
                 conn_type="azure_data_explorer",
                 login="client_id",
                 password="client secret",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps(
-                    {
-                        "tenant": "tenant",
-                        "auth_method": "AAD_CREDS",
-                    }
-                ),
+                extra={"tenant": "tenant", "auth_method": "AAD_CREDS"},
             )
-        )
-        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        ],
+        indirect=True,
+    )
+    @mock.patch.object(KustoClient, "__init__")
+    def test_conn_method_aad_creds(self, mock_init, mocked_connection):
+        mock_init.return_value = None
+        
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
         assert mock_init.called_with(
             KustoConnectionStringBuilder.with_aad_user_password_authentication(
                 "https://help.kusto.windows.net";, "client_id", "client 
secret", "tenant"
             )
         )
 
-    
@mock.patch("azure.identity._credentials.environment.ClientSecretCredential")
-    def test_conn_method_token_creds(self, mock1):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="method_token_creds",
                 conn_type="azure_data_explorer",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps(
-                    {
-                        "auth_method": "AZURE_TOKEN_CRED",
-                    }
-                ),
-            )
+                extra={
+                    "auth_method": "AZURE_TOKEN_CRED",
+                },
+            ),
+        ],
+        indirect=True,
+    )
+    
@mock.patch("azure.identity._credentials.environment.ClientSecretCredential")
+    def test_conn_method_token_creds(self, mock1, mocked_connection, 
monkeypatch):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+
+        monkeypatch.setenv("AZURE_TENANT_ID", "tenant")
+        monkeypatch.setenv("AZURE_CLIENT_ID", "client")
+        monkeypatch.setenv("AZURE_CLIENT_SECRET", "secret")
+
+        assert hook.connection._kcsb.data_source == 
"https://help.kusto.windows.net";
+        mock1.assert_called_once_with(
+            tenant_id="tenant",
+            client_id="client",
+            client_secret="secret",
+            authority="https://login.microsoftonline.com";,
         )
-        with patch.dict(
-            in_dict=os.environ,
-            values={
-                "AZURE_TENANT_ID": "tenant",
-                "AZURE_CLIENT_ID": "client",
-                "AZURE_CLIENT_SECRET": "secret",
-            },
-        ):
-            hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
-            assert hook.connection._kcsb.data_source == 
"https://help.kusto.windows.net";
-            mock1.assert_called_once_with(
-                tenant_id="tenant",
-                client_id="client",
-                client_secret="secret",
-                authority="https://login.microsoftonline.com";,
-            )
 
-    @mock.patch.object(KustoClient, "__init__")
-    def test_conn_method_aad_app(self, mock_init):
-        mock_init.return_value = None
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="method_aad_app",
                 conn_type="azure_data_explorer",
                 login="app_id",
                 password="app key",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps(
-                    {
-                        "tenant": "tenant",
-                        "auth_method": "AAD_APP",
-                    }
-                ),
+                extra={
+                    "tenant": "tenant",
+                    "auth_method": "AAD_APP",
+                },
             )
-        )
-        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        ],
+        indirect=True,
+    )
+    @mock.patch.object(KustoClient, "__init__")
+    def test_conn_method_aad_app(self, mock_init, mocked_connection):
+        mock_init.return_value = None
+        
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
         assert mock_init.called_with(
             
KustoConnectionStringBuilder.with_aad_application_key_authentication(
                 "https://help.kusto.windows.net";, "app_id", "app key", "tenant"
             )
         )
 
-    @mock.patch.object(KustoClient, "__init__")
-    def test_conn_method_aad_app_cert(self, mock_init):
-        mock_init.return_value = None
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
-                conn_id=ADX_TEST_CONN_ID,
+                conn_id="method_aad_app",
                 conn_type="azure_data_explorer",
-                login="client_id",
+                login="app_id",
+                password="app key",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps(
-                    {
-                        "tenant": "tenant",
-                        "auth_method": "AAD_APP_CERT",
-                        "certificate": "PEM",
-                        "thumbprint": "thumbprint",
-                    }
-                ),
+                extra={
+                    "tenant": "tenant",
+                    "auth_method": "AAD_APP",
+                },
             )
-        )
-        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        ],
+        indirect=True,
+    )
+    @mock.patch.object(KustoClient, "__init__")
+    def test_conn_method_aad_app_cert(self, mock_init, mocked_connection):
+        mock_init.return_value = None
+        
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
         assert mock_init.called_with(
             
KustoConnectionStringBuilder.with_aad_application_certificate_authentication(
                 "https://help.kusto.windows.net";, "client_id", "PEM", 
"thumbprint", "tenant"
             )
         )
 
-    @mock.patch.object(KustoClient, "__init__")
-    def test_conn_method_aad_device(self, mock_init):
-        mock_init.return_value = None
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
                 conn_id=ADX_TEST_CONN_ID,
                 conn_type="azure_data_explorer",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps({"auth_method": "AAD_DEVICE"}),
+                extra={"auth_method": "AAD_DEVICE"},
             )
-        )
-        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        ],
+        indirect=True,
+    )
+    @mock.patch.object(KustoClient, "__init__")
+    def test_conn_method_aad_device(self, mock_init, mocked_connection):
+        mock_init.return_value = None
+        
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id).get_conn()
         assert mock_init.called_with(
             
KustoConnectionStringBuilder.with_aad_device_authentication("https://help.kusto.windows.net";)
         )
 
-    @mock.patch.object(KustoClient, "execute")
-    def test_run_query(self, mock_execute):
-        mock_execute.return_value = None
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
                 conn_id=ADX_TEST_CONN_ID,
                 conn_type="azure_data_explorer",
                 host="https://help.kusto.windows.net";,
-                extra=json.dumps({"auth_method": "AAD_DEVICE"}),
+                extra={"auth_method": "AAD_DEVICE"},
             )
-        )
+        ],
+        indirect=True,
+    )
+    @mock.patch.object(KustoClient, "execute")
+    def test_run_query(self, mock_execute, mocked_connection):
+        mock_execute.return_value = None
         hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
         hook.run_query("Database", "Logs | schema", options={"option1": 
"option_value"})
         properties = ClientRequestProperties()
@@ -246,7 +267,7 @@ class TestAzureDataExplorerHook:
             )
 
     @pytest.mark.parametrize(
-        "uri",
+        "mocked_connection",
         [
             param(
                 "a://usr:pw@host?extra__azure_data_explorer__tenant=my-tenant"
@@ -255,24 +276,28 @@ class TestAzureDataExplorerHook:
             ),
             param("a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP", 
id="no-prefix"),
         ],
+        indirect=True,
     )
-    def test_backcompat_prefix_works(self, uri):
-        with patch.dict(os.environ, AIRFLOW_CONN_MY_CONN=uri):
-            hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn")  # get_conn is 
called in init
+    def test_backcompat_prefix_works(self, mocked_connection):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
         assert hook.connection._kcsb.data_source == "host"
         assert hook.connection._kcsb.application_client_id == "usr"
         assert hook.connection._kcsb.application_key == "pw"
         assert hook.connection._kcsb.authority_id == "my-tenant"
 
-    def test_backcompat_prefix_both_causes_warning(self):
-        with patch.dict(
-            in_dict=os.environ,
-            
AIRFLOW_CONN_MY_CONN="a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP"
-            "&extra__azure_data_explorer__auth_method=AAD_APP",
-        ):
-            with pytest.warns(Warning, match="Using value for `auth_method`"):
-                hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn")
-            assert hook.connection._kcsb.data_source == "host"
-            assert hook.connection._kcsb.application_client_id == "usr"
-            assert hook.connection._kcsb.application_key == "pw"
-            assert hook.connection._kcsb.authority_id == "my-tenant"
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            (
+                "a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP"
+                "&extra__azure_data_explorer__auth_method=AAD_APP"
+            )
+        ],
+        indirect=True,
+    )
+    def test_backcompat_prefix_both_causes_warning(self, mocked_connection):
+        hook = 
AzureDataExplorerHook(azure_data_explorer_conn_id=mocked_connection.conn_id)
+        assert hook.connection._kcsb.data_source == "host"
+        assert hook.connection._kcsb.application_client_id == "usr"
+        assert hook.connection._kcsb.application_key == "pw"
+        assert hook.connection._kcsb.authority_id == "my-tenant"
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py 
b/tests/providers/microsoft/azure/hooks/test_asb.py
index a9a3851561..5f626d6c29 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -34,22 +34,23 @@ MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(0, 10)]
 
 
 class TestAdminClientHook:
-    def setup_class(self) -> None:
-        self.queue_name: str = "test_queue"
-        self.conn_id: str = "azure_service_bus_default"
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
+        self.queue_name = "test_queue"
+        self.conn_id = "azure_service_bus_default"
         self.connection_string = (
             "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;"
             "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc"
         )
-        self.mock_conn = Connection(
-            conn_id="azure_service_bus_default",
-            conn_type="azure_service_bus",
-            schema=self.connection_string,
+        self.mock_conn = create_mock_connection(
+            Connection(
+                conn_id=self.conn_id,
+                conn_type="azure_service_bus",
+                schema=self.connection_string,
+            )
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_connection")
-    def test_get_conn(self, mock_connection):
-        mock_connection.return_value = self.mock_conn
+    def test_get_conn(self):
         hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusAdministrationClient)
 
@@ -124,26 +125,27 @@ class TestAdminClientHook:
 
 
 class TestMessageHook:
-    def setup_class(self) -> None:
-        self.queue_name: str = "test_queue"
-        self.conn_id: str = "azure_service_bus_default"
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
+        self.queue_name = "test_queue"
+        self.conn_id = "azure_service_bus_default"
         self.connection_string = (
             "Endpoint=sb://test-service-bus-provider.servicebus.windows.net/;"
             "SharedAccessKeyName=Test;SharedAccessKey=1234566acbc"
         )
-        self.conn = Connection(
-            conn_id="azure_service_bus_default",
-            conn_type="azure_service_bus",
-            schema=self.connection_string,
+        self.mock_conn = create_mock_connection(
+            Connection(
+                conn_id=self.conn_id,
+                conn_type="azure_service_bus",
+                schema=self.connection_string,
+            )
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_connection")
-    def test_get_service_bus_message_conn(self, mock_connection):
+    def test_get_service_bus_message_conn(self):
         """
         Test get_conn() function and check whether the get_conn() function 
returns value
         is instance of ServiceBusClient
         """
-        mock_connection.return_value = self.conn
         hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         assert isinstance(hook.get_conn(), ServiceBusClient)
 
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py 
b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index a3a421f5a0..cd5ab10134 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -17,22 +17,21 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 from unittest.mock import PropertyMock
 
+import pytest
 from azure.batch import BatchServiceClient, models as batch_models
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
-from airflow.utils import db
 
 MODULE = "airflow.providers.microsoft.azure.hooks.batch"
 
 
 class TestAzureBatchHook:
-    # set up the test environment
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connections):
         # set up the test variable
         self.test_vm_conn_id = "test_azure_batch_vm"
         self.test_cloud_conn_id = "test_azure_batch_cloud"
@@ -47,27 +46,27 @@ class TestAzureBatchHook:
         self.test_cloud_os_version = "test-version"
         self.test_node_agent_sku = "test-node-agent-sku"
 
-        # connect with vm configuration
-        db.merge_conn(
+        create_mock_connections(
+            # connect with vm configuration
             Connection(
                 conn_id=self.test_vm_conn_id,
-                conn_type="azure_batch",
-                extra=json.dumps({"account_url": self.test_account_url}),
-            )
-        )
-        # connect with cloud service
-        db.merge_conn(
+                conn_type="azure-batch",
+                extra={"account_url": self.test_account_url},
+            ),
+            # connect with cloud service
             Connection(
                 conn_id=self.test_cloud_conn_id,
-                conn_type="azure_batch",
-                extra=json.dumps({"account_url": self.test_account_url}),
-            )
+                conn_type="azure-batch",
+                extra={"account_url": self.test_account_url},
+            ),
         )
 
     def test_connection_and_client(self):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
-        assert isinstance(hook._connection(), Connection)
         assert isinstance(hook.get_conn(), BatchServiceClient)
+        conn = hook.connection
+        assert isinstance(conn, BatchServiceClient)
+        assert hook.connection is conn, "`connection` property should be 
cached"
 
     @mock.patch(f"{MODULE}.batch_auth.SharedKeyCredentials")
     @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter")
@@ -195,7 +194,7 @@ class TestAzureBatchHook:
     
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
     def test_connection_success(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
-        hook.get_conn().job.return_value = {}
+        hook.connection.job.return_value = {}
         status, msg = hook.test_connection()
         assert status is True
         assert msg == "Successfully connected to Azure Batch."
@@ -203,7 +202,7 @@ class TestAzureBatchHook:
     
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
     def test_connection_failure(self, mock_batch):
         hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
-        hook.get_conn().job.list = 
PropertyMock(side_effect=Exception("Authentication failed."))
+        hook.connection.job.list = 
PropertyMock(side_effect=Exception("Authentication failed."))
         status, msg = hook.test_connection()
         assert status is False
         assert msg == "Authentication failed."
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index a10dde9c65..786df4eb16 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -17,9 +17,9 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest.mock import patch
 
+import pytest
 from azure.mgmt.containerinstance.models import (
     Container,
     ContainerGroup,
@@ -30,27 +30,26 @@ from azure.mgmt.containerinstance.models import (
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.container_instance import 
AzureContainerInstanceHook
-from airflow.utils import db
 
 
 class TestAzureContainerInstanceHook:
-    def setup_method(self):
-        db.merge_conn(
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
+        mock_connection = create_mock_connection(
             Connection(
                 conn_id="azure_container_instance_test",
                 conn_type="azure_container_instances",
                 login="login",
                 password="key",
-                extra=json.dumps({"tenantId": "tenant_id", "subscriptionId": 
"subscription_id"}),
+                extra={"tenantId": "tenant_id", "subscriptionId": 
"subscription_id"},
             )
         )
-
         self.resources = 
ResourceRequirements(requests=ResourceRequests(memory_in_gb="4", cpu="1"))
-        with patch(
+        self.hook = 
AzureContainerInstanceHook(azure_conn_id=mock_connection.conn_id)
+        with 
patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"), patch(
             "azure.common.credentials.ServicePrincipalCredentials.__init__", 
autospec=True, return_value=None
         ):
-            with 
patch("azure.mgmt.containerinstance.ContainerInstanceManagementClient"):
-                self.hook = 
AzureContainerInstanceHook(azure_conn_id="azure_container_instance_test")
+            yield
 
     @patch("azure.mgmt.containerinstance.models.ContainerGroup")
     
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update")
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 91ae933d2b..38f326d298 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -17,14 +17,16 @@
 # under the License.
 from __future__ import annotations
 
+import pytest
+
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.container_registry import 
AzureContainerRegistryHook
-from airflow.utils import db
 
 
 class TestAzureContainerRegistryHook:
-    def test_get_conn(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
                 conn_id="azure_container_registry",
                 conn_type="azure_container_registry",
@@ -32,8 +34,11 @@ class TestAzureContainerRegistryHook:
                 password="password",
                 host="test.cr",
             )
-        )
-        hook = AzureContainerRegistryHook(conn_id="azure_container_registry")
+        ],
+        indirect=True,
+    )
+    def test_get_conn(self, mocked_connection):
+        hook = AzureContainerRegistryHook(conn_id=mocked_connection.conn_id)
         assert hook.connection is not None
         assert hook.connection.username == "myuser"
         assert hook.connection.password == "password"
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index 3dcbfd9a89..b4c7b8d1c7 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -17,22 +17,25 @@
 # under the License.
 from __future__ import annotations
 
-import json
+import pytest
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.container_volume import 
AzureContainerVolumeHook
-from airflow.utils import db
 from tests.test_utils.providers import get_provider_min_airflow_version
 
 
 class TestAzureContainerVolumeHook:
-    def test_get_file_volume(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
                 conn_id="azure_container_test_connection", conn_type="wasb", 
login="login", password="key"
             )
-        )
-        hook = 
AzureContainerVolumeHook(azure_container_volume_conn_id="azure_container_test_connection")
+        ],
+        indirect=True,
+    )
+    def test_get_file_volume(self, mocked_connection):
+        hook = 
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
         volume = hook.get_file_volume(
             mount_name="mount", share_name="share", 
storage_account_name="storage", read_only=True
         )
@@ -43,19 +46,21 @@ class TestAzureContainerVolumeHook:
         assert volume.azure_file.storage_account_name == "storage"
         assert volume.azure_file.read_only is True
 
-    def test_get_file_volume_connection_string(self):
-        db.merge_conn(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
             Connection(
                 conn_id="azure_container_test_connection_connection_string",
                 conn_type="wasb",
                 login="login",
                 password="key",
-                extra=json.dumps({"connection_string": "a=b;AccountKey=1"}),
+                extra={"connection_string": "a=b;AccountKey=1"},
             )
-        )
-        hook = AzureContainerVolumeHook(
-            
azure_container_volume_conn_id="azure_container_test_connection_connection_string"
-        )
+        ],
+        indirect=True,
+    )
+    def test_get_file_volume_connection_string(self, mocked_connection):
+        hook = 
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
         volume = hook.get_file_volume(
             mount_name="mount", share_name="share", 
storage_account_name="storage", read_only=True
         )
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py 
b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index af649f1d6e..f63b8e8dbd 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 import logging
 import uuid
 from unittest import mock
@@ -29,14 +28,14 @@ from azure.cosmos.cosmos_client import CosmosClient
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
-from airflow.utils import db
 from tests.test_utils.providers import get_provider_min_airflow_version
 
 
 class TestAzureCosmosDbHook:
 
     # Set up an environment to test with
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
         # set up some test variables
         self.test_end_point = "https://test_endpoint:443";
         self.test_master_key = "magic_test_key"
@@ -44,25 +43,22 @@ class TestAzureCosmosDbHook:
         self.test_collection_name = "test_collection_name"
         self.test_database_default = "test_database_default"
         self.test_collection_default = "test_collection_default"
-        db.merge_conn(
+        create_mock_connection(
             Connection(
                 conn_id="azure_cosmos_test_key_id",
                 conn_type="azure_cosmos",
                 login=self.test_end_point,
                 password=self.test_master_key,
-                extra=json.dumps(
-                    {
-                        "database_name": self.test_database_default,
-                        "collection_name": self.test_collection_default,
-                    }
-                ),
+                extra={
+                    "database_name": self.test_database_default,
+                    "collection_name": self.test_collection_default,
+                },
             )
         )
 
     @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient", 
autospec=True)
     def test_client(self, mock_cosmos):
         hook = 
AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
-        assert hook._conn is None
         assert isinstance(hook.get_conn(), CosmosClient)
 
     @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
index 4ca2eb4920..63a22614dc 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 import os
 from unittest import mock
 from unittest.mock import MagicMock, PropertyMock, patch
@@ -25,9 +24,9 @@ import pytest
 from azure.identity import ClientSecretCredential, DefaultAzureCredential
 from azure.mgmt.datafactory.aio import DataFactoryManagementClient
 from azure.mgmt.datafactory.models import FactoryListResponse
-from pytest import fixture, param
+from pytest import param
 
-from airflow import AirflowException
+from airflow.exceptions import AirflowException
 from airflow.models.connection import Connection
 from airflow.providers.microsoft.azure.hooks.data_factory import (
     AzureDataFactoryAsyncHook,
@@ -37,7 +36,6 @@ from airflow.providers.microsoft.azure.hooks.data_factory 
import (
     get_field,
     provide_targeted_factory,
 )
-from airflow.utils import db
 
 DEFAULT_RESOURCE_GROUP = "defaultResourceGroup"
 AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_default"
@@ -59,66 +57,60 @@ NAME = "testName"
 ID = "testId"
 
 
-def setup_module():
-    connection_client_secret = Connection(
-        conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-        conn_type="azure_data_factory",
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
[email protected](autouse=True)
+def setup_connections(create_mock_connections):
+    create_mock_connections(
+        # connection_client_secret
+        Connection(
+            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+            conn_type="azure_data_factory",
+            login="clientId",
+            password="clientSecret",
+            extra={
                 "tenantId": "tenantId",
                 "subscriptionId": "subscriptionId",
                 "resource_group_name": DEFAULT_RESOURCE_GROUP,
                 "factory_name": DEFAULT_FACTORY,
-            }
+            },
         ),
-    )
-    connection_default_credential = Connection(
-        conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
-        conn_type="azure_data_factory",
-        extra=json.dumps(
-            {
+        # connection_default_credential
+        Connection(
+            conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+            conn_type="azure_data_factory",
+            extra={
                 "subscriptionId": "subscriptionId",
                 "resource_group_name": DEFAULT_RESOURCE_GROUP,
                 "factory_name": DEFAULT_FACTORY,
-            }
+            },
         ),
-    )
-    connection_missing_subscription_id = Connection(
-        conn_id="azure_data_factory_missing_subscription_id",
-        conn_type="azure_data_factory",
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
+        Connection(
+            # connection_missing_subscription_id
+            conn_id="azure_data_factory_missing_subscription_id",
+            conn_type="azure_data_factory",
+            login="clientId",
+            password="clientSecret",
+            extra={
                 "tenantId": "tenantId",
                 "resource_group_name": DEFAULT_RESOURCE_GROUP,
                 "factory_name": DEFAULT_FACTORY,
-            }
+            },
         ),
-    )
-    connection_missing_tenant_id = Connection(
-        conn_id="azure_data_factory_missing_tenant_id",
-        conn_type="azure_data_factory",
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
+        # connection_missing_tenant_id
+        Connection(
+            conn_id="azure_data_factory_missing_tenant_id",
+            conn_type="azure_data_factory",
+            login="clientId",
+            password="clientSecret",
+            extra={
                 "subscriptionId": "subscriptionId",
                 "resource_group_name": DEFAULT_RESOURCE_GROUP,
                 "factory_name": DEFAULT_FACTORY,
-            }
+            },
         ),
     )
 
-    db.merge_conn(connection_client_secret)
-    db.merge_conn(connection_default_credential)
-    db.merge_conn(connection_missing_subscription_id)
-    db.merge_conn(connection_missing_tenant_id)
 
-
-@fixture
[email protected]
 def hook():
     client = 
AzureDataFactoryHook(azure_data_factory_conn_id=DEFAULT_CONNECTION_CLIENT_SECRET)
     client._conn = MagicMock(
@@ -799,7 +791,7 @@ class TestAzureDataFactoryAsyncHook:
         Test get_pipeline_run function without passing the resource name to 
check the decorator function and
         raise exception
         """
-        mock_connection = Connection(extra=json.dumps({"factory_name": 
DATAFACTORY_NAME}))
+        mock_connection = Connection(extra={"factory_name": DATAFACTORY_NAME})
         mock_get_connection.return_value = mock_connection
         mock_conn.return_value.pipeline_runs.get.return_value = 
mock_pipeline_run
         hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
@@ -807,98 +799,98 @@ class TestAzureDataFactoryAsyncHook:
             await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-    async def test_get_async_conn(self, mock_connection):
-        """"""
-        mock_conn = Connection(
-            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-            conn_type="azure_data_factory",
-            login="clientId",
-            password="clientSecret",
-            extra=json.dumps(
-                {
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+                conn_type="azure_data_factory",
+                login="clientId",
+                password="clientSecret",
+                extra={
                     "tenantId": "tenantId",
                     "subscriptionId": "subscriptionId",
                     "resource_group_name": RESOURCE_GROUP_NAME,
                     "factory_name": DATAFACTORY_NAME,
-                }
-            ),
-        )
-        mock_connection.return_value = mock_conn
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+                },
+            )
+        ],
+        indirect=True,
+    )
+    async def test_get_async_conn(self, mocked_connection):
+        """"""
+        hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
         response = await hook.get_async_conn()
         assert isinstance(response, DataFactoryManagementClient)
 
     @pytest.mark.asyncio
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-    async def test_get_async_conn_without_login_id(self, mock_connection):
-        """Test get_async_conn function without login id"""
-        mock_conn = Connection(
-            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-            conn_type="azure_data_factory",
-            extra=json.dumps(
-                {
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+                conn_type="azure_data_factory",
+                extra={
                     "tenantId": "tenantId",
                     "subscriptionId": "subscriptionId",
                     "resource_group_name": RESOURCE_GROUP_NAME,
                     "factory_name": DATAFACTORY_NAME,
-                }
+                },
             ),
-        )
-        mock_connection.return_value = mock_conn
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        ],
+        indirect=True,
+    )
+    async def test_get_async_conn_without_login_id(self, mocked_connection):
+        """Test get_async_conn function without login id"""
+        hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
         response = await hook.get_async_conn()
         assert isinstance(response, DataFactoryManagementClient)
 
     @pytest.mark.asyncio
     @pytest.mark.parametrize(
-        "mock_connection_params",
+        "mocked_connection",
         [
-            {
-                "tenantId": "tenantId",
-                "resource_group_name": RESOURCE_GROUP_NAME,
-                "factory_name": DATAFACTORY_NAME,
-            }
+            Connection(
+                conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+                conn_type="azure_data_factory",
+                login="clientId",
+                password="clientSecret",
+                extra={
+                    "tenantId": "tenantId",
+                    "resource_group_name": RESOURCE_GROUP_NAME,
+                    "factory_name": DATAFACTORY_NAME,
+                },
+            )
         ],
+        indirect=True,
     )
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-    async def test_get_async_conn_key_error_subscription_id(self, 
mock_connection, mock_connection_params):
+    async def test_get_async_conn_key_error_subscription_id(self, 
mocked_connection):
         """Test get_async_conn function when subscription_id is missing in the 
connection"""
-        mock_conn = Connection(
-            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-            conn_type="azure_data_factory",
-            login="clientId",
-            password="clientSecret",
-            extra=json.dumps(mock_connection_params),
-        )
-        mock_connection.return_value = mock_conn
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
         with pytest.raises(ValueError):
             await hook.get_async_conn()
 
     @pytest.mark.asyncio
     @pytest.mark.parametrize(
-        "mock_connection_params",
+        "mocked_connection",
         [
-            {
-                "subscriptionId": "subscriptionId",
-                "resource_group_name": RESOURCE_GROUP_NAME,
-                "factory_name": DATAFACTORY_NAME,
-            },
+            Connection(
+                conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+                conn_type="azure_data_factory",
+                login="clientId",
+                password="clientSecret",
+                extra={
+                    "subscriptionId": "subscriptionId",
+                    "resource_group_name": RESOURCE_GROUP_NAME,
+                    "factory_name": DATAFACTORY_NAME,
+                },
+            )
         ],
+        indirect=True,
     )
-    
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
-    async def test_get_async_conn_key_error_tenant_id(self, mock_connection, 
mock_connection_params):
+    async def test_get_async_conn_key_error_tenant_id(self, mocked_connection):
         """Test get_async_conn function when tenant id is missing in the 
connection"""
-        mock_conn = Connection(
-            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-            conn_type="azure_data_factory",
-            login="clientId",
-            password="clientSecret",
-            extra=json.dumps(mock_connection_params),
-        )
-        mock_connection.return_value = mock_conn
-        hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
+        hook = AzureDataFactoryAsyncHook(mocked_connection.conn_id)
         with pytest.raises(ValueError):
             await hook.get_async_conn()
 
@@ -907,14 +899,12 @@ class TestAzureDataFactoryAsyncHook:
         mock_conn = Connection(
             conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
             conn_type="azure_data_factory",
-            extra=json.dumps(
-                {
-                    "tenantId": "tenantId",
-                    "subscriptionId": "subscriptionId",
-                    "resource_group_name": RESOURCE_GROUP_NAME,
-                    "factory_name": DATAFACTORY_NAME,
-                }
-            ),
+            extra={
+                "tenantId": "tenantId",
+                "subscriptionId": "subscriptionId",
+                "resource_group_name": RESOURCE_GROUP_NAME,
+                "factory_name": DATAFACTORY_NAME,
+            },
         )
         extras = mock_conn.extra_dejson
         assert get_field(extras, "tenantId", strict=True) == "tenantId"
@@ -929,14 +919,12 @@ class TestAzureDataFactoryAsyncHook:
         mock_conn = Connection(
             conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
             conn_type="azure_data_factory",
-            extra=json.dumps(
-                {
-                    "tenantId": "tenantId",
-                    "subscriptionId": "subscriptionId",
-                    "resource_group_name": RESOURCE_GROUP_NAME,
-                    "factory_name": DATAFACTORY_NAME,
-                }
-            ),
+            extra={
+                "tenantId": "tenantId",
+                "subscriptionId": "subscriptionId",
+                "resource_group_name": RESOURCE_GROUP_NAME,
+                "factory_name": DATAFACTORY_NAME,
+            },
         )
         extras = mock_conn.extra_dejson
         assert get_field(extras, "tenantId", strict=True) == "tenantId"
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index 122949beac..f5e2e8be5c 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 from unittest.mock import PropertyMock
 
@@ -26,18 +25,18 @@ from azure.storage.filedatalake._models import 
FileSystemProperties
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeStorageV2Hook
-from airflow.utils import db
 
 
 class TestAzureDataLakeHook:
-    def setup_method(self):
-        db.merge_conn(
+    @pytest.fixture(autouse=True)
+    def setup_connections(self, create_mock_connections):
+        create_mock_connections(
             Connection(
                 conn_id="adl_test_key",
                 conn_type="azure_data_lake",
                 login="client_id",
                 password="client secret",
-                extra=json.dumps({"tenant": "tenant", "account_name": 
"accountname"}),
+                extra={"tenant": "tenant", "account_name": "accountname"},
             )
         )
 
@@ -58,9 +57,10 @@ class TestAzureDataLakeHook:
     def test_check_for_blob(self, mock_lib, mock_filesystem):
         from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
 
+        mocked_glob = mock_filesystem.return_value.glob
         hook = AzureDataLakeHook(azure_data_lake_conn_id="adl_test_key")
         hook.check_for_file("file_path")
-        mock_filesystem.glob.called
+        mocked_glob.assert_called()
 
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
@@ -140,7 +140,7 @@ class TestAzureDataLakeHook:
 
 class TestAzureDataLakeStorageV2Hook:
     def setup_class(self) -> None:
-        self.conn_id: str = "adls_conn_id"
+        self.conn_id: str = "adls_conn_id1"
         self.file_system_name = "test_file_system"
         self.directory_name = "test_directory"
         self.file_name = "test_file_name"
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py 
b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index 1529eadcd9..a99ca9e90b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -25,7 +25,6 @@ and password (=Storage account key), or login and SAS token 
in the extra field
 """
 from __future__ import annotations
 
-import json
 import os
 from unittest import mock
 from unittest.mock import patch
@@ -36,51 +35,36 @@ from pytest import param
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.fileshare import 
AzureFileShareHook
-from airflow.utils import db
 
 
 class TestAzureFileshareHook:
-    def setup_method(self):
-        db.merge_conn(
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connections):
+        create_mock_connections(
             Connection(
                 conn_id="azure_fileshare_test_key",
                 conn_type="azure_file_share",
                 login="login",
                 password="key",
-            )
-        )
-        db.merge_conn(
+            ),
             Connection(
                 conn_id="azure_fileshare_extras",
                 conn_type="azure_fileshare",
                 login="login",
-                extra=json.dumps(
-                    {
-                        "sas_token": "token",
-                        "protocol": "http",
-                    }
-                ),
-            )
-        )
-        db.merge_conn(
+                extra={"sas_token": "token", "protocol": "http"},
+            ),
             # Neither password nor sas_token present
             Connection(
                 conn_id="azure_fileshare_missing_credentials",
                 conn_type="azure_fileshare",
                 login="login",
-            )
-        )
-        db.merge_conn(
+            ),
             Connection(
                 conn_id="azure_fileshare_extras_wrong",
                 conn_type="azure_fileshare",
                 login="login",
-                extra=json.dumps(
-                    {
-                        "wrong_key": "token",
-                    }
-                ),
-            )
+                extra={"wrong_key": "token"},
+            ),
         )
 
     def test_key_and_connection(self):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py 
b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
index 3a82efd812..b63dacc6da 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_synapse.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest.mock import MagicMock, patch
 
 import pytest
@@ -26,7 +25,6 @@ from pytest import fixture
 
 from airflow.models.connection import Connection
 from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, 
AzureSynapseSparkBatchRunStatus
-from airflow.utils import db
 
 DEFAULT_SPARK_POOL = "defaultSparkPool"
 
@@ -42,60 +40,45 @@ ID = "testId"
 JOB_ID = 1
 
 
-def setup_module():
-    connection_client_secret = Connection(
-        conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
-        conn_type="azure_synapse",
-        host="https://testsynapse.dev.azuresynapse.net";,
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
-                "tenantId": "tenantId",
-                "subscriptionId": "subscriptionId",
-            }
[email protected](autouse=True)
+def setup_connections(create_mock_connections):
+    create_mock_connections(
+        # connection_client_secret
+        Connection(
+            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+            conn_type="azure_synapse",
+            host="https://testsynapse.dev.azuresynapse.net";,
+            login="clientId",
+            password="clientSecret",
+            extra={"tenantId": "tenantId", "subscriptionId": "subscriptionId"},
         ),
-    )
-    connection_default_credential = Connection(
-        conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
-        conn_type="azure_synapse",
-        host="https://testsynapse.dev.azuresynapse.net";,
-        extra=json.dumps(
-            {
-                "subscriptionId": "subscriptionId",
-            }
+        # connection_default_credential
+        Connection(
+            conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+            conn_type="azure_synapse",
+            host="https://testsynapse.dev.azuresynapse.net";,
+            extra={"subscriptionId": "subscriptionId"},
         ),
-    )
-    connection_missing_subscription_id = Connection(
-        conn_id="azure_synapse_missing_subscription_id",
-        conn_type="azure_synapse",
-        host="https://testsynapse.dev.azuresynapse.net";,
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
-                "tenantId": "tenantId",
-            }
+        Connection(
+            # connection_missing_subscription_id
+            conn_id="azure_synapse_missing_subscription_id",
+            conn_type="azure_synapse",
+            host="https://testsynapse.dev.azuresynapse.net";,
+            login="clientId",
+            password="clientSecret",
+            extra={"tenantId": "tenantId"},
         ),
-    )
-    connection_missing_tenant_id = Connection(
-        conn_id="azure_synapse_missing_tenant_id",
-        conn_type="azure_synapse",
-        host="https://testsynapse.dev.azuresynapse.net";,
-        login="clientId",
-        password="clientSecret",
-        extra=json.dumps(
-            {
-                "subscriptionId": "subscriptionId",
-            }
+        # connection_missing_tenant_id
+        Connection(
+            conn_id="azure_synapse_missing_tenant_id",
+            conn_type="azure_synapse",
+            host="https://testsynapse.dev.azuresynapse.net";,
+            login="clientId",
+            password="clientSecret",
+            extra={"subscriptionId": "subscriptionId"},
         ),
     )
 
-    db.merge_conn(connection_client_secret)
-    db.merge_conn(connection_default_credential)
-    db.merge_conn(connection_missing_subscription_id)
-    db.merge_conn(connection_missing_tenant_id)
-
 
 @fixture
 def hook():
diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py 
b/tests/providers/microsoft/azure/hooks/test_base_azure.py
index 7b587c4121..53e2614a69 100644
--- a/tests/providers/microsoft/azure/hooks/test_base_azure.py
+++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py
@@ -18,63 +18,71 @@ from __future__ import annotations
 
 from unittest.mock import Mock, patch
 
+import pytest
+
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
 
 
 class TestBaseAzureHook:
-    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
-    @patch(
-        
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
-        return_value=Connection(conn_id="azure_default", extra='{ "key_path": 
"key_file.json" }'),
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [Connection(conn_id="azure_default", extra={"key_path": 
"key_file.json"})],
+        indirect=True,
     )
-    def test_get_conn_with_key_path(self, mock_connection, 
mock_get_client_from_auth_file):
+    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file")
+    def test_get_conn_with_key_path(self, mock_get_client_from_auth_file, 
mocked_connection):
+        mock_get_client_from_auth_file.return_value = "foo-bar"
         mock_sdk_client = Mock()
 
         auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
 
         mock_get_client_from_auth_file.assert_called_once_with(
-            client_class=mock_sdk_client, 
auth_path=mock_connection.return_value.extra_dejson["key_path"]
+            client_class=mock_sdk_client, 
auth_path=mocked_connection.extra_dejson["key_path"]
         )
-        assert auth_sdk_client == mock_get_client_from_auth_file.return_value
+        assert auth_sdk_client == "foo-bar"
 
-    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
-    @patch(
-        
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
-        return_value=Connection(conn_id="azure_default", extra='{ "key_json": 
{ "test": "test" } }'),
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [Connection(conn_id="azure_default", extra={"key_json": {"test": 
"test"}})],
+        indirect=True,
     )
-    def test_get_conn_with_key_json(self, mock_connection, 
mock_get_client_from_json_dict):
+    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict")
+    def test_get_conn_with_key_json(self, mock_get_client_from_json_dict, 
mocked_connection):
         mock_sdk_client = Mock()
-
+        mock_get_client_from_json_dict.return_value = "foo-bar"
         auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
 
         mock_get_client_from_json_dict.assert_called_once_with(
-            client_class=mock_sdk_client, 
config_dict=mock_connection.return_value.extra_dejson["key_json"]
+            client_class=mock_sdk_client, 
config_dict=mocked_connection.extra_dejson["key_json"]
         )
-        assert auth_sdk_client == mock_get_client_from_json_dict.return_value
+        assert auth_sdk_client == "foo-bar"
 
     
@patch("airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials")
-    @patch(
-        
"airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection",
-        return_value=Connection(
-            conn_id="azure_default",
-            login="my_login",
-            password="my_password",
-            extra='{ "tenantId": "my_tenant", "subscriptionId": 
"my_subscription" }',
-        ),
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="azure_default",
+                login="my_login",
+                password="my_password",
+                extra={"tenantId": "my_tenant", "subscriptionId": 
"my_subscription"},
+            )
+        ],
+        indirect=True,
     )
-    def test_get_conn_with_credentials(self, mock_connection, mock_spc):
-        mock_sdk_client = Mock()
-
+    def test_get_conn_with_credentials(self, mock_spc, mocked_connection):
+        mock_sdk_client = Mock(return_value="spam-egg")
+        mock_spc.return_value = "foo-bar"
         auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
 
         mock_spc.assert_called_once_with(
-            client_id=mock_connection.return_value.login,
-            secret=mock_connection.return_value.password,
-            tenant=mock_connection.return_value.extra_dejson["tenantId"],
+            client_id=mocked_connection.login,
+            secret=mocked_connection.password,
+            tenant=mocked_connection.extra_dejson["tenantId"],
         )
         mock_sdk_client.assert_called_once_with(
-            credentials=mock_spc.return_value,
-            
subscription_id=mock_connection.return_value.extra_dejson["subscriptionId"],
+            credentials="foo-bar",
+            subscription_id=mocked_connection.extra_dejson["subscriptionId"],
         )
-        assert auth_sdk_client == mock_sdk_client.return_value
+        assert auth_sdk_client == "spam-egg"
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py 
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 1f48a7b011..06ef4eedfb 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 
 import pytest
@@ -34,269 +33,262 @@ CONN_STRING = (
 )
 
 ACCESS_KEY_STRING = "AccountName=name;skdkskd"
+PROXIES = {"http": "http_proxy_uri", "https": "https_proxy_uri"}
+
+
[email protected]
+def mocked_blob_service_client():
+    with 
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") as 
m:
+        yield m
+
+
[email protected]
+def mocked_default_azure_credential():
+    with 
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
 as m:
+        yield m
+
+
[email protected]
+def mocked_client_secret_credential():
+    with 
mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential")
 as m:
+        yield m
 
 
 class TestWasbHook:
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_method(self, create_mock_connections):
         self.login = "login"
         self.wasb_test_key = "wasb_test_key"
         self.connection_type = "wasb"
-        self.connection_string_id = "azure_test_connection_string"
-        self.shared_key_conn_id = "azure_shared_key_test"
-        self.shared_key_conn_id_without_host = 
"azure_shared_key_test_wihout_host"
-        self.ad_conn_id = "azure_AD_test"
+        self.azure_test_connection_string = "azure_test_connection_string"
+        self.azure_shared_key_test = "azure_shared_key_test"
+        self.ad_conn_id = "ad_conn_id"
         self.sas_conn_id = "sas_token_id"
-        self.extra__wasb__sas_conn_id = "extra__sas_token_id"
-        self.http_sas_conn_id = "http_sas_token_id"
-        self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id"
+        self.extra__wasb__sas_conn_id = "extra__wasb__sas_conn_id"
+        self.http_sas_conn_id = "http_sas_conn_id"
+        self.extra__wasb__http_sas_conn_id = "extra__wasb__http_sas_conn_id"
         self.public_read_conn_id = "pub_read_id"
         self.public_read_conn_id_without_host = "pub_read_id_without_host"
-        self.managed_identity_conn_id = "managed_identity"
+        self.managed_identity_conn_id = "managed_identity_conn_id"
         self.authority = "https://test_authority.com";
 
-        self.proxies = {"http": "http_proxy_uri", "https": "https_proxy_uri"}
+        self.proxies = PROXIES
         self.client_secret_auth_config = {
             "proxies": self.proxies,
             "connection_verify": False,
             "authority": self.authority,
         }
-        self.connection_map = {
-            self.wasb_test_key: Connection(
+
+        conns = create_mock_connections(
+            Connection(
                 conn_id="wasb_test_key",
                 conn_type=self.connection_type,
                 login=self.login,
                 password="key",
             ),
-            self.public_read_conn_id: Connection(
+            Connection(
                 conn_id=self.public_read_conn_id,
                 conn_type=self.connection_type,
                 host="https://accountname.blob.core.windows.net";,
-                extra=json.dumps({"proxies": self.proxies}),
+                extra={"proxies": self.proxies},
             ),
-            self.public_read_conn_id_without_host: Connection(
+            Connection(
                 conn_id=self.public_read_conn_id_without_host,
                 conn_type=self.connection_type,
                 login=self.login,
-                extra=json.dumps({"proxies": self.proxies}),
+                extra={"proxies": self.proxies},
             ),
-            self.connection_string_id: Connection(
-                conn_id=self.connection_string_id,
+            Connection(
+                conn_id=self.azure_test_connection_string,
                 conn_type=self.connection_type,
-                extra=json.dumps({"connection_string": CONN_STRING, "proxies": 
self.proxies}),
+                extra={"connection_string": CONN_STRING, "proxies": 
self.proxies},
             ),
-            self.shared_key_conn_id: Connection(
-                conn_id=self.shared_key_conn_id,
+            Connection(
+                conn_id=self.azure_shared_key_test,
                 conn_type=self.connection_type,
                 host="https://accountname.blob.core.windows.net";,
-                extra=json.dumps({"shared_access_key": "token", "proxies": 
self.proxies}),
+                extra={"shared_access_key": "token", "proxies": self.proxies},
             ),
-            self.shared_key_conn_id_without_host: Connection(
-                conn_id=self.shared_key_conn_id_without_host,
-                conn_type=self.connection_type,
-                login=self.login,
-                extra=json.dumps({"shared_access_key": "token", "proxies": 
self.proxies}),
-            ),
-            self.ad_conn_id: Connection(
+            Connection(
                 conn_id=self.ad_conn_id,
                 conn_type=self.connection_type,
                 host="conn_host",
                 login="appID",
                 password="appsecret",
-                extra=json.dumps(
-                    {
-                        "tenant_id": "token",
-                        "proxies": self.proxies,
-                        "client_secret_auth_config": 
self.client_secret_auth_config,
-                    }
-                ),
+                extra={
+                    "tenant_id": "token",
+                    "proxies": self.proxies,
+                    "client_secret_auth_config": 
self.client_secret_auth_config,
+                },
             ),
-            self.managed_identity_conn_id: Connection(
+            Connection(
                 conn_id=self.managed_identity_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({"proxies": self.proxies}),
+                extra={"proxies": self.proxies},
             ),
-            self.sas_conn_id: Connection(
-                conn_id=self.sas_conn_id,
+            Connection(
+                conn_id="sas_conn_id",
                 conn_type=self.connection_type,
                 login=self.login,
-                extra=json.dumps({"sas_token": "token", "proxies": 
self.proxies}),
+                extra={"sas_token": "token", "proxies": self.proxies},
             ),
-            self.extra__wasb__sas_conn_id: Connection(
+            Connection(
                 conn_id=self.extra__wasb__sas_conn_id,
                 conn_type=self.connection_type,
                 login=self.login,
-                extra=json.dumps({"extra__wasb__sas_token": "token", 
"proxies": self.proxies}),
+                extra={"extra__wasb__sas_token": "token", "proxies": 
self.proxies},
             ),
-            self.http_sas_conn_id: Connection(
+            Connection(
                 conn_id=self.http_sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps(
-                    {"sas_token": "https://login.blob.core.windows.net/token";, 
"proxies": self.proxies}
-                ),
+                extra={"sas_token": 
"https://login.blob.core.windows.net/token";, "proxies": self.proxies},
             ),
-            self.extra__wasb__http_sas_conn_id: Connection(
+            Connection(
                 conn_id=self.extra__wasb__http_sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps(
-                    {
-                        "extra__wasb__sas_token": 
"https://login.blob.core.windows.net/token";,
-                        "proxies": self.proxies,
-                    }
-                ),
+                extra={
+                    "extra__wasb__sas_token": 
"https://login.blob.core.windows.net/token";,
+                    "proxies": self.proxies,
+                },
             ),
-        }
-
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_key(self, mock_get_conn, mock_blob_service_client):
-        conn = self.connection_map[self.wasb_test_key]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.wasb_test_key)
-        assert mock_blob_service_client.call_args == mock.call(
-            account_url=f"https://{self.login}.blob.core.windows.net/";,
-            credential=conn.password,
+        )
+        self.connection_map = {conn.conn_id: conn for conn in conns}
+
+    def test_key(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.wasb_test_key)
+        mocked_blob_service_client.assert_not_called()  # Not expected during 
initialisation
+        hook.get_conn()
+        mocked_blob_service_client.assert_called_once_with(
+            account_url=f"https://{self.login}.blob.core.windows.net/";, 
credential="key"
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_public_read(self, mock_get_conn, mock_blob_service_client):
-        conn = self.connection_map[self.public_read_conn_id]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True)
-        assert mock_blob_service_client.call_args == mock.call(
-            account_url=conn.host,
-            proxies=conn.extra_dejson["proxies"],
+    def test_public_read(self, mocked_blob_service_client):
+        WasbHook(wasb_conn_id=self.public_read_conn_id, 
public_read=True).get_conn()
+        mocked_blob_service_client.assert_called_once_with(
+            account_url="https://accountname.blob.core.windows.net";, 
proxies=self.proxies
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_connection_string(self, mock_get_conn, mock_blob_service_client):
-        conn = self.connection_map[self.connection_string_id]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.connection_string_id)
-        
mock_blob_service_client.from_connection_string.assert_called_once_with(
-            CONN_STRING,
-            proxies=conn.extra_dejson["proxies"],
-            connection_string=conn.extra_dejson["connection_string"],
+    def test_connection_string(self, mocked_blob_service_client):
+        WasbHook(wasb_conn_id=self.azure_test_connection_string).get_conn()
+        
mocked_blob_service_client.from_connection_string.assert_called_once_with(
+            CONN_STRING, proxies=self.proxies, connection_string=CONN_STRING
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_shared_key_connection(self, mock_get_conn, 
mock_blob_service_client):
-        conn = self.connection_map[self.shared_key_conn_id]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.shared_key_conn_id)
-        mock_blob_service_client.assert_called_once_with(
-            account_url=conn.host,
-            credential=conn.extra_dejson["shared_access_key"],
-            proxies=conn.extra_dejson["proxies"],
-            shared_access_key=conn.extra_dejson["shared_access_key"],
+    def test_shared_key_connection(self, mocked_blob_service_client):
+        WasbHook(wasb_conn_id=self.azure_shared_key_test).get_conn()
+        mocked_blob_service_client.assert_called_once_with(
+            account_url="https://accountname.blob.core.windows.net";,
+            credential="token",
+            proxies=self.proxies,
+            shared_access_key="token",
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_managed_identity(self, mock_get_conn, mock_credential, 
mock_blob_service_client):
-        conn = self.connection_map[self.managed_identity_conn_id]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.managed_identity_conn_id)
-        mock_blob_service_client.assert_called_once_with(
-            account_url=f"https://{conn.login}.blob.core.windows.net/";,
-            credential=mock_credential.return_value,
-            proxies=conn.extra_dejson["proxies"],
+    def test_managed_identity(self, mocked_default_azure_credential, 
mocked_blob_service_client):
+        mocked_default_azure_credential.return_value = "foo-bar"
+        WasbHook(wasb_conn_id=self.managed_identity_conn_id).get_conn()
+        mocked_blob_service_client.assert_called_once_with(
+            account_url="https://None.blob.core.windows.net/";,
+            credential="foo-bar",
+            proxies=self.proxies,
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_azure_directory_connection(self, mock_get_conn, mock_credential, 
mock_blob_service_client):
-        conn = self.connection_map[self.ad_conn_id]
-        mock_get_conn.return_value = conn
-        WasbHook(wasb_conn_id=self.ad_conn_id)
-        mock_credential.assert_called_once_with(
-            conn.extra_dejson["tenant_id"],
-            conn.login,
-            conn.password,
+    def test_azure_directory_connection(self, mocked_client_secret_credential, 
mocked_blob_service_client):
+        mocked_client_secret_credential.return_value = "spam-egg"
+        WasbHook(wasb_conn_id=self.ad_conn_id).get_conn()
+        mocked_client_secret_credential.assert_called_once_with(
+            "token",
+            "appID",
+            "appsecret",
             proxies=self.client_secret_auth_config["proxies"],
             
connection_verify=self.client_secret_auth_config["connection_verify"],
             authority=self.client_secret_auth_config["authority"],
         )
-        mock_blob_service_client.assert_called_once_with(
-            account_url=f"https://{conn.login}.blob.core.windows.net/";,
-            credential=mock_credential.return_value,
-            tenant_id=conn.extra_dejson["tenant_id"],
-            proxies=conn.extra_dejson["proxies"],
+        mocked_blob_service_client.assert_called_once_with(
+            account_url="https://appID.blob.core.windows.net/";,
+            credential="spam-egg",
+            tenant_id="token",
+            proxies=self.proxies,
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_active_directory_ID_used_as_host(self, mock_get_conn, 
mock_credential, mock_blob_service_client):
-        mock_get_conn.return_value = Connection(
-            conn_id="testconn",
-            conn_type=self.connection_type,
-            login="testaccountname",
-            host="testaccountID",
-        )
-        WasbHook(wasb_conn_id="testconn")
-        assert mock_blob_service_client.call_args == mock.call(
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="testconn",
+                conn_type="wasb",
+                login="testaccountname",
+                host="testaccountID",
+            )
+        ],
+        indirect=True,
+    )
+    def test_active_directory_id_used_as_host(
+        self, mocked_connection, mocked_default_azure_credential, 
mocked_blob_service_client
+    ):
+        mocked_default_azure_credential.return_value = "fake-credential"
+        WasbHook(wasb_conn_id="testconn").get_conn()
+        mocked_blob_service_client.assert_called_once_with(
             account_url="https://testaccountname.blob.core.windows.net/";,
-            credential=mock_credential.return_value,
+            credential="fake-credential",
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_sas_token_provided_and_active_directory_ID_used_as_host(
-        self, mock_get_conn, mock_blob_service_client
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="testconn",
+                conn_type="wasb",
+                login="testaccountname",
+                host="testaccountID",
+                extra={"sas_token": "SAStoken"},
+            )
+        ],
+        indirect=True,
+    )
+    def test_sas_token_provided_and_active_directory_id_used_as_host(
+        self, mocked_connection, mocked_blob_service_client
     ):
-        mock_get_conn.return_value = Connection(
-            conn_id="testconn",
-            conn_type=self.connection_type,
-            login="testaccountname",
-            host="testaccountID",
-            extra=json.dumps({"sas_token": "SAStoken"}),
-        )
-        WasbHook(wasb_conn_id="testconn")
-        assert mock_blob_service_client.call_args == mock.call(
+        WasbHook(wasb_conn_id="testconn").get_conn()
+        mocked_blob_service_client.assert_called_once_with(
             
account_url="https://testaccountname.blob.core.windows.net/SAStoken";,
             sas_token="SAStoken",
         )
 
     @pytest.mark.parametrize(
-        argnames="conn_id_str",
-        argvalues=[
-            "wasb_test_key",
-            "shared_key_conn_id_without_host",
-            "public_read_conn_id_without_host",
+        "mocked_connection",
+        [
+            pytest.param(
+                Connection(
+                    conn_type="wasb",
+                    login="foo",
+                    extra={"shared_access_key": "token", "proxies": PROXIES},
+                ),
+                id="shared-key-without-host",
+            ),
+            pytest.param(
+                Connection(conn_type="wasb", login="foo", extra={"proxies": 
PROXIES}),
+                id="public-read-without-host",
+            ),
         ],
+        indirect=True,
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
     def test_account_url_without_host(
-        self, mock_get_conn, mock_credential, mock_blob_service_client, 
conn_id_str
+        self, mocked_connection, mocked_blob_service_client, 
mocked_default_azure_credential
     ):
-        conn_id = self.__getattribute__(conn_id_str)
-        connection = self.connection_map[conn_id]
-        mock_get_conn.return_value = connection
-        WasbHook(wasb_conn_id=conn_id)
-        if conn_id_str == "wasb_test_key":
-            mock_blob_service_client.assert_called_once_with(
-                
account_url=f"https://{connection.login}.blob.core.windows.net/";,
-                credential=connection.password,
-            )
-        elif conn_id_str == "shared_key_conn_id_without_host":
-            mock_blob_service_client.assert_called_once_with(
-                
account_url=f"https://{connection.login}.blob.core.windows.net/";,
-                credential=connection.extra_dejson["shared_access_key"],
-                proxies=connection.extra_dejson["proxies"],
-                shared_access_key=connection.extra_dejson["shared_access_key"],
+        mocked_default_azure_credential.return_value = "default-creds"
+        WasbHook(wasb_conn_id=mocked_connection.conn_id).get_conn()
+        if "shared_access_key" in mocked_connection.extra_dejson:
+            mocked_blob_service_client.assert_called_once_with(
+                
account_url=f"https://{mocked_connection.login}.blob.core.windows.net/";,
+                credential=mocked_connection.extra_dejson["shared_access_key"],
+                proxies=mocked_connection.extra_dejson["proxies"],
+                
shared_access_key=mocked_connection.extra_dejson["shared_access_key"],
             )
         else:
-            mock_blob_service_client.assert_called_once_with(
-                
account_url=f"https://{connection.login}.blob.core.windows.net/";,
-                credential=mock_credential.return_value,
-                proxies=connection.extra_dejson["proxies"],
+            mocked_blob_service_client.assert_called_once_with(
+                
account_url=f"https://{mocked_connection.login}.blob.core.windows.net/";,
+                credential="default-creds",
+                proxies=mocked_connection.extra_dejson["proxies"],
             )
 
     @pytest.mark.parametrize(
@@ -308,25 +300,22 @@ class TestWasbHook:
             ("extra__wasb__http_sas_conn_id", "extra__wasb__sas_token"),
         ],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_sas_token_connection(self, mock_get_conn, conn_id_str, extra_key):
-        conn_id = self.__getattribute__(conn_id_str)
-        mock_get_conn.return_value = self.connection_map[conn_id]
-        hook = WasbHook(wasb_conn_id=conn_id)
+    def test_sas_token_connection(self, conn_id_str, extra_key):
+        hook = WasbHook(wasb_conn_id=conn_id_str)
         conn = hook.get_conn()
         hook_conn = hook.get_connection(hook.conn_id)
         sas_token = hook_conn.extra_dejson[extra_key]
         assert isinstance(conn, BlobServiceClient)
         assert conn.url.startswith("https://";)
         if hook_conn.login:
-            assert conn.url.__contains__(hook_conn.login)
+            assert hook_conn.login in conn.url
         assert conn.url.endswith(sas_token + "/")
 
     @pytest.mark.parametrize(
         argnames="conn_id_str",
         argvalues=[
-            "connection_string_id",
-            "shared_key_conn_id",
+            "azure_test_connection_string",
+            "azure_shared_key_test",
             "ad_conn_id",
             "managed_identity_conn_id",
             "sas_conn_id",
@@ -335,27 +324,17 @@ class TestWasbHook:
             "extra__wasb__http_sas_conn_id",
         ],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_connection_extra_arguments(self, mock_get_conn, conn_id_str):
-        conn_id = self.__getattribute__(conn_id_str)
-        mock_get_conn.return_value = self.connection_map[conn_id]
-        hook = WasbHook(wasb_conn_id=conn_id)
-        conn = hook.get_conn()
+    def test_connection_extra_arguments(self, conn_id_str):
+        conn = WasbHook(wasb_conn_id=conn_id_str).get_conn()
         assert conn._config.proxy_policy.proxies == self.proxies
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_connection_extra_arguments_public_read(self, mock_get_conn):
-        conn_id = self.public_read_conn_id
-        mock_get_conn.return_value = self.connection_map[conn_id]
-        hook = WasbHook(wasb_conn_id=conn_id, public_read=True)
+    def test_connection_extra_arguments_public_read(self):
+        hook = WasbHook(wasb_conn_id=self.public_read_conn_id, 
public_read=True)
         conn = hook.get_conn()
         assert conn._config.proxy_policy.proxies == self.proxies
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_extra_client_secret_auth_config_ad_connection(self, 
mock_get_conn):
-        mock_get_conn.return_value = self.connection_map[self.ad_conn_id]
-        conn_id = self.ad_conn_id
-        hook = WasbHook(wasb_conn_id=conn_id)
+    def test_extra_client_secret_auth_config_ad_connection(self):
+        hook = WasbHook(wasb_conn_id=self.ad_conn_id)
         conn = hook.get_conn()
         assert conn.credential._authority == self.authority
 
@@ -371,79 +350,62 @@ class TestWasbHook:
             ("testhost.blob.net", "testhost.blob.net"),
         ],
     )
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
     def test_proper_account_url_update(
-        self, mock_get_conn, mock_blob_service_client, provided_host, 
expected_host
+        self, mocked_blob_service_client, provided_host, expected_host, 
create_mock_connection
     ):
-        mock_get_conn.return_value = Connection(
-            conn_id="test_conn",
-            conn_type=self.connection_type,
-            password="testpass",
-            login="accountlogin",
-            host=provided_host,
+        conn = create_mock_connection(
+            Connection(
+                conn_type=self.connection_type,
+                password="testpass",
+                login="accountlogin",
+                host=provided_host,
+            )
         )
-        WasbHook(wasb_conn_id=self.shared_key_conn_id)
-        
mock_blob_service_client.assert_called_once_with(account_url=expected_host, 
credential="testpass")
-
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_check_for_blob(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        WasbHook(wasb_conn_id=conn.conn_id).get_conn()
+        
mocked_blob_service_client.assert_called_once_with(account_url=expected_host, 
credential="testpass")
+
+    def test_check_for_blob(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         assert hook.check_for_blob(container_name="mycontainer", 
blob_name="myblob")
-        mock_blob_client = mock_service.return_value.get_blob_client
+        mock_blob_client = 
mocked_blob_service_client.return_value.get_blob_client
         mock_blob_client.assert_called_once_with(container="mycontainer", 
blob="myblob")
         mock_blob_client.return_value.get_blob_properties.assert_called()
 
-    @mock.patch.object(WasbHook, "get_blobs_list")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_check_for_prefix(self, mock_get_conn, get_blobs_list):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        get_blobs_list.return_value = ["blobs"]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    @mock.patch.object(WasbHook, "get_blobs_list", return_value=["blobs"])
+    def test_check_for_prefix(self, get_blobs_list):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         assert hook.check_for_prefix("container", "prefix", timeout=3)
         get_blobs_list.assert_called_once_with(container_name="container", 
prefix="prefix", timeout=3)
 
-    @mock.patch.object(WasbHook, "get_blobs_list")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_check_for_prefix_empty(self, mock_get_conn, get_blobs_list):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        get_blobs_list.return_value = []
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    @mock.patch.object(WasbHook, "get_blobs_list", return_value=[])
+    def test_check_for_prefix_empty(self, get_blobs_list):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         assert not hook.check_for_prefix("container", "prefix", timeout=3)
         get_blobs_list.assert_called_once_with(container_name="container", 
prefix="prefix", timeout=3)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_blobs_list(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_get_blobs_list(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.get_blobs_list(container_name="mycontainer", prefix="my", 
include=None, delimiter="/")
-        
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-        
mock_service.return_value.get_container_client.return_value.walk_blobs.assert_called_once_with(
+        mock_container_client = 
mocked_blob_service_client.return_value.get_container_client
+        mock_container_client.assert_called_once_with("mycontainer")
+        mock_container_client.return_value.walk_blobs.assert_called_once_with(
             name_starts_with="my", include=None, delimiter="/"
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_blobs_list_recursive(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_get_blobs_list_recursive(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.get_blobs_list_recursive(
             container_name="mycontainer", prefix="test", include=None, 
endswith="file_extension"
         )
-        
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-        
mock_service.return_value.get_container_client.return_value.list_blobs.assert_called_once_with(
+        mock_container_client = 
mocked_blob_service_client.return_value.get_container_client
+        mock_container_client.assert_called_once_with("mycontainer")
+        mock_container_client.return_value.list_blobs.assert_called_once_with(
             name_starts_with="test", include=None
         )
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_blobs_list_recursive_endswith(self, mock_get_conn, 
mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
-        
mock_service.return_value.get_container_client.return_value.list_blobs.return_value
 = [
+    def test_get_blobs_list_recursive_endswith(self, 
mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
+        
mocked_blob_service_client.return_value.get_container_client.return_value.list_blobs.return_value
 = [
             BlobProperties(name="test/abc.py"),
             BlobProperties(name="test/inside_test/abc.py"),
             BlobProperties(name="test/abc.csv"),
@@ -455,11 +417,9 @@ class TestWasbHook:
 
     @pytest.mark.parametrize(argnames="create_container", argvalues=[True, 
False])
     @mock.patch.object(WasbHook, "upload")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_load_file(self, mock_get_conn, mock_upload, create_container):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
+    def test_load_file(self, mock_upload, create_container):
         with mock.patch("builtins.open", mock.mock_open(read_data="data")):
-            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
             hook.load_file("path", "container", "blob", create_container, 
max_connections=1)
 
         mock_upload.assert_called_with(
@@ -472,10 +432,8 @@ class TestWasbHook:
 
     @pytest.mark.parametrize(argnames="create_container", argvalues=[True, 
False])
     @mock.patch.object(WasbHook, "upload")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_load_string(self, mock_get_conn, mock_upload, create_container):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_load_string(self, mock_upload, create_container):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.load_string("big string", "container", "blob", create_container, 
max_connections=1)
         mock_upload.assert_called_once_with(
             container_name="container",
@@ -486,30 +444,22 @@ class TestWasbHook:
         )
 
     @mock.patch.object(WasbHook, "download")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_file(self, mock_get_conn, mock_download):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
+    def test_get_file(self, mock_download):
         with mock.patch("builtins.open", mock.mock_open(read_data="data")):
-            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
             hook.get_file("path", "container", "blob", max_connections=1)
         mock_download.assert_called_once_with(container_name="container", 
blob_name="blob", max_connections=1)
         mock_download.return_value.readall.assert_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     @mock.patch.object(WasbHook, "download")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_read_file(self, mock_get_conn, mock_download, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_read_file(self, mock_download, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.read_file("container", "blob", max_connections=1)
         mock_download.assert_called_once_with("container", "blob", 
max_connections=1)
 
     @pytest.mark.parametrize(argnames="create_container", argvalues=[True, 
False])
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_upload(self, mock_get_conn, mock_service, create_container):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_upload(self, mocked_blob_service_client, create_container):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.upload(
             container_name="mycontainer",
             blob_name="myblob",
@@ -518,80 +468,61 @@ class TestWasbHook:
             blob_type="BlockBlob",
             length=4,
         )
-        mock_blob_client = mock_service.return_value.get_blob_client
+        mock_blob_client = 
mocked_blob_service_client.return_value.get_blob_client
         mock_blob_client.assert_called_once_with(container="mycontainer", 
blob="myblob")
         
mock_blob_client.return_value.upload_blob.assert_called_once_with(b"mydata", 
"BlockBlob", length=4)
 
-        mock_container_client = mock_service.return_value.get_container_client
+        mock_container_client = 
mocked_blob_service_client.return_value.get_container_client
         if create_container:
             mock_container_client.assert_called_with("mycontainer")
         else:
             mock_container_client.assert_not_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_download(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        blob_client = mock_service.return_value.get_blob_client
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_download(self, mocked_blob_service_client):
+        blob_client = mocked_blob_service_client.return_value.get_blob_client
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.download(container_name="mycontainer", blob_name="myblob", 
offset=2, length=4)
         blob_client.assert_called_once_with(container="mycontainer", 
blob="myblob")
         
blob_client.return_value.download_blob.assert_called_once_with(offset=2, 
length=4)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_container_client(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_get_container_client(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook._get_container_client("mycontainer")
-        
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
+        
mocked_blob_service_client.return_value.get_container_client.assert_called_once_with("mycontainer")
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_get_blob_client(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_get_blob_client(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook._get_blob_client(container_name="mycontainer", blob_name="myblob")
-        mock_instance = mock_service.return_value.get_blob_client
+        mock_instance = mocked_blob_service_client.return_value.get_blob_client
         mock_instance.assert_called_once_with(container="mycontainer", 
blob="myblob")
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_create_container(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_create_container(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.create_container(container_name="mycontainer")
-        mock_instance = mock_service.return_value.get_container_client
+        mock_instance = 
mocked_blob_service_client.return_value.get_container_client
         mock_instance.assert_called_once_with("mycontainer")
         mock_instance.return_value.create_container.assert_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_container(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_delete_container(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.delete_container("mycontainer")
-        
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
-        
mock_service.return_value.get_container_client.return_value.delete_container.assert_called()
+        mocked_container_client = 
mocked_blob_service_client.return_value.get_container_client
+        mocked_container_client.assert_called_once_with("mycontainer")
+        mocked_container_client.return_value.delete_container.assert_called()
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     @mock.patch.object(WasbHook, "delete_blobs")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_single_blob(self, mock_get_conn, delete_blobs, 
mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_delete_single_blob(self, delete_blobs, 
mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.delete_file("container", "blob", is_prefix=False)
         delete_blobs.assert_called_once_with("container", "blob")
 
     @mock.patch.object(WasbHook, "delete_blobs")
     @mock.patch.object(WasbHook, "get_blobs_list")
     @mock.patch.object(WasbHook, "check_for_blob")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_multiple_blobs(self, mock_get_conn, mock_check, 
mock_get_blobslist, mock_delete_blobs):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
+    def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist, 
mock_delete_blobs):
         mock_check.return_value = False
         mock_get_blobslist.return_value = ["blob_prefix/blob1", 
"blob_prefix/blob2"]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.delete_file("container", "blob_prefix", is_prefix=True)
         mock_get_blobslist.assert_called_once_with("container", 
prefix="blob_prefix", delimiter="")
         mock_delete_blobs.assert_any_call(
@@ -604,14 +535,10 @@ class TestWasbHook:
     @mock.patch.object(WasbHook, "delete_blobs")
     @mock.patch.object(WasbHook, "get_blobs_list")
     @mock.patch.object(WasbHook, "check_for_blob")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_more_than_256_blobs(
-        self, mock_get_conn, mock_check, mock_get_blobslist, mock_delete_blobs
-    ):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
+    def test_delete_more_than_256_blobs(self, mock_check, mock_get_blobslist, 
mock_delete_blobs):
         mock_check.return_value = False
         mock_get_blobslist.return_value = [f"blob_prefix/blob{i}" for i in 
range(300)]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.delete_file("container", "blob_prefix", is_prefix=True)
         mock_get_blobslist.assert_called_once_with("container", 
prefix="blob_prefix", delimiter="")
         # The maximum number of blobs that can be deleted in a single request 
is 256 using the underlying
@@ -620,33 +547,26 @@ class TestWasbHook:
         # `ContainerClient.delete_blobs()` in this test.
         assert mock_delete_blobs.call_count == 2
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     @mock.patch.object(WasbHook, "get_blobs_list")
     @mock.patch.object(WasbHook, "check_for_blob")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_nonexisting_blob_fails(self, mock_get_conn, mock_check, 
mock_getblobs, mock_service):
+    def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs, 
mocked_blob_service_client):
         mock_getblobs.return_value = []
         mock_check.return_value = False
         with pytest.raises(Exception) as ctx:
-            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
             hook.delete_file("container", "nonexisting_blob", is_prefix=False, 
ignore_if_missing=False)
         assert isinstance(ctx.value, AirflowException)
 
     @mock.patch.object(WasbHook, "get_blobs_list")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_delete_multiple_nonexisting_blobs_fails(self, mock_get_conn, 
mock_getblobs):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
+    def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs):
         mock_getblobs.return_value = []
         with pytest.raises(Exception) as ctx:
-            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
             hook.delete_file("container", "nonexisting_blob_prefix", 
is_prefix=True, ignore_if_missing=False)
         assert isinstance(ctx.value, AirflowException)
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_connection_success(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_connection_success(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.get_conn().get_account_information().return_value = {
             "sku_name": "Standard_RAGRS",
             "account_kind": "StorageV2",
@@ -656,11 +576,8 @@ class TestWasbHook:
         assert status is True
         assert msg == "Successfully connected to Azure Blob Storage."
 
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
-    def test_connection_failure(self, mock_get_conn, mock_service):
-        mock_get_conn.return_value = 
self.connection_map[self.shared_key_conn_id]
-        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+    def test_connection_failure(self, mocked_blob_service_client):
+        hook = WasbHook(wasb_conn_id=self.azure_shared_key_test)
         hook.get_conn().get_account_information = mock.PropertyMock(
             side_effect=Exception("Authentication failed.")
         )
diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py 
b/tests/providers/microsoft/azure/operators/test_azure_batch.py
index e920f7c1d9..e70db80913 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py
@@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
 from airflow.providers.microsoft.azure.operators.batch import 
AzureBatchOperator
-from airflow.utils import db
 
 TASK_ID = "MyDag"
 BATCH_POOL_ID = "MyPool"
@@ -40,11 +39,20 @@ FORMULA = """$curTime = time();
              $TargetDedicated = $isWorkingWeekdayHour ? 20:10;"""
 
 
[email protected]
+def mocked_batch_service_client():
+    with 
mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") 
as m:
+        yield m
+
+
 class TestAzureBatchOperator:
     # set up the test environment
-    @mock.patch("airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
-    def setup_method(self, method, mock_batch, mock_hook):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, mocked_batch_service_client, 
create_mock_connections):
+        # set up mocked Azure Batch client
+        self.batch_client = mock.MagicMock(name="FakeBatchServiceClient")
+        mocked_batch_service_client.return_value = self.batch_client
+
         # set up the test variable
         self.test_vm_conn_id = "test_azure_batch_vm2"
         self.test_cloud_conn_id = "test_azure_batch_cloud2"
@@ -59,22 +67,21 @@ class TestAzureBatchOperator:
         self.test_cloud_os_version = "test-version"
         self.test_node_agent_sku = "test-node-agent-sku"
 
-        # connect with vm configuration
-        db.merge_conn(
+        create_mock_connections(
+            # connect with vm configuration
             Connection(
                 conn_id=self.test_vm_conn_id,
                 conn_type="azure_batch",
                 extra=json.dumps({"account_url": self.test_account_url}),
-            )
-        )
-        # connect with cloud service
-        db.merge_conn(
+            ),
+            # connect with cloud service
             Connection(
                 conn_id=self.test_cloud_conn_id,
                 conn_type="azure_batch",
                 extra=json.dumps({"account_url": self.test_account_url}),
-            )
+            ),
         )
+
         self.operator = AzureBatchOperator(
             task_id=TASK_ID,
             batch_pool_id=BATCH_POOL_ID,
@@ -159,10 +166,6 @@ class TestAzureBatchOperator:
             target_dedicated_nodes=1,
             timeout=2,
         )
-        self.batch_client = mock_batch.return_value
-        self.mock_instance = mock_hook.return_value
-        assert self.batch_client == self.operator.hook.connection
-        assert self.batch_client == self.operator2_pass.hook.connection
 
     @mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
     def test_execute_without_failures(self, wait_mock):
diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py 
b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
index 57e5097006..599e24889e 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py
@@ -17,33 +17,35 @@
 # under the License.
 from __future__ import annotations
 
-import json
 import uuid
 from unittest import mock
 
+import pytest
+
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.operators.cosmos import 
AzureCosmosInsertDocumentOperator
-from airflow.utils import db
 
 
 class TestAzureCosmosDbHook:
 
     # Set up an environment to test with
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
         # set up some test variables
         self.test_end_point = "https://test_endpoint:443";
         self.test_master_key = "magic_test_key"
         self.test_database_name = "test_database_name"
         self.test_collection_name = "test_collection_name"
-        db.merge_conn(
+        create_mock_connection(
             Connection(
                 conn_id="azure_cosmos_test_key_id",
                 conn_type="azure_cosmos",
                 login=self.test_end_point,
                 password=self.test_master_key,
-                extra=json.dumps(
-                    {"database_name": self.test_database_name, 
"collection_name": self.test_collection_name}
-                ),
+                extra={
+                    "database_name": self.test_database_name,
+                    "collection_name": self.test_collection_name,
+                },
             )
         )
 
diff --git 
a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py 
b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
index 2a545cd137..e4dd61c1d2 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 from unittest.mock import MagicMock, patch
 
@@ -35,7 +34,7 @@ from airflow.providers.microsoft.azure.hooks.data_factory 
import (
 )
 from airflow.providers.microsoft.azure.operators.data_factory import 
AzureDataFactoryRunPipelineOperator
 from airflow.providers.microsoft.azure.triggers.data_factory import 
AzureDataFactoryTrigger
-from airflow.utils import db, timezone
+from airflow.utils import timezone
 from airflow.utils.types import DagRunType
 
 DEFAULT_DATE = timezone.datetime(2021, 1, 1)
@@ -60,7 +59,8 @@ AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"
 
 
 class TestAzureDataFactoryRunPipelineOperator:
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
         self.mock_ti = MagicMock()
         self.mock_context = {"ti": self.mock_ti}
         self.config = {
@@ -73,13 +73,13 @@ class TestAzureDataFactoryRunPipelineOperator:
             "timeout": 3,
         }
 
-        db.merge_conn(
+        create_mock_connection(
             Connection(
                 conn_id="azure_data_factory_test",
                 conn_type="azure_data_factory",
                 login="client-id",
                 password="client-secret",
-                extra=json.dumps(CONN_EXTRAS),
+                extra=CONN_EXTRAS,
             )
         )
 
diff --git a/tests/providers/microsoft/azure/operators/test_azure_synapse.py 
b/tests/providers/microsoft/azure/operators/test_azure_synapse.py
index c43b11ef7b..233e1c57fd 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_synapse.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_synapse.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import json
 from unittest import mock
 from unittest.mock import MagicMock
 
@@ -24,7 +23,7 @@ import pytest
 
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.operators.synapse import 
AzureSynapseRunSparkBatchOperator
-from airflow.utils import db, timezone
+from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2021, 1, 1)
 SUBSCRIPTION_ID = "my-subscription-id"
@@ -39,7 +38,8 @@ JOB_RUN_RESPONSE = {"id": 123}
 
 
 class TestAzureSynapseRunSparkBatchOperator:
-    def setup_method(self):
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
         self.mock_ti = MagicMock()
         self.mock_context = {"ti": self.mock_ti}
         self.config = {
@@ -50,22 +50,21 @@ class TestAzureSynapseRunSparkBatchOperator:
             "timeout": 3,
         }
 
-        db.merge_conn(
+        create_mock_connection(
             Connection(
                 conn_id=AZURE_SYNAPSE_CONN_ID,
                 conn_type="azure_synapse",
                 host="https://synapsetest.net";,
                 login="client-id",
                 password="client-secret",
-                extra=json.dumps(CONN_EXTRAS),
+                extra=CONN_EXTRAS,
             )
         )
 
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
     def test_azure_synapse_run_spark_batch_operator_success(
-        self, mock_run_spark_job, mock_conn, mock_get_job_run_status
+        self, mock_run_spark_job, mock_get_job_run_status
     ):
         mock_get_job_run_status.return_value = "success"
         mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
@@ -76,11 +75,8 @@ class TestAzureSynapseRunSparkBatchOperator:
         assert op.job_id == JOB_RUN_RESPONSE["id"]
 
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
-    def test_azure_synapse_run_spark_batch_operator_error(
-        self, mock_run_spark_job, mock_conn, mock_get_job_run_status
-    ):
+    def test_azure_synapse_run_spark_batch_operator_error(self, 
mock_run_spark_job, mock_get_job_run_status):
         mock_get_job_run_status.return_value = "error"
         mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
         op = AzureSynapseRunSparkBatchOperator(
@@ -93,11 +89,10 @@ class TestAzureSynapseRunSparkBatchOperator:
             op.execute(context=self.mock_context)
 
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_job_run_status")
-    
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.get_conn")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.run_spark_job")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook.cancel_job_run")
     def test_azure_synapse_run_spark_batch_operator_on_kill(
-        self, mock_cancel_job_run, mock_run_spark_job, mock_conn, 
mock_get_job_run_status
+        self, mock_cancel_job_run, mock_run_spark_job, mock_get_job_run_status
     ):
         mock_get_job_run_status.return_value = "success"
         mock_run_spark_job.return_value = MagicMock(**JOB_RUN_RESPONSE)
diff --git a/tests/providers/microsoft/conftest.py 
b/tests/providers/microsoft/conftest.py
new file mode 100644
index 0000000000..aa75c95203
--- /dev/null
+++ b/tests/providers/microsoft/conftest.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import string
+from random import choices
+from typing import TypeVar
+
+import pytest
+
+from airflow.models import Connection
+
+T = TypeVar("T", dict, str, Connection)
+
+
[email protected]
+def create_mock_connection(monkeypatch):
+    """Helper fixture for create test connection."""
+
+    def wrapper(conn: T, conn_id: str | None = None):
+        conn_id = conn_id or "test_conn_" + 
"".join(choices(string.ascii_lowercase + string.digits, k=6))
+        if isinstance(conn, dict):
+            conn = Connection.from_json(conn)
+        elif isinstance(conn, str):
+            conn = Connection(uri=conn)
+
+        if not isinstance(conn, Connection):
+            raise TypeError(
+                f"Fixture expected either JSON, URI or Connection type, but 
got {type(conn).__name__}"
+            )
+        if not conn.conn_id:
+            conn.conn_id = conn_id
+
+        monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", 
conn.get_uri())
+        return conn
+
+    return wrapper
+
+
[email protected]
+def create_mock_connections(create_mock_connection):
+    """Helper fixture for create multiple test connections."""
+
+    def wrapper(*conns: T):
+        return list(map(create_mock_connection, conns))
+
+    return wrapper
+
+
[email protected]
+def mocked_connection(request, create_mock_connection):
+    """Helper indirect fixture for create test connection."""
+    return create_mock_connection(request.param)


Reply via email to