This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 61d0093  Added sas_token var to BlobServiceClient return. Updated 
tests (#19234)
61d0093 is described below

commit 61d009305478e76e53aaf43ce07a181ebbd259d3
Author: Rocco Pascale <[email protected]>
AuthorDate: Wed Oct 27 06:50:15 2021 -0400

    Added sas_token var to BlobServiceClient return. Updated tests (#19234)
---
 airflow/providers/microsoft/azure/hooks/wasb.py    |  2 +-
 tests/providers/microsoft/azure/hooks/test_wasb.py | 44 ++++++++++++++++++++--
 2 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py 
b/airflow/providers/microsoft/azure/hooks/wasb.py
index bef2844..75dcf03 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -138,7 +138,7 @@ class WasbHook(BaseHook):
             return BlobServiceClient(account_url=conn.host, 
credential=token_credential)
         sas_token = extra.get('sas_token') or 
extra.get('extra__wasb__sas_token')
         if sas_token and sas_token.startswith('https'):
-            return BlobServiceClient(account_url=extra.get('sas_token'))
+            return BlobServiceClient(account_url=sas_token)
         if sas_token and not sas_token.startswith('https'):
             return 
BlobServiceClient(account_url=f"https://{conn.login}.blob.core.windows.net/"; + 
sas_token)
 
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py 
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 12f7462..ac8658e 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -25,6 +25,7 @@ from unittest import mock
 import pytest
 from azure.identity import ManagedIdentityCredential
 from azure.storage.blob import BlobServiceClient
+from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
@@ -47,6 +48,9 @@ class TestWasbHook(unittest.TestCase):
         self.shared_key_conn_id = 'azure_shared_key_test'
         self.ad_conn_id = 'azure_AD_test'
         self.sas_conn_id = 'sas_token_id'
+        self.extra__wasb__sas_conn_id = 'extra__sas_token_id'
+        self.http_sas_conn_id = 'http_sas_token_id'
+        self.extra__wasb__http_sas_conn_id = 'extra__http_sas_token_id'
         self.public_read_conn_id = 'pub_read_id'
         self.managed_identity_conn_id = 'managed_identity'
 
@@ -95,6 +99,27 @@ class TestWasbHook(unittest.TestCase):
                 extra=json.dumps({'sas_token': 'token'}),
             )
         )
+        db.merge_conn(
+            Connection(
+                conn_id=self.extra__wasb__sas_conn_id,
+                conn_type=self.connection_type,
+                extra=json.dumps({'extra__wasb__sas_token': 'token'}),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=self.http_sas_conn_id,
+                conn_type=self.connection_type,
+                extra=json.dumps({'sas_token': 
'https://login.blob.core.windows.net/token'}),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=self.extra__wasb__http_sas_conn_id,
+                conn_type=self.connection_type,
+                extra=json.dumps({'extra__wasb__sas_token': 
'https://login.blob.core.windows.net/token'}),
+            )
+        )
 
     def test_key(self):
         hook = WasbHook(wasb_conn_id='wasb_test_key')
@@ -119,9 +144,22 @@ class TestWasbHook(unittest.TestCase):
         self.assertIsInstance(hook.get_conn(), BlobServiceClient)
         self.assertIsInstance(hook.get_conn().credential, 
ManagedIdentityCredential)
 
-    def test_sas_token_connection(self):
-        hook = WasbHook(wasb_conn_id=self.sas_conn_id)
-        assert isinstance(hook.get_conn(), BlobServiceClient)
+    @parameterized.expand(
+        [
+            ('sas_conn_id', 'sas_token'),
+            ('extra__wasb__sas_conn_id', 'extra__wasb__sas_token'),
+            ('http_sas_conn_id', 'sas_token'),
+            ('extra__wasb__http_sas_conn_id', 'extra__wasb__sas_token'),
+        ],
+    )
+    def test_sas_token_connection(self, conn_id_str, extra_key):
+        conn_id = self.__getattribute__(conn_id_str)
+        hook = WasbHook(wasb_conn_id=conn_id)
+        conn = hook.get_conn()
+        hook_conn = hook.get_connection(hook.conn_id)
+        sas_token = hook_conn.extra_dejson[extra_key]
+        assert isinstance(conn, BlobServiceClient)
+        assert conn.url.endswith(sas_token + '/')
 
     
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     def test_check_for_blob(self, mock_service):

Reply via email to