This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 89404530aa Refactor WasbHook tests (#32922)
89404530aa is described below
commit 89404530aa02d49ffe724a3ab653d5c9d687dd00
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Sat Jul 29 00:00:16 2023 +0100
Refactor WasbHook tests (#32922)
The current tests around the different ways of connecting to Azure in the
WasbHook are not actually unit testing it.
The tests are basically testing if we could connect and not how we
connected.
This refactor improves the tests and appropriately tests how we connect to
the BlobServiceClient.
It also removes adding the different connections to the database before
testing.
---
tests/providers/microsoft/azure/hooks/test_wasb.py | 326 ++++++++++++++-------
1 file changed, 214 insertions(+), 112 deletions(-)
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 9e9ce9a0d0..5deca1b805 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -21,14 +21,12 @@ import json
from unittest import mock
import pytest
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.storage.blob._models import BlobProperties
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
-from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version,
object_exists
# connection_string has a format
@@ -63,56 +61,43 @@ class TestWasbHook:
"connection_verify": False,
"authority": self.authority,
}
-
- db.merge_conn(
- Connection(
- conn_id=self.wasb_test_key,
+ self.connection_map = {
+ self.wasb_test_key: Connection(
+ conn_id="wasb_test_key",
conn_type=self.connection_type,
login=self.login,
password="key",
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.public_read_conn_id: Connection(
conn_id=self.public_read_conn_id,
conn_type=self.connection_type,
host="https://accountname.blob.core.windows.net",
extra=json.dumps({"proxies": self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.public_read_conn_id_without_host: Connection(
conn_id=self.public_read_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"proxies": self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.connection_string_id: Connection(
conn_id=self.connection_string_id,
conn_type=self.connection_type,
extra=json.dumps({"connection_string": CONN_STRING, "proxies":
self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.shared_key_conn_id: Connection(
conn_id=self.shared_key_conn_id,
conn_type=self.connection_type,
host="https://accountname.blob.core.windows.net",
extra=json.dumps({"shared_access_key": "token", "proxies":
self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.shared_key_conn_id_without_host: Connection(
conn_id=self.shared_key_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"shared_access_key": "token", "proxies":
self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.ad_conn_id: Connection(
conn_id=self.ad_conn_id,
conn_type=self.connection_type,
host="conn_host",
@@ -125,42 +110,32 @@ class TestWasbHook:
"client_secret_auth_config":
self.client_secret_auth_config,
}
),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.managed_identity_conn_id: Connection(
conn_id=self.managed_identity_conn_id,
conn_type=self.connection_type,
extra=json.dumps({"proxies": self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.sas_conn_id: Connection(
conn_id=self.sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"sas_token": "token", "proxies":
self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.extra__wasb__sas_conn_id: Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"extra__wasb__sas_token": "token",
"proxies": self.proxies}),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.http_sas_conn_id: Connection(
conn_id=self.http_sas_conn_id,
conn_type=self.connection_type,
extra=json.dumps(
{"sas_token": "https://login.blob.core.windows.net/token",
"proxies": self.proxies}
),
- )
- )
- db.merge_conn(
- Connection(
+ ),
+ self.extra__wasb__http_sas_conn_id: Connection(
conn_id=self.extra__wasb__http_sas_conn_id,
conn_type=self.connection_type,
extra=json.dumps(
@@ -169,49 +144,102 @@ class TestWasbHook:
"proxies": self.proxies,
}
),
- )
- )
+ ),
+ }
- def test_key(self):
- hook = WasbHook(wasb_conn_id="wasb_test_key")
- assert hook.conn_id == "wasb_test_key"
- assert isinstance(hook.blob_service_client, BlobServiceClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_key(self, mock_get_conn, mock_blob_service_client):
+ conn = self.connection_map[self.wasb_test_key]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.wasb_test_key)
+ assert mock_blob_service_client.call_args == mock.call(
+ account_url=f"https://{self.login}.blob.core.windows.net/",
+ credential=conn.password,
+ )
- def test_public_read(self):
- hook = WasbHook(wasb_conn_id=self.public_read_conn_id,
public_read=True)
- assert isinstance(hook.get_conn(), BlobServiceClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_public_read(self, mock_get_conn, mock_blob_service_client):
+ conn = self.connection_map[self.public_read_conn_id]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True)
+ assert mock_blob_service_client.call_args == mock.call(
+ account_url=conn.host,
+ proxies=conn.extra_dejson["proxies"],
+ )
- def test_connection_string(self):
- hook = WasbHook(wasb_conn_id=self.connection_string_id)
- assert hook.conn_id == self.connection_string_id
- assert isinstance(hook.get_conn(), BlobServiceClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_connection_string(self, mock_get_conn, mock_blob_service_client):
+ conn = self.connection_map[self.connection_string_id]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.connection_string_id)
+
mock_blob_service_client.from_connection_string.assert_called_once_with(
+ CONN_STRING,
+ proxies=conn.extra_dejson["proxies"],
+ connection_string=conn.extra_dejson["connection_string"],
+ )
- def test_shared_key_connection(self):
- hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
- assert isinstance(hook.get_conn(), BlobServiceClient)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_shared_key_connection(self, mock_get_conn,
mock_blob_service_client):
+ conn = self.connection_map[self.shared_key_conn_id]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ mock_blob_service_client.assert_called_once_with(
+ account_url=conn.host,
+ credential=conn.extra_dejson["shared_access_key"],
+ proxies=conn.extra_dejson["proxies"],
+ shared_access_key=conn.extra_dejson["shared_access_key"],
+ )
- def test_managed_identity(self):
- hook = WasbHook(wasb_conn_id=self.managed_identity_conn_id)
- assert isinstance(hook.get_conn(), BlobServiceClient)
- assert isinstance(hook.get_conn().credential, DefaultAzureCredential)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_managed_identity(self, mock_get_conn, mock_credential,
mock_blob_service_client):
+ conn = self.connection_map[self.managed_identity_conn_id]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.managed_identity_conn_id)
+ mock_blob_service_client.assert_called_once_with(
+ account_url=f"https://{conn.login}.blob.core.windows.net/",
+ credential=mock_credential.return_value,
+ proxies=conn.extra_dejson["proxies"],
+ )
- def test_azure_directory_connection(self):
- hook = WasbHook(wasb_conn_id=self.ad_conn_id)
- assert isinstance(hook.get_conn(), BlobServiceClient)
- assert isinstance(hook.get_conn().credential, ClientSecretCredential)
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.ClientSecretCredential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_azure_directory_connection(self, mock_get_conn, mock_credential,
mock_blob_service_client):
+ conn = self.connection_map[self.ad_conn_id]
+ mock_get_conn.return_value = conn
+ WasbHook(wasb_conn_id=self.ad_conn_id)
+ mock_credential.assert_called_once_with(
+ conn.extra_dejson["tenant_id"],
+ conn.login,
+ conn.password,
+ proxies=self.client_secret_auth_config["proxies"],
+
connection_verify=self.client_secret_auth_config["connection_verify"],
+ authority=self.client_secret_auth_config["authority"],
+ )
+ mock_blob_service_client.assert_called_once_with(
+ account_url=conn.host,
+ credential=mock_credential.return_value,
+ tenant_id=conn.extra_dejson["tenant_id"],
+ proxies=conn.extra_dejson["proxies"],
+ )
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
def test_active_directory_ID_used_as_host(self, mock_get_conn,
mock_credential, mock_blob_service_client):
- hook = WasbHook(wasb_conn_id="testconn")
mock_get_conn.return_value = Connection(
conn_id="testconn",
conn_type=self.connection_type,
login="testaccountname",
host="testaccountID",
)
- hook.get_conn()
+ WasbHook(wasb_conn_id="testconn")
assert mock_blob_service_client.call_args == mock.call(
account_url="https://testaccountname.blob.core.windows.net/",
credential=mock_credential.return_value,
@@ -222,7 +250,6 @@ class TestWasbHook:
def test_sas_token_provided_and_active_directory_ID_used_as_host(
self, mock_get_conn, mock_blob_service_client
):
- hook = WasbHook(wasb_conn_id="testconn")
mock_get_conn.return_value = Connection(
conn_id="testconn",
conn_type=self.connection_type,
@@ -230,7 +257,7 @@ class TestWasbHook:
host="testaccountID",
extra=json.dumps({"sas_token": "SAStoken"}),
)
- hook.get_conn()
+ WasbHook(wasb_conn_id="testconn")
assert mock_blob_service_client.call_args == mock.call(
account_url="https://testaccountname.blob.core.windows.net/SAStoken",
sas_token="SAStoken",
@@ -244,14 +271,34 @@ class TestWasbHook:
"public_read_conn_id_without_host",
],
)
- def test_account_url_without_host(self, conn_id_str):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.DefaultAzureCredential")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_account_url_without_host(
+ self, mock_get_conn, mock_credential, mock_blob_service_client,
conn_id_str
+ ):
conn_id = self.__getattribute__(conn_id_str)
- hook = WasbHook(wasb_conn_id=conn_id)
- hook_conn = hook.get_connection(hook.conn_id)
- conn = hook.get_conn()
- assert conn.url.startswith("https://")
- assert conn.url.__contains__(hook_conn.login)
- assert conn.url.endswith(".blob.core.windows.net/")
+ connection = self.connection_map[conn_id]
+ mock_get_conn.return_value = connection
+ WasbHook(wasb_conn_id=conn_id)
+ if conn_id_str == "wasb_test_key":
+ mock_blob_service_client.assert_called_once_with(
+
account_url=f"https://{connection.login}.blob.core.windows.net/",
+ credential=connection.password,
+ )
+ elif conn_id_str == "shared_key_conn_id_without_host":
+ mock_blob_service_client.assert_called_once_with(
+
account_url=f"https://{connection.login}.blob.core.windows.net/",
+ credential=connection.extra_dejson["shared_access_key"],
+ proxies=connection.extra_dejson["proxies"],
+ shared_access_key=connection.extra_dejson["shared_access_key"],
+ )
+ else:
+ mock_blob_service_client.assert_called_once_with(
+
account_url=f"https://{connection.login}.blob.core.windows.net/",
+ credential=mock_credential.return_value,
+ proxies=connection.extra_dejson["proxies"],
+ )
@pytest.mark.parametrize(
argnames="conn_id_str, extra_key",
@@ -262,8 +309,10 @@ class TestWasbHook:
("extra__wasb__http_sas_conn_id", "extra__wasb__sas_token"),
],
)
- def test_sas_token_connection(self, conn_id_str, extra_key):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_sas_token_connection(self, mock_get_conn, conn_id_str, extra_key):
conn_id = self.__getattribute__(conn_id_str)
+ mock_get_conn.return_value = self.connection_map[conn_id]
hook = WasbHook(wasb_conn_id=conn_id)
conn = hook.get_conn()
hook_conn = hook.get_connection(hook.conn_id)
@@ -287,26 +336,34 @@ class TestWasbHook:
"extra__wasb__http_sas_conn_id",
],
)
- def test_connection_extra_arguments(self, conn_id_str):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_connection_extra_arguments(self, mock_get_conn, conn_id_str):
conn_id = self.__getattribute__(conn_id_str)
+ mock_get_conn.return_value = self.connection_map[conn_id]
hook = WasbHook(wasb_conn_id=conn_id)
conn = hook.get_conn()
assert conn._config.proxy_policy.proxies == self.proxies
- def test_connection_extra_arguments_public_read(self):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_connection_extra_arguments_public_read(self, mock_get_conn):
conn_id = self.public_read_conn_id
+ mock_get_conn.return_value = self.connection_map[conn_id]
hook = WasbHook(wasb_conn_id=conn_id, public_read=True)
conn = hook.get_conn()
assert conn._config.proxy_policy.proxies == self.proxies
- def test_extra_client_secret_auth_config_ad_connection(self):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_extra_client_secret_auth_config_ad_connection(self,
mock_get_conn):
+ mock_get_conn.return_value = self.connection_map[self.ad_conn_id]
conn_id = self.ad_conn_id
hook = WasbHook(wasb_conn_id=conn_id)
conn = hook.get_conn()
assert conn.credential._authority == self.authority
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_check_for_blob(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_check_for_blob(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
assert hook.check_for_blob(container_name="mycontainer",
blob_name="myblob")
mock_blob_client = mock_service.return_value.get_blob_client
@@ -314,21 +371,27 @@ class TestWasbHook:
mock_blob_client.return_value.get_blob_properties.assert_called()
@mock.patch.object(WasbHook, "get_blobs_list")
- def test_check_for_prefix(self, get_blobs_list):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_check_for_prefix(self, mock_get_conn, get_blobs_list):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
get_blobs_list.return_value = ["blobs"]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
assert hook.check_for_prefix("container", "prefix", timeout=3)
get_blobs_list.assert_called_once_with(container_name="container",
prefix="prefix", timeout=3)
@mock.patch.object(WasbHook, "get_blobs_list")
- def test_check_for_prefix_empty(self, get_blobs_list):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_check_for_prefix_empty(self, mock_get_conn, get_blobs_list):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
get_blobs_list.return_value = []
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
assert not hook.check_for_prefix("container", "prefix", timeout=3)
get_blobs_list.assert_called_once_with(container_name="container",
prefix="prefix", timeout=3)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_get_blobs_list(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_blobs_list(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.get_blobs_list(container_name="mycontainer", prefix="my",
include=None, delimiter="/")
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
@@ -337,7 +400,9 @@ class TestWasbHook:
)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_get_blobs_list_recursive(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_blobs_list_recursive(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.get_blobs_list_recursive(
container_name="mycontainer", prefix="test", include=None,
endswith="file_extension"
@@ -348,7 +413,9 @@ class TestWasbHook:
)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_get_blobs_list_recursive_endswith(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_blobs_list_recursive_endswith(self, mock_get_conn,
mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
mock_service.return_value.get_container_client.return_value.list_blobs.return_value
= [
BlobProperties(name="test/abc.py"),
@@ -362,7 +429,9 @@ class TestWasbHook:
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
@mock.patch.object(WasbHook, "upload")
- def test_load_file(self, mock_upload, create_container):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_load_file(self, mock_get_conn, mock_upload, create_container):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
with mock.patch("builtins.open", mock.mock_open(read_data="data")):
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.load_file("path", "container", "blob", create_container,
max_connections=1)
@@ -377,7 +446,9 @@ class TestWasbHook:
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
@mock.patch.object(WasbHook, "upload")
- def test_load_string(self, mock_upload, create_container):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_load_string(self, mock_get_conn, mock_upload, create_container):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.load_string("big string", "container", "blob", create_container,
max_connections=1)
mock_upload.assert_called_once_with(
@@ -389,7 +460,9 @@ class TestWasbHook:
)
@mock.patch.object(WasbHook, "download")
- def test_get_file(self, mock_download):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_file(self, mock_get_conn, mock_download):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
with mock.patch("builtins.open", mock.mock_open(read_data="data")):
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.get_file("path", "container", "blob", max_connections=1)
@@ -398,14 +471,18 @@ class TestWasbHook:
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "download")
- def test_read_file(self, mock_download, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_read_file(self, mock_get_conn, mock_download, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.read_file("container", "blob", max_connections=1)
mock_download.assert_called_once_with("container", "blob",
max_connections=1)
@pytest.mark.parametrize(argnames="create_container", argvalues=[True,
False])
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_upload(self, mock_service, create_container):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_upload(self, mock_get_conn, mock_service, create_container):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.upload(
container_name="mycontainer",
@@ -426,7 +503,9 @@ class TestWasbHook:
mock_container_client.assert_not_called()
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_download(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_download(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
blob_client = mock_service.return_value.get_blob_client
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.download(container_name="mycontainer", blob_name="myblob",
offset=2, length=4)
@@ -434,20 +513,26 @@ class TestWasbHook:
blob_client.return_value.download_blob.assert_called_once_with(offset=2,
length=4)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_get_container_client(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_container_client(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook._get_container_client("mycontainer")
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_get_blob_client(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_get_blob_client(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook._get_blob_client(container_name="mycontainer", blob_name="myblob")
mock_instance = mock_service.return_value.get_blob_client
mock_instance.assert_called_once_with(container="mycontainer",
blob="myblob")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_create_container(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_create_container(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.create_container(container_name="mycontainer")
mock_instance = mock_service.return_value.get_container_client
@@ -455,7 +540,9 @@ class TestWasbHook:
mock_instance.return_value.create_container.assert_called()
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_delete_container(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_container(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_container("mycontainer")
mock_service.return_value.get_container_client.assert_called_once_with("mycontainer")
@@ -463,7 +550,9 @@ class TestWasbHook:
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "delete_blobs")
- def test_delete_single_blob(self, delete_blobs, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_single_blob(self, mock_get_conn, delete_blobs,
mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_file("container", "blob", is_prefix=False)
delete_blobs.assert_called_once_with("container", "blob")
@@ -471,7 +560,9 @@ class TestWasbHook:
@mock.patch.object(WasbHook, "delete_blobs")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
- def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist,
mock_delete_blobs):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_multiple_blobs(self, mock_get_conn, mock_check,
mock_get_blobslist, mock_delete_blobs):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
mock_check.return_value = False
mock_get_blobslist.return_value = ["blob_prefix/blob1",
"blob_prefix/blob2"]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
@@ -487,7 +578,11 @@ class TestWasbHook:
@mock.patch.object(WasbHook, "delete_blobs")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
- def test_delete_more_than_256_blobs(self, mock_check, mock_get_blobslist,
mock_delete_blobs):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_more_than_256_blobs(
+ self, mock_get_conn, mock_check, mock_get_blobslist, mock_delete_blobs
+ ):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
mock_check.return_value = False
mock_get_blobslist.return_value = [f"blob_prefix/blob{i}" for i in
range(300)]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
@@ -502,7 +597,8 @@ class TestWasbHook:
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
@mock.patch.object(WasbHook, "get_blobs_list")
@mock.patch.object(WasbHook, "check_for_blob")
- def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs,
mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_nonexisting_blob_fails(self, mock_get_conn, mock_check,
mock_getblobs, mock_service):
mock_getblobs.return_value = []
mock_check.return_value = False
with pytest.raises(Exception) as ctx:
@@ -511,7 +607,9 @@ class TestWasbHook:
assert isinstance(ctx.value, AirflowException)
@mock.patch.object(WasbHook, "get_blobs_list")
- def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_delete_multiple_nonexisting_blobs_fails(self, mock_get_conn,
mock_getblobs):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
mock_getblobs.return_value = []
with pytest.raises(Exception) as ctx:
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
@@ -519,7 +617,9 @@ class TestWasbHook:
assert isinstance(ctx.value, AirflowException)
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_connection_success(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_connection_success(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.get_conn().get_account_information().return_value = {
"sku_name": "Standard_RAGRS",
@@ -531,7 +631,9 @@ class TestWasbHook:
assert msg == "Successfully connected to Azure Blob Storage."
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
- def test_connection_failure(self, mock_service):
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook.get_connection")
+ def test_connection_failure(self, mock_get_conn, mock_service):
+ mock_get_conn.return_value =
self.connection_map[self.shared_key_conn_id]
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.get_conn().get_account_information = mock.PropertyMock(
side_effect=Exception("Authentication failed.")