This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 23b15e6428 feat(providers/microsoft): add DefaultAzureCredential 
support to AzureContainerVolumeHook (#33822)
23b15e6428 is described below

commit 23b15e64284261211cfbcb9eaaf76c0c6a0be547
Author: Wei Lee <[email protected]>
AuthorDate: Wed Aug 30 17:59:21 2023 +0800

    feat(providers/microsoft): add DefaultAzureCredential support to 
AzureContainerVolumeHook (#33822)
    
    * feat(providers/microsoft): add DefaultAzureCredential support to 
AzureContainerVolumeHook
    
    * feat(providers/microsoft): pin azure-mgmt-storage to >= 16.0.0
    
    * docs(providers/microsoft): update connection documnetation for Azure 
Container Volume connection
---
 .../microsoft/azure/hooks/container_volume.py      | 31 +++++++++++++--
 airflow/providers/microsoft/azure/provider.yaml    |  5 +--
 .../connections/azure_container_volume.rst         |  2 +
 generated/provider_dependencies.json               |  1 +
 .../azure/hooks/test_azure_container_volume.py     | 46 ++++++++++++++++++++++
 5 files changed, 78 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py 
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index fd8f3fa4c7..3aa58415a4 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -19,7 +19,9 @@ from __future__ import annotations
 
 from typing import Any
 
+from azure.identity import DefaultAzureCredential
 from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
+from azure.mgmt.storage import StorageManagementClient
 
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import get_field
@@ -54,14 +56,22 @@ class AzureContainerVolumeHook(BaseHook):
     @staticmethod
     def get_connection_form_widgets() -> dict[str, Any]:
         """Returns connection widgets to add to connection form."""
-        from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
+        from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, 
BS3TextFieldWidget
         from flask_babel import lazy_gettext
-        from wtforms import PasswordField
+        from wtforms import PasswordField, StringField
 
         return {
             "connection_string": PasswordField(
                 lazy_gettext("Blob Storage Connection String (optional)"), 
widget=BS3PasswordFieldWidget()
             ),
+            "subscription_id": StringField(
+                lazy_gettext("Subscription ID (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
+            "resource_group": StringField(
+                lazy_gettext("Resource group name (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
         }
 
     @staticmethod
@@ -77,10 +87,12 @@ class AzureContainerVolumeHook(BaseHook):
                 "login": "client_id (token credentials auth)",
                 "password": "secret (token credentials auth)",
                 "connection_string": "connection string auth",
+                "subscription_id": "Subscription id (required for Azure AD 
authentication)",
+                "resource_group": "Resource group name (required for Azure AD 
authentication)",
             },
         }
 
-    def get_storagekey(self) -> str:
+    def get_storagekey(self, *, storage_account_name: str | None = None) -> 
str:
         """Get Azure File Volume storage key."""
         conn = self.get_connection(self.conn_id)
         extras = conn.extra_dejson
@@ -90,6 +102,17 @@ class AzureContainerVolumeHook(BaseHook):
                 key, value = keyvalue.split("=", 1)
                 if key == "AccountKey":
                     return value
+
+        subscription_id = self._get_field(extras, "subscription_id")
+        resource_group = self._get_field(extras, "resource_group")
+        if subscription_id and storage_account_name and resource_group:
+            credentials = DefaultAzureCredential()
+            storage_client = StorageManagementClient(credentials, 
subscription_id)
+            storage_account_list_keys_result = 
storage_client.storage_accounts.list_keys(
+                resource_group, storage_account_name
+            )
+            return 
storage_account_list_keys_result.as_dict()["keys"][0]["value"]
+
         return conn.password
 
     def get_file_volume(
@@ -102,6 +125,6 @@ class AzureContainerVolumeHook(BaseHook):
                 share_name=share_name,
                 storage_account_name=storage_account_name,
                 read_only=read_only,
-                storage_account_key=self.get_storagekey(),
+                
storage_account_key=self.get_storagekey(storage_account_name=storage_account_name),
             ),
         )
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index 2c9868bd84..ea481df6b5 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -14,12 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 ---
 package-name: apache-airflow-providers-microsoft-azure
 name: Microsoft Azure
 description: |
-    `Microsoft Azure <https://azure.microsoft.com/>`__
+  `Microsoft Azure <https://azure.microsoft.com/>`__
 suspended: false
 versions:
   - 6.3.0
@@ -76,6 +75,7 @@ dependencies:
   - azure-storage-blob>=12.14.0
   - azure-storage-common>=2.1.0
   - azure-storage-file>=2.1.0
+  - azure-mgmt-storage>=16.0.0
   - azure-servicebus>=7.6.1
   - azure-synapse-spark
   - adal>=1.2.7
@@ -258,7 +258,6 @@ transfers:
     how-to-guide: 
/docs/apache-airflow-providers-microsoft-azure/transfer/azure_blob_to_gcs.rst
     python-module: 
airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs
 
-
 connection-types:
   - hook-class-name: 
airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook
     connection-type: azure
diff --git 
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
 
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
index e3fe8efe76..2490010d99 100644
--- 
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
+++ 
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
@@ -62,6 +62,8 @@ Extra (optional)
     The following parameters are all optional:
 
     * ``connection_string``: Connection string for use with connection string 
authentication.
+    * ``subscription_id``: The ID of the subscription used for the initial 
connection. This is needed for Azure Active Directory (Azure AD) authentication.
+    * ``resource_group``: Azure Resource Group Name under which the desired 
Azure file volume resides. This is needed for Azure Active Directory (Azure AD) 
authentication.
 
 When specifying the connection in environment variable you should specify
 it using URI syntax.
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index f9bcf95be2..ed7578f178 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -564,6 +564,7 @@
       "azure-mgmt-datafactory>=1.0.0,<2.0",
       "azure-mgmt-datalake-store>=0.5.0",
       "azure-mgmt-resource>=2.2.0",
+      "azure-mgmt-storage>=16.0.0",
       "azure-servicebus>=7.6.1",
       "azure-storage-blob>=12.14.0",
       "azure-storage-common>=2.1.0",
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index b4c7b8d1c7..49a4ed9550 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
 import pytest
 
 from airflow.models import Connection
@@ -71,6 +73,48 @@ class TestAzureContainerVolumeHook:
         assert volume.azure_file.storage_account_name == "storage"
         assert volume.azure_file.read_only is True
 
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="azure_container_volume_test_default_azure-credential",
+                conn_type="wasb",
+                login="",
+                password="",
+                extra={"subscription_id": "subscription_id", "resource_group": 
"resource_group"},
+            )
+        ],
+        indirect=True,
+    )
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.DefaultAzureCredential")
+    def test_get_file_volume_default_azure_credential(
+        self, mocked_default_azure_credential, mocked_client, mocked_connection
+    ):
+        
mocked_client.return_value.storage_accounts.list_keys.return_value.as_dict.return_value
 = {
+            "keys": [
+                {
+                    "key_name": "key1",
+                    "value": "value",
+                    "permissions": "FULL",
+                    "creation_time": "2023-07-13T16:16:10.474107Z",
+                },
+            ]
+        }
+
+        hook = 
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
+        volume = hook.get_file_volume(
+            mount_name="mount", share_name="share", 
storage_account_name="storage", read_only=True
+        )
+        assert volume is not None
+        assert volume.name == "mount"
+        assert volume.azure_file.share_name == "share"
+        assert volume.azure_file.storage_account_key == "value"
+        assert volume.azure_file.storage_account_name == "storage"
+        assert volume.azure_file.read_only is True
+
+        mocked_default_azure_credential.assert_called_with()
+
     def test_get_ui_field_behaviour_placeholders(self):
         """
         Check that ensure_prefixes decorator working properly
@@ -81,6 +125,8 @@ class TestAzureContainerVolumeHook:
             "login",
             "password",
             "connection_string",
+            "subscription_id",
+            "resource_group",
         ]
         if 
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= 
(2, 5):
             raise Exception(

Reply via email to