This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 76c2ade2c6 feat(providers/microsoft): add DefaultAzureCredential to 
data_lake (#33433)
76c2ade2c6 is described below

commit 76c2ade2c63abc3677b8fcd59af6f8779b613be7
Author: Wei Lee <[email protected]>
AuthorDate: Mon Aug 28 18:39:59 2023 +0800

    feat(providers/microsoft): add DefaultAzureCredential to data_lake (#33433)
---
 .../providers/microsoft/azure/hooks/data_lake.py   | 31 ++++++++-----
 .../microsoft/azure/hooks/test_azure_data_lake.py  | 53 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py 
b/airflow/providers/microsoft/azure/hooks/data_lake.py
index 3849727e86..0b344a41ad 100644
--- a/airflow/providers/microsoft/azure/hooks/data_lake.py
+++ b/airflow/providers/microsoft/azure/hooks/data_lake.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from functools import cached_property
-from typing import Any
+from typing import Any, Union
 
 from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
 from azure.datalake.store import core, lib, multithread
@@ -34,7 +34,9 @@ from azure.storage.filedatalake import (
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import get_field
+from airflow.providers.microsoft.azure.utils import 
AzureIdentityCredentialAdapter, get_field
+
+Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
 
 
 class AzureDataLakeHook(BaseHook):
@@ -110,9 +112,14 @@ class AzureDataLakeHook(BaseHook):
             conn = self.get_connection(self.conn_id)
             extras = conn.extra_dejson
             self.account_name = self._get_field(extras, "account_name")
+
+            credential: Credentials
             tenant = self._get_field(extras, "tenant")
-            adl_creds = lib.auth(tenant_id=tenant, 
client_secret=conn.password, client_id=conn.login)
-            self._conn = core.AzureDLFileSystem(adl_creds, 
store_name=self.account_name)
+            if tenant:
+                credential = lib.auth(tenant_id=tenant, 
client_secret=conn.password, client_id=conn.login)
+            else:
+                credential = AzureIdentityCredentialAdapter()
+            self._conn = core.AzureDLFileSystem(credential, 
store_name=self.account_name)
             self._conn.connect()
         return self._conn
 
@@ -313,20 +320,22 @@ class AzureDataLakeStorageV2Hook(BaseHook):
             # connection_string auth takes priority
             return 
DataLakeServiceClient.from_connection_string(connection_string, **extra)
 
+        credential: Credentials
         tenant = self._get_field(extra, "tenant_id")
         if tenant:
             # use Active Directory auth
             app_id = conn.login
             app_secret = conn.password
-            token_credential = ClientSecretCredential(tenant, app_id, 
app_secret)
-            return DataLakeServiceClient(
-                account_url=f"https://{conn.host}.dfs.core.windows.net";, 
credential=token_credential, **extra
-            )
+            credential = ClientSecretCredential(tenant, app_id, app_secret)
+        elif conn.password:
+            credential = conn.password
+        else:
+            credential = AzureIdentityCredentialAdapter()
 
-        # otherwise, use key auth
-        credential = conn.password
         return DataLakeServiceClient(
-            account_url=f"https://{conn.host}.dfs.core.windows.net";, 
credential=credential, **extra
+            account_url=f"https://{conn.host}.dfs.core.windows.net";,
+            credential=credential,  # type: ignore[arg-type]
+            **extra,
         )
 
     def _get_field(self, extra_dict, field_name):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py 
b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index f5e2e8be5c..e1412ca4a4 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -26,6 +26,34 @@ from azure.storage.filedatalake._models import 
FileSystemProperties
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeStorageV2Hook
 
+MODULE = "airflow.providers.microsoft.azure.hooks.data_lake"
+
+
[email protected]
+def connection_without_tenant(create_mock_connections):
+    create_mock_connections(
+        Connection(
+            conn_id="adl_test_key_without_tenant",
+            conn_type="azure_data_lake",
+            login="client_id",
+            password="client secret",
+            extra={"account_name": "accountname"},
+        )
+    )
+
+
[email protected]
+def connection(create_mock_connections):
+    create_mock_connections(
+        Connection(
+            conn_id="adl_test_key",
+            conn_type="azure_data_lake",
+            login="client_id",
+            password="client secret",
+            extra={"tenant": "tenant", "account_name": "accountname"},
+        )
+    )
+
 
 class TestAzureDataLakeHook:
     @pytest.fixture(autouse=True)
@@ -52,6 +80,26 @@ class TestAzureDataLakeHook:
         assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
         assert mock_lib.auth.called
 
+    @pytest.mark.usefixtures("connection_without_tenant")
+    @mock.patch(f"{MODULE}.lib")
+    @mock.patch(f"{MODULE}.AzureIdentityCredentialAdapter")
+    def 
test_fallback_to_azure_identity_credential_adppter_when_tenant_is_not_provided(
+        self,
+        mock_azure_identity_credential_adapter,
+        mock_datalake_store_lib,
+    ):
+        from azure.datalake.store import core
+
+        from airflow.providers.microsoft.azure.hooks.data_lake import 
AzureDataLakeHook
+
+        hook = 
AzureDataLakeHook(azure_data_lake_conn_id="adl_test_key_without_tenant")
+        assert hook._conn is None
+        assert hook.conn_id == "adl_test_key_without_tenant"
+        assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
+        assert mock_azure_identity_credential_adapter.called
+        assert not mock_datalake_store_lib.auth.called
+
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_check_for_blob(self, mock_lib, mock_filesystem):
@@ -62,6 +110,7 @@ class TestAzureDataLakeHook:
         hook.check_for_file("file_path")
         mocked_glob.assert_called()
 
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_upload_file(self, mock_lib, mock_uploader):
@@ -86,6 +135,7 @@ class TestAzureDataLakeHook:
             blocksize=4194304,
         )
 
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLDownloader",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_download_file(self, mock_lib, mock_downloader):
@@ -110,6 +160,7 @@ class TestAzureDataLakeHook:
             blocksize=4194304,
         )
 
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_list_glob(self, mock_lib, mock_fs):
@@ -119,6 +170,7 @@ class TestAzureDataLakeHook:
         hook.list("file_path/*")
         mock_fs.return_value.glob.assert_called_once_with("file_path/*")
 
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_list_walk(self, mock_lib, mock_fs):
@@ -128,6 +180,7 @@ class TestAzureDataLakeHook:
         hook.list("file_path/some_folder/")
         
mock_fs.return_value.walk.assert_called_once_with("file_path/some_folder/")
 
+    @pytest.mark.usefixtures("connection")
     
@mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem",
 autospec=True)
     @mock.patch("airflow.providers.microsoft.azure.hooks.data_lake.lib", 
autospec=True)
     def test_remove(self, mock_lib, mock_fs):

Reply via email to