This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 61d0093 Added sas_token var to BlobServiceClient return. Updated
tests (#19234)
61d0093 is described below
commit 61d009305478e76e53aaf43ce07a181ebbd259d3
Author: Rocco Pascale <[email protected]>
AuthorDate: Wed Oct 27 06:50:15 2021 -0400
Added sas_token var to BlobServiceClient return. Updated tests (#19234)
---
airflow/providers/microsoft/azure/hooks/wasb.py | 2 +-
tests/providers/microsoft/azure/hooks/test_wasb.py | 44 ++++++++++++++++++++--
2 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py
b/airflow/providers/microsoft/azure/hooks/wasb.py
index bef2844..75dcf03 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -138,7 +138,7 @@ class WasbHook(BaseHook):
return BlobServiceClient(account_url=conn.host,
credential=token_credential)
sas_token = extra.get('sas_token') or
extra.get('extra__wasb__sas_token')
if sas_token and sas_token.startswith('https'):
- return BlobServiceClient(account_url=extra.get('sas_token'))
+ return BlobServiceClient(account_url=sas_token)
if sas_token and not sas_token.startswith('https'):
return
BlobServiceClient(account_url=f"https://{conn.login}.blob.core.windows.net/" +
sas_token)
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py
b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 12f7462..ac8658e 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -25,6 +25,7 @@ from unittest import mock
import pytest
from azure.identity import ManagedIdentityCredential
from azure.storage.blob import BlobServiceClient
+from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -47,6 +48,9 @@ class TestWasbHook(unittest.TestCase):
self.shared_key_conn_id = 'azure_shared_key_test'
self.ad_conn_id = 'azure_AD_test'
self.sas_conn_id = 'sas_token_id'
+ self.extra__wasb__sas_conn_id = 'extra__sas_token_id'
+ self.http_sas_conn_id = 'http_sas_token_id'
+ self.extra__wasb__http_sas_conn_id = 'extra__http_sas_token_id'
self.public_read_conn_id = 'pub_read_id'
self.managed_identity_conn_id = 'managed_identity'
@@ -95,6 +99,27 @@ class TestWasbHook(unittest.TestCase):
extra=json.dumps({'sas_token': 'token'}),
)
)
+ db.merge_conn(
+ Connection(
+ conn_id=self.extra__wasb__sas_conn_id,
+ conn_type=self.connection_type,
+ extra=json.dumps({'extra__wasb__sas_token': 'token'}),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=self.http_sas_conn_id,
+ conn_type=self.connection_type,
+ extra=json.dumps({'sas_token':
'https://login.blob.core.windows.net/token'}),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=self.extra__wasb__http_sas_conn_id,
+ conn_type=self.connection_type,
+ extra=json.dumps({'extra__wasb__sas_token':
'https://login.blob.core.windows.net/token'}),
+ )
+ )
def test_key(self):
hook = WasbHook(wasb_conn_id='wasb_test_key')
@@ -119,9 +144,22 @@ class TestWasbHook(unittest.TestCase):
self.assertIsInstance(hook.get_conn(), BlobServiceClient)
self.assertIsInstance(hook.get_conn().credential,
ManagedIdentityCredential)
- def test_sas_token_connection(self):
- hook = WasbHook(wasb_conn_id=self.sas_conn_id)
- assert isinstance(hook.get_conn(), BlobServiceClient)
+ @parameterized.expand(
+ [
+ ('sas_conn_id', 'sas_token'),
+ ('extra__wasb__sas_conn_id', 'extra__wasb__sas_token'),
+ ('http_sas_conn_id', 'sas_token'),
+ ('extra__wasb__http_sas_conn_id', 'extra__wasb__sas_token'),
+ ],
+ )
+ def test_sas_token_connection(self, conn_id_str, extra_key):
+ conn_id = self.__getattribute__(conn_id_str)
+ hook = WasbHook(wasb_conn_id=conn_id)
+ conn = hook.get_conn()
+ hook_conn = hook.get_connection(hook.conn_id)
+ sas_token = hook_conn.extra_dejson[extra_key]
+ assert isinstance(conn, BlobServiceClient)
+ assert conn.url.endswith(sas_token + '/')
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
def test_check_for_blob(self, mock_service):