This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 23b15e6428 feat(providers/microsoft): add DefaultAzureCredential
support to AzureContainerVolumeHook (#33822)
23b15e6428 is described below
commit 23b15e64284261211cfbcb9eaaf76c0c6a0be547
Author: Wei Lee <[email protected]>
AuthorDate: Wed Aug 30 17:59:21 2023 +0800
feat(providers/microsoft): add DefaultAzureCredential support to
AzureContainerVolumeHook (#33822)
* feat(providers/microsoft): add DefaultAzureCredential support to
AzureContainerVolumeHook
* feat(providers/microsoft): pin azure-mgmt-storage to >= 16.0.0
* docs(providers/microsoft): update connection documnetation for Azure
Container Volume connection
---
.../microsoft/azure/hooks/container_volume.py | 31 +++++++++++++--
airflow/providers/microsoft/azure/provider.yaml | 5 +--
.../connections/azure_container_volume.rst | 2 +
generated/provider_dependencies.json | 1 +
.../azure/hooks/test_azure_container_volume.py | 46 ++++++++++++++++++++++
5 files changed, 78 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py
b/airflow/providers/microsoft/azure/hooks/container_volume.py
index fd8f3fa4c7..3aa58415a4 100644
--- a/airflow/providers/microsoft/azure/hooks/container_volume.py
+++ b/airflow/providers/microsoft/azure/hooks/container_volume.py
@@ -19,7 +19,9 @@ from __future__ import annotations
from typing import Any
+from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
+from azure.mgmt.storage import StorageManagementClient
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field
@@ -54,14 +56,22 @@ class AzureContainerVolumeHook(BaseHook):
@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
- from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
+ from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget,
BS3TextFieldWidget
from flask_babel import lazy_gettext
- from wtforms import PasswordField
+ from wtforms import PasswordField, StringField
return {
"connection_string": PasswordField(
lazy_gettext("Blob Storage Connection String (optional)"),
widget=BS3PasswordFieldWidget()
),
+ "subscription_id": StringField(
+ lazy_gettext("Subscription ID (optional)"),
+ widget=BS3TextFieldWidget(),
+ ),
+ "resource_group": StringField(
+ lazy_gettext("Resource group name (optional)"),
+ widget=BS3TextFieldWidget(),
+ ),
}
@staticmethod
@@ -77,10 +87,12 @@ class AzureContainerVolumeHook(BaseHook):
"login": "client_id (token credentials auth)",
"password": "secret (token credentials auth)",
"connection_string": "connection string auth",
+ "subscription_id": "Subscription id (required for Azure AD
authentication)",
+ "resource_group": "Resource group name (required for Azure AD
authentication)",
},
}
- def get_storagekey(self) -> str:
+ def get_storagekey(self, *, storage_account_name: str | None = None) ->
str:
"""Get Azure File Volume storage key."""
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
@@ -90,6 +102,17 @@ class AzureContainerVolumeHook(BaseHook):
key, value = keyvalue.split("=", 1)
if key == "AccountKey":
return value
+
+ subscription_id = self._get_field(extras, "subscription_id")
+ resource_group = self._get_field(extras, "resource_group")
+ if subscription_id and storage_account_name and resource_group:
+ credentials = DefaultAzureCredential()
+ storage_client = StorageManagementClient(credentials,
subscription_id)
+ storage_account_list_keys_result =
storage_client.storage_accounts.list_keys(
+ resource_group, storage_account_name
+ )
+ return
storage_account_list_keys_result.as_dict()["keys"][0]["value"]
+
return conn.password
def get_file_volume(
@@ -102,6 +125,6 @@ class AzureContainerVolumeHook(BaseHook):
share_name=share_name,
storage_account_name=storage_account_name,
read_only=read_only,
- storage_account_key=self.get_storagekey(),
+
storage_account_key=self.get_storagekey(storage_account_name=storage_account_name),
),
)
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index 2c9868bd84..ea481df6b5 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -14,12 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
---
package-name: apache-airflow-providers-microsoft-azure
name: Microsoft Azure
description: |
- `Microsoft Azure <https://azure.microsoft.com/>`__
+ `Microsoft Azure <https://azure.microsoft.com/>`__
suspended: false
versions:
- 6.3.0
@@ -76,6 +75,7 @@ dependencies:
- azure-storage-blob>=12.14.0
- azure-storage-common>=2.1.0
- azure-storage-file>=2.1.0
+ - azure-mgmt-storage>=16.0.0
- azure-servicebus>=7.6.1
- azure-synapse-spark
- adal>=1.2.7
@@ -258,7 +258,6 @@ transfers:
how-to-guide:
/docs/apache-airflow-providers-microsoft-azure/transfer/azure_blob_to_gcs.rst
python-module:
airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs
-
connection-types:
- hook-class-name:
airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook
connection-type: azure
diff --git
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
index e3fe8efe76..2490010d99 100644
---
a/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
+++
b/docs/apache-airflow-providers-microsoft-azure/connections/azure_container_volume.rst
@@ -62,6 +62,8 @@ Extra (optional)
The following parameters are all optional:
* ``connection_string``: Connection string for use with connection string
authentication.
+ * ``subscription_id``: The ID of the subscription used for the initial
connection. This is needed for Azure Active Directory (Azure AD) authentication.
+ * ``resource_group``: Azure Resource Group Name under which the desired
Azure file volume resides. This is needed for Azure Active Directory (Azure AD)
authentication.
When specifying the connection in environment variable you should specify
it using URI syntax.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index f9bcf95be2..ed7578f178 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -564,6 +564,7 @@
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
"azure-mgmt-resource>=2.2.0",
+ "azure-mgmt-storage>=16.0.0",
"azure-servicebus>=7.6.1",
"azure-storage-blob>=12.14.0",
"azure-storage-common>=2.1.0",
diff --git
a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index b4c7b8d1c7..49a4ed9550 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
from airflow.models import Connection
@@ -71,6 +73,48 @@ class TestAzureContainerVolumeHook:
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True
+ @pytest.mark.parametrize(
+ "mocked_connection",
+ [
+ Connection(
+ conn_id="azure_container_volume_test_default_azure-credential",
+ conn_type="wasb",
+ login="",
+ password="",
+ extra={"subscription_id": "subscription_id", "resource_group":
"resource_group"},
+ )
+ ],
+ indirect=True,
+ )
+
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
+
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.DefaultAzureCredential")
+ def test_get_file_volume_default_azure_credential(
+ self, mocked_default_azure_credential, mocked_client, mocked_connection
+ ):
+
mocked_client.return_value.storage_accounts.list_keys.return_value.as_dict.return_value
= {
+ "keys": [
+ {
+ "key_name": "key1",
+ "value": "value",
+ "permissions": "FULL",
+ "creation_time": "2023-07-13T16:16:10.474107Z",
+ },
+ ]
+ }
+
+ hook =
AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
+ volume = hook.get_file_volume(
+ mount_name="mount", share_name="share",
storage_account_name="storage", read_only=True
+ )
+ assert volume is not None
+ assert volume.name == "mount"
+ assert volume.azure_file.share_name == "share"
+ assert volume.azure_file.storage_account_key == "value"
+ assert volume.azure_file.storage_account_name == "storage"
+ assert volume.azure_file.read_only is True
+
+ mocked_default_azure_credential.assert_called_with()
+
def test_get_ui_field_behaviour_placeholders(self):
"""
Check that ensure_prefixes decorator working properly
@@ -81,6 +125,8 @@ class TestAzureContainerVolumeHook:
"login",
"password",
"connection_string",
+ "subscription_id",
+ "resource_group",
]
if
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >=
(2, 5):
raise Exception(