This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 539797fdfb Add DefaultAzureCredential support to 
AzureContainerRegistryHook (#33825)
539797fdfb is described below

commit 539797fdfb2e0b2aca82376095e74edaad775439
Author: Wei Lee <[email protected]>
AuthorDate: Wed Aug 30 17:59:53 2023 +0800

    Add DefaultAzureCredential support to AzureContainerRegistryHook (#33825)
    
    * feat(providers/microsoft): add DefaultAzureCredential support to 
AzureContainerRegistryHook
    
    * feat(providers/microsoft): pin azure-mgmt-containerregistry to >= 8.0.0
    
    * docs(providers/microsoft): update connection documnetation for Azure 
Container Registry connection
---
 .../microsoft/azure/hooks/container_registry.py    | 44 +++++++++++++++++++++-
 airflow/providers/microsoft/azure/provider.yaml    |  2 +
 .../connections/acr.rst                            | 10 +++++
 generated/provider_dependencies.json               |  1 +
 .../azure/hooks/test_azure_container_registry.py   | 38 +++++++++++++++++++
 5 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py 
b/airflow/providers/microsoft/azure/hooks/container_registry.py
index c1217e3a86..2b9383e5d3 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -21,9 +21,12 @@ from __future__ import annotations
 from functools import cached_property
 from typing import Any
 
+from azure.identity import DefaultAzureCredential
 from azure.mgmt.containerinstance.models import ImageRegistryCredential
+from azure.mgmt.containerregistry import ContainerRegistryManagementClient
 
 from airflow.hooks.base import BaseHook
+from airflow.providers.microsoft.azure.utils import get_field
 
 
 class AzureContainerRegistryHook(BaseHook):
@@ -40,6 +43,24 @@ class AzureContainerRegistryHook(BaseHook):
     conn_type = "azure_container_registry"
     hook_name = "Azure Container Registry"
 
+    @staticmethod
+    def get_connection_form_widgets() -> dict[str, Any]:
+        """Returns connection widgets to add to connection form."""
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import StringField
+
+        return {
+            "subscription_id": StringField(
+                lazy_gettext("Subscription ID (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
+            "resource_group": StringField(
+                lazy_gettext("Resource group name (optional)"),
+                widget=BS3TextFieldWidget(),
+            ),
+        }
+
     @classmethod
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Returns custom field behaviour."""
@@ -54,6 +75,8 @@ class AzureContainerRegistryHook(BaseHook):
                 "login": "private registry username",
                 "password": "private registry password",
                 "host": "docker image registry server",
+                "subscription_id": "Subscription id (required for Azure AD 
authentication)",
+                "resource_group": "Resource group name (required for Azure AD 
authentication)",
             },
         }
 
@@ -61,10 +84,29 @@ class AzureContainerRegistryHook(BaseHook):
         super().__init__()
         self.conn_id = conn_id
 
+    def _get_field(self, extras, name):
+        return get_field(
+            conn_id=self.conn_id,
+            conn_type=self.conn_type,
+            extras=extras,
+            field_name=name,
+        )
+
     @cached_property
     def connection(self) -> ImageRegistryCredential:
         return self.get_conn()
 
     def get_conn(self) -> ImageRegistryCredential:
         conn = self.get_connection(self.conn_id)
-        return ImageRegistryCredential(server=conn.host, username=conn.login, 
password=conn.password)
+        password = conn.password
+        if not password:
+            extras = conn.extra_dejson
+            subscription_id = self._get_field(extras, "subscription_id")
+            resource_group = self._get_field(extras, "resource_group")
+            client = ContainerRegistryManagementClient(
+                credential=DefaultAzureCredential(), 
subscription_id=subscription_id
+            )
+            credentials = client.registries.list_credentials(resource_group, 
conn.login).as_dict()
+            password = credentials["passwords"][0]["value"]
+
+        return ImageRegistryCredential(server=conn.host, username=conn.login, 
password=password)
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index ea481df6b5..40673fb1eb 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -81,6 +81,8 @@ dependencies:
   - adal>=1.2.7
   - azure-storage-file-datalake>=12.9.1
   - azure-kusto-data>=4.1.0
+
+  - azure-mgmt-containerregistry>=8.0.0
   # TODO: upgrade to newer versions of all the below libraries.
   #   See issue https://github.com/apache/airflow/issues/30199
   - azure-mgmt-containerinstance>=7.0.0,<9.0.0
diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst 
b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
index 3944519998..ff4c6dfd65 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
@@ -50,6 +50,16 @@ Password
 Host
     Specify the Image Registry Server used for the initial connection.
 
+Subscription ID
+    Specify the ID of the subscription used for the initial connection.
+    This is needed for Azure Active Directory (Azure AD) authentication.
+    Use extra param ``subscription_id`` to pass in the Azure subscription ID.
+
+Resource Group Name (optional)
+    Specify the Azure Resource Group Name under which the desired Azure 
container registry resides.
+    This is needed for Azure Active Directory (Azure AD) authentication.
+    Use extra param ``resource_group`` to pass in the resource group name.
+
 When specifying the connection in environment variable you should specify
 it using URI syntax.
 
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index ed7578f178..8098cdd629 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -560,6 +560,7 @@
       "azure-keyvault-secrets>=4.1.0",
       "azure-kusto-data>=4.1.0",
       "azure-mgmt-containerinstance>=7.0.0,<9.0.0",
+      "azure-mgmt-containerregistry>=8.0.0",
       "azure-mgmt-cosmosdb",
       "azure-mgmt-datafactory>=1.0.0,<2.0",
       "azure-mgmt-datalake-store>=0.5.0",
diff --git 
a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py 
b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 38f326d298..a2b0635749 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
 import pytest
 
 from airflow.models import Connection
@@ -43,3 +45,39 @@ class TestAzureContainerRegistryHook:
         assert hook.connection.username == "myuser"
         assert hook.connection.password == "password"
         assert hook.connection.server == "test.cr"
+
+    @pytest.mark.parametrize(
+        "mocked_connection",
+        [
+            Connection(
+                conn_id="azure_container_registry",
+                conn_type="azure_container_registry",
+                login="myuser",
+                password="",
+                host="test.cr",
+                extra={"subscription_id": "subscription_id", "resource_group": 
"resource_group"},
+            )
+        ],
+        indirect=True,
+    )
+    @mock.patch(
+        
"airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient"
+    )
+    
@mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.DefaultAzureCredential")
+    def test_get_conn_with_default_azure_credential(
+        self, mocked_default_azure_credential, mocked_client, mocked_connection
+    ):
+        
mocked_client.return_value.registries.list_credentials.return_value.as_dict.return_value
 = {
+            "username": "myuser",
+            "passwords": [
+                {"name": "password", "value": "password"},
+            ],
+        }
+
+        hook = AzureContainerRegistryHook(conn_id=mocked_connection.conn_id)
+        assert hook.connection is not None
+        assert hook.connection.username == "myuser"
+        assert hook.connection.password == "password"
+        assert hook.connection.server == "test.cr"
+
+        mocked_default_azure_credential.assert_called_with()

Reply via email to