This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 46ee1c2c8d Fix where account url is build if not provided using login 
(account name) (#32082)
46ee1c2c8d is described below

commit 46ee1c2c8d3d0e5793f42fd10bcd80150caa538b
Author: Akash Sharma <[email protected]>
AuthorDate: Wed Jun 28 04:30:11 2023 +0530

    Fix where account url is build if not provided using login (account name) 
(#32082)
---
 airflow/providers/microsoft/azure/hooks/wasb.py    | 69 +++++++++++-----------
 .../connections/wasb.rst                           | 22 ++++---
 tests/providers/microsoft/azure/hooks/test_wasb.py | 53 ++++++++++++++++-
 3 files changed, 102 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index fbad1627be..c1d615810b 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -126,7 +126,7 @@ class WasbHook(BaseHook):
     def get_ui_field_behaviour() -> dict[str, Any]:
         """Returns custom field behaviour."""
         return {
-            "hidden_fields": ["schema", "port", "extra"],
+            "hidden_fields": ["schema", "port"],
             "relabeling": {
                 "login": "Blob Storage Login (optional)",
                 "password": "Blob Storage Key (optional)",
@@ -140,6 +140,7 @@ class WasbHook(BaseHook):
                 "tenant_id": "tenant",
                 "shared_access_key": "shared access key",
                 "sas_token": "account url or token",
+                "extra": "additional options for use with 
ClientSecretCredential or DefaultAzureCredential",
             },
         }
 
@@ -176,22 +177,11 @@ class WasbHook(BaseHook):
         extra = conn.extra_dejson or {}
         client_secret_auth_config = extra.pop("client_secret_auth_config", {})
 
-        if self.public_read:
-            # Here we use anonymous public read
-            # more info
-            # 
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
-            return BlobServiceClient(account_url=conn.host, **extra)
-
         connection_string = self._get_field(extra, "connection_string")
         if connection_string:
             # connection_string auth takes priority
             return BlobServiceClient.from_connection_string(connection_string, 
**extra)
 
-        shared_access_key = self._get_field(extra, "shared_access_key")
-        if shared_access_key:
-            # using shared access key
-            return BlobServiceClient(account_url=conn.host, 
credential=shared_access_key, **extra)
-
         tenant = self._get_field(extra, "tenant_id")
         if tenant:
             # use Active Directory auth
@@ -200,14 +190,25 @@ class WasbHook(BaseHook):
             token_credential = ClientSecretCredential(tenant, app_id, 
app_secret, **client_secret_auth_config)
             return BlobServiceClient(account_url=conn.host, 
credential=token_credential, **extra)
 
+        account_url = conn.host if conn.host else 
f"https://{conn.login}.blob.core.windows.net/";
+
+        if self.public_read:
+            # Here we use anonymous public read
+            # more info
+            # 
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+            return BlobServiceClient(account_url=account_url, **extra)
+
+        shared_access_key = self._get_field(extra, "shared_access_key")
+        if shared_access_key:
+            # using shared access key
+            return BlobServiceClient(account_url=account_url, 
credential=shared_access_key, **extra)
+
         sas_token = self._get_field(extra, "sas_token")
         if sas_token:
             if sas_token.startswith("https"):
                 return BlobServiceClient(account_url=sas_token, **extra)
             else:
-                return BlobServiceClient(
-                    
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}";, **extra
-                )
+                return 
BlobServiceClient(account_url=f"{account_url}/{sas_token}", **extra)
 
         # Fall back to old auth (password) or use managed identity if not 
provided.
         credential = conn.password
@@ -215,7 +216,7 @@ class WasbHook(BaseHook):
             credential = DefaultAzureCredential()
             self.log.info("Using DefaultAzureCredential as credential")
         return BlobServiceClient(
-            account_url=f"https://{conn.login}.blob.core.windows.net/";,
+            account_url=account_url,
             credential=credential,
             **extra,
         )
@@ -545,13 +546,6 @@ class WasbAsyncHook(WasbHook):
         extra = conn.extra_dejson or {}
         client_secret_auth_config = extra.pop("client_secret_auth_config", {})
 
-        if self.public_read:
-            # Here we use anonymous public read
-            # more info
-            # 
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
-            self.blob_service_client = 
AsyncBlobServiceClient(account_url=conn.host, **extra)
-            return self.blob_service_client
-
         connection_string = self._get_field(extra, "connection_string")
         if connection_string:
             # connection_string auth takes priority
@@ -560,14 +554,6 @@ class WasbAsyncHook(WasbHook):
             )
             return self.blob_service_client
 
