This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b011b28ad Make DefaultAzureCredential in AzureBaseHook configuration 
(#35051)
2b011b28ad is described below

commit 2b011b28adf95ec8c686cdb69630c28b74049cf4
Author: Wei Lee <[email protected]>
AuthorDate: Tue Oct 31 02:01:27 2023 +0900

    Make DefaultAzureCredential in AzureBaseHook configuration (#35051)
    
    * feat(provider/microsoft): make managed_identity configurable int base 
azure hook
    
    * test(providers/microsoft): add test case for verifying calling 
DefaultAzureCredential with user provided identity in 
AzureIdentityCredentialAdapter
    
    * docs(microsoft/azure): update azure base hook doc for manged identity args
    
    * refactor(provider/microsoft): extract get_default_azure_credential 
function
---
 .../providers/microsoft/azure/hooks/base_azure.py  |  8 +++++
 airflow/providers/microsoft/azure/utils.py         | 39 +++++++++++++++++++---
 .../connections/azure.rst                          | 12 +++++--
 tests/providers/microsoft/azure/test_utils.py      | 21 ++++++++++++
 4 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py 
b/airflow/providers/microsoft/azure/hooks/base_azure.py
index 54130b3b2e..a58b6cae1a 100644
--- a/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ b/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -54,6 +54,12 @@ class AzureBaseHook(BaseHook):
         return {
             "tenantId": StringField(lazy_gettext("Azure Tenant ID"), 
widget=BS3TextFieldWidget()),
             "subscriptionId": StringField(lazy_gettext("Azure Subscription 
ID"), widget=BS3TextFieldWidget()),
+            "managed_identity_client_id": StringField(
+                lazy_gettext("Managed Identity Client ID"), 
widget=BS3TextFieldWidget()
+            ),
+            "workload_identity_tenant_id": StringField(
+                lazy_gettext("Workload Identity Tenant ID"), 
widget=BS3TextFieldWidget()
+            ),
         }
 
     @staticmethod
@@ -79,6 +85,8 @@ class AzureBaseHook(BaseHook):
                 "password": "secret (token credentials auth)",
                 "tenantId": "tenantId (token credentials auth)",
                 "subscriptionId": "subscriptionId (token credentials auth)",
+                "managed_identity_client_id": "Managed Identity Client ID",
+                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
diff --git a/airflow/providers/microsoft/azure/utils.py 
b/airflow/providers/microsoft/azure/utils.py
index 5afc2a48ca..f63580e958 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -51,6 +51,24 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, 
field_name: str):
     return ret
 
 
+def get_default_azure_credential(
+    managed_identity_client_id: str | None, workload_identity_tenant_id: str | 
None
+) -> DefaultAzureCredential:
+    """Get DefaultAzureCredential based on provided arguments.
+
+    If managed_identity_client_id and workload_identity_tenant_id are 
provided, this function returns
+    DefaultAzureCredential with managed identity.
+    """
+    if managed_identity_client_id and workload_identity_tenant_id:
+        return DefaultAzureCredential(
+            managed_identity_client_id=managed_identity_client_id,
+            workload_identity_tenant_id=workload_identity_tenant_id,
+            additionally_allowed_tenants=[workload_identity_tenant_id],
+        )
+    else:
+        return DefaultAzureCredential()
+
+
 class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
     """Adapt azure-identity credentials for backward compatibility.
 
@@ -60,15 +78,28 @@ class 
AzureIdentityCredentialAdapter(BasicTokenAuthentication):
     Check 
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
     """
 
-    def __init__(self, credential=None, 
resource_id="https://management.azure.com/.default";, **kwargs):
+    def __init__(
+        self,
+        credential=None,
+        resource_id="https://management.azure.com/.default";,
+        *,
+        managed_identity_client_id: str | None = None,
+        workload_identity_tenant_id: str | None = None,
+        **kwargs,
+    ):
         """Adapt azure-identity credentials for backward compatibility.
 
         :param credential: Any azure-identity credential 
(DefaultAzureCredential by default)
-        :param str resource_id: The scope to use to get the token (default ARM)
+        :param resource_id: The scope to use to get the token (default ARM)
+        :param managed_identity_client_id: The client ID of a user-assigned 
managed identity.
+            If provided with `workload_identity_tenant_id`, they'll pass to 
``DefaultAzureCredential``.
+        :param workload_identity_tenant_id: ID of the application's Microsoft 
Entra tenant.
+            Also called its "directory" ID.
+            If provided with `managed_identity_client_id`, they'll pass to 
``DefaultAzureCredential``.
         """
-        super().__init__(None)
+        super().__init__(None)  # type: ignore[arg-type]
         if credential is None:
-            credential = DefaultAzureCredential()
+            credential = 
get_default_azure_credential(managed_identity_client_id, 
workload_identity_tenant_id)
         self._policy = BearerTokenCredentialPolicy(credential, resource_id, 
**kwargs)
 
     def _make_request(self):
diff --git 
a/docs/apache-airflow-providers-microsoft-azure/connections/azure.rst 
b/docs/apache-airflow-providers-microsoft-azure/connections/azure.rst
index 0b82d14667..3c1e8d4e20 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/azure.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/azure.rst
@@ -19,6 +19,7 @@
 
 .. _howto/connection:azure:
 
+
 Microsoft Azure Connection
 ==========================
 
@@ -27,14 +28,15 @@ The Microsoft Azure connection type enables the Azure 
Integrations.
 Authenticating to Azure
 -----------------------
 
-There are four ways to connect to Azure using Airflow.
+There are five ways to connect to Azure using Airflow.
 
 1. Use `token credentials`_
    i.e. add specific credentials (client_id, secret, tenant) and subscription 
id to the Airflow connection.
 2. Use a `JSON file`_
 3. Use a `JSON dictionary`_
    i.e. add a key config directly into the Airflow connection.
-4. Fallback on `DefaultAzureCredential`_.
+4. Use managed identity through providing ``managed_identity_client_id`` and 
``workload_identity_tenant_id``.
+5. Fallback on `DefaultAzureCredential`_.
    This includes a mechanism to try different options to authenticate: Managed 
System Identity, environment variables, authentication through Azure CLI and 
etc.
    ``subscriptionId`` is required in this authentication mechanism.
 
@@ -71,6 +73,8 @@ Extra (optional)
       It specifies the path to the json file that contains the authentication 
information.
     * ``key_json``: If set, it uses the *JSON dictionary* authentication 
mechanism.
       It specifies the json that contains the authentication information.
+    * ``managed_identity_client_id``:  The client ID of a user-assigned 
managed identity. If provided with `workload_identity_tenant_id`, they'll pass 
to ``DefaultAzureCredential``.
+    * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra 
tenant. Also called its "directory" ID. If provided with 
`managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``.
 
     The entire extra column can be left out to fall back on 
DefaultAzureCredential_.
 
@@ -90,3 +94,7 @@ For example:
 .. _JSON file: 
https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-a-json-file
 .. _JSON dictionary: 
https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-a-json-dictionary>
 .. _DefaultAzureCredential: 
https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential
+
+.. spelling:word-list::
+
+    Entra
diff --git a/tests/providers/microsoft/azure/test_utils.py 
b/tests/providers/microsoft/azure/test_utils.py
index eefce54e93..ad5e3cbc2d 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -75,3 +75,24 @@ class TestAzureIdentityCredentialAdapter:
 
         adapter.signed_session()
         assert adapter.token == {"access_token": "token"}
+
+    @mock.patch(f"{MODULE}.PipelineRequest")
+    @mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")
+    @mock.patch(f"{MODULE}.DefaultAzureCredential")
+    def test_init_with_identity(self, mock_default_azure_credential, 
mock_policy, mock_request):
+        mock_request.return_value.http_request.headers = {"Authorization": 
"Bearer token"}
+
+        adapter = AzureIdentityCredentialAdapter(
+            managed_identity_client_id="managed_identity_client_id",
+            workload_identity_tenant_id="workload_identity_tenant_id",
+            additionally_allowed_tenants=["workload_identity_tenant_id"],
+        )
+        mock_default_azure_credential.assert_called_once_with(
+            managed_identity_client_id="managed_identity_client_id",
+            workload_identity_tenant_id="workload_identity_tenant_id",
+            additionally_allowed_tenants=["workload_identity_tenant_id"],
+        )
+        mock_policy.assert_called_once()
+
+        adapter.signed_session()
+        assert adapter.token == {"access_token": "token"}

Reply via email to