-        shared_access_key = self._get_field(extra, "shared_access_key")
-        if shared_access_key:
-            # using shared access key
-            self.blob_service_client = AsyncBlobServiceClient(
-                account_url=conn.host, credential=shared_access_key, **extra
-            )
-            return self.blob_service_client
-
         tenant = self._get_field(extra, "tenant_id")
         if tenant:
             # use Active Directory auth
@@ -581,13 +567,30 @@ class WasbAsyncHook(WasbHook):
             )
             return self.blob_service_client
 
+        account_url = conn.host if conn.host else 
f"https://{conn.login}.blob.core.windows.net/";
+
+        if self.public_read:
+            # Here we use anonymous public read
+            # more info
+            # 
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+            self.blob_service_client = 
AsyncBlobServiceClient(account_url=account_url, **extra)
+            return self.blob_service_client
+
+        shared_access_key = self._get_field(extra, "shared_access_key")
+        if shared_access_key:
+            # using shared access key
+            self.blob_service_client = AsyncBlobServiceClient(
+                account_url=account_url, credential=shared_access_key, **extra
+            )
+            return self.blob_service_client
+
         sas_token = self._get_field(extra, "sas_token")
         if sas_token:
             if sas_token.startswith("https"):
                 self.blob_service_client = 
AsyncBlobServiceClient(account_url=sas_token, **extra)
             else:
                 self.blob_service_client = AsyncBlobServiceClient(
-                    
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}";, **extra
+                    account_url=f"{account_url}/{sas_token}", **extra
                 )
             return self.blob_service_client
 
@@ -597,7 +600,7 @@ class WasbAsyncHook(WasbHook):
             credential = AsyncDefaultAzureCredential()
             self.log.info("Using DefaultAzureCredential as credential")
         self.blob_service_client = AsyncBlobServiceClient(
-            account_url=f"https://{conn.login}.blob.core.windows.net/";,
+            account_url=account_url,
             credential=credential,
             **extra,
         )
diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst 
b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst
index ce057e1592..8efdeef362 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst
@@ -54,23 +54,31 @@ Configuring the Connection
 --------------------------
 
 Login (optional)
-    Specify the login used for azure blob storage. For use with Shared Key 
Credential and SAS Token authentication.
+    Specify the login used for Azure Blob Storage. Strictly needed for Active 
Directory (token) authentication as Service principle credential. Optional for 
the rest if host (account url) is specified.
 
 Password (optional)
-    Specify the password used for azure blob storage. For use with
+    Specify the password used for Azure Blob Storage. For use with
     Active Directory (token credential) and shared key authentication.
 
 Host (optional)
-    Specify the account url for anonymous public read, Active Directory, 
shared access key authentication.
+    Specify the account url for Azure Blob Storage. Strictly needed for Active 
Directory (token) authentication as Service principle credential. Optional for 
the rest if login (account name) is specified.
+
+Blob Storage Connection String (optional)
+    Connection string for use with connection string authentication.
+
+Blob Storage Shared Access Key (optional)
+    Specify the shared access key. Needed only for shared access key 
authentication.
+
+SAS Token (optional)
+    SAS Token for use with SAS Token authentication.
+
+Tenant Id (Active Directory Auth) (optional)
+    Specify the tenant to use. Required only for Active Directory (token) 
authentication.
 
 Extra (optional)
     Specify the extra parameters (as json dictionary) that can be used in 
Azure connection.
     The following parameters are all optional:
 
-    * ``tenant_id``: Specify the tenant to use. Needed for Active Directory 
(token) authentication.
-    * ``shared_access_key``: Specify the shared access key. Needed for shared 
access key authentication.
-    * ``connection_string``: Connection string for use with connection string 
authentication.
-    * ``sas_token``: SAS Token for use with SAS Token authentication.
     * ``client_secret_auth_config``: Extra config to pass while authenticating 
as a service principal using `ClientSecretCredential 
<https://learn.microsoft.com/en-in/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python>`_
 
 When specifying the connection in environment variable you should specify
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py 
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 6d837a2fa2..464db0f39f 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -40,16 +40,19 @@ ACCESS_KEY_STRING = "AccountName=name;skdkskd"
 
 class TestWasbHook:
     def setup_method(self):
-        db.merge_conn(Connection(conn_id="wasb_test_key", conn_type="wasb", 
login="login", password="key"))
+        self.login = "login"
+        self.wasb_test_key = "wasb_test_key"
         self.connection_type = "wasb"
         self.connection_string_id = "azure_test_connection_string"
         self.shared_key_conn_id = "azure_shared_key_test"
+        self.shared_key_conn_id_without_host = 
"azure_shared_key_test_wihout_host"
         self.ad_conn_id = "azure_AD_test"
         self.sas_conn_id = "sas_token_id"
         self.extra__wasb__sas_conn_id = "extra__sas_token_id"
         self.http_sas_conn_id = "http_sas_token_id"
         self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id"
         self.public_read_conn_id = "pub_read_id"
+        self.public_read_conn_id_without_host = "pub_read_id_without_host"
         self.managed_identity_conn_id = "managed_identity"
         self.authority = "https://test_authority.com";
 
@@ -60,6 +63,14 @@ class TestWasbHook:
             "authority": self.authority,
         }
 
+        db.merge_conn(
+            Connection(
+                conn_id=self.wasb_test_key,
+                conn_type=self.connection_type,
+                login=self.login,
+                password="key",
+            )
+        )
         db.merge_conn(
             Connection(
                 conn_id=self.public_read_conn_id,
@@ -68,7 +79,14 @@ class TestWasbHook:
                 extra=json.dumps({"proxies": self.proxies}),
             )
         )
-
+        db.merge_conn(
+            Connection(
+                conn_id=self.public_read_conn_id_without_host,
+                conn_type=self.connection_type,
+                login=self.login,
+                extra=json.dumps({"proxies": self.proxies}),
+            )
+        )
         db.merge_conn(
             Connection(
                 conn_id=self.connection_string_id,
@@ -84,6 +102,14 @@ class TestWasbHook:
                 extra=json.dumps({"shared_access_key": "token", "proxies": 
self.proxies}),
             )
         )
+        db.merge_conn(
+            Connection(
+                conn_id=self.shared_key_conn_id_without_host,
+                conn_type=self.connection_type,
+                login=self.login,
+                extra=json.dumps({"shared_access_key": "token", "proxies": 
self.proxies}),
+            )
+        )
         db.merge_conn(
             Connection(
                 conn_id=self.ad_conn_id,
@@ -111,6 +137,7 @@ class TestWasbHook:
             Connection(
                 conn_id=self.sas_conn_id,
                 conn_type=self.connection_type,
+                login=self.login,
                 extra=json.dumps({"sas_token": "token", "proxies": 
self.proxies}),
             )
         )
@@ -118,6 +145,7 @@ class TestWasbHook:
             Connection(
                 conn_id=self.extra__wasb__sas_conn_id,
                 conn_type=self.connection_type,
+                login=self.login,
                 extra=json.dumps({"extra__wasb__sas_token": "token", 
"proxies": self.proxies}),
             )
         )
@@ -171,6 +199,23 @@ class TestWasbHook:
         assert isinstance(hook.get_conn(), BlobServiceClient)
         assert isinstance(hook.get_conn().credential, ClientSecretCredential)
 
+    @pytest.mark.parametrize(
+        argnames="conn_id_str",
+        argvalues=[
+            "wasb_test_key",
+            "shared_key_conn_id_without_host",
+            "public_read_conn_id_without_host",
+        ],
+    )
+    def test_account_url_without_host(self, conn_id_str):
+        conn_id = self.__getattribute__(conn_id_str)
+        hook = WasbHook(wasb_conn_id=conn_id)
+        hook_conn = hook.get_connection(hook.conn_id)
+        conn = hook.get_conn()
+        assert conn.url.startswith("https://";)
+        assert conn.url.__contains__(hook_conn.login)
+        assert conn.url.endswith(".blob.core.windows.net/")
+
     @pytest.mark.parametrize(
         argnames="conn_id_str, extra_key",
         argvalues=[
@@ -187,6 +232,9 @@ class TestWasbHook:
         hook_conn = hook.get_connection(hook.conn_id)
         sas_token = hook_conn.extra_dejson[extra_key]
         assert isinstance(conn, BlobServiceClient)
+        assert conn.url.startswith("https://";)
+        if hook_conn.login:
+            assert conn.url.__contains__(hook_conn.login)
         assert conn.url.endswith(sas_token + "/")
 
     @pytest.mark.parametrize(
@@ -459,4 +507,5 @@ class TestWasbHook:
             "extra__wasb__tenant_id",
             "extra__wasb__shared_access_key",
             "extra__wasb__sas_token",
+            "extra",
         ]

Reply via email to