This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d2011a3a52 Support user-assigned managed identity for Azure VM auth 
(#66072)
0d2011a3a52 is described below

commit 0d2011a3a5267d34ff06448533d2c4718227237c
Author: Shen YuDong <[email protected]>
AuthorDate: Wed May 6 17:33:38 2026 +0800

    Support user-assigned managed identity for Azure VM auth (#66072)
    
    * Databricks: support user-assigned managed identity for Azure VM auth
    
    This fixes an issue where Databricks connections could not specify
    a user-assigned managed identity when Airflow runs on Azure VM.
    
    Ref: #65588
    
    * fix static check
---
 .../databricks/docs/connections/databricks.rst     |  1 +
 .../providers/databricks/hooks/databricks_base.py  | 13 +++++++--
 .../unit/databricks/hooks/test_databricks_base.py  | 33 ++++++++++++++++++++++
 3 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/providers/databricks/docs/connections/databricks.rst 
b/providers/databricks/docs/connections/databricks.rst
index 634f61f2d4d..630526266be 100644
--- a/providers/databricks/docs/connections/databricks.rst
+++ b/providers/databricks/docs/connections/databricks.rst
@@ -99,6 +99,7 @@ Extra (optional)
     * ``use_default_azure_credential``: required boolean flag to specify if 
the `DefaultAzureCredential` class should be used to retrieve a AAD token. For 
example, this can be used when authenticating with workload identity within an 
Azure Kubernetes Service cluster. Note that this option can't be set together 
with the `use_azure_managed_identity` parameter.
     * ``azure_resource_id``: optional Resource ID of the Azure Databricks 
workspace (required if managed identity isn't
       a user inside workspace)
+    * ``azure_managed_identity_client_id``: optional client ID of the 
user-assigned managed identity. This parameter is only required if you're using 
a user-assigned managed identity. If not specified, the hook will attempt to 
authenticate using a system-assigned managed identity.
 
     The following parameters are necessary if using authentication with 
Kubernetes OIDC token federation:
 
diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index fe9d2245705..7ab51fc7dfa 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -111,6 +111,7 @@ class BaseDatabricksHook(BaseHook):
         "host",
         "use_azure_managed_identity",
         DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
+        "azure_managed_identity_client_id",
         "azure_ad_endpoint",
         "azure_resource_id",
         "azure_tenant_id",
@@ -340,7 +341,12 @@ class BaseDatabricksHook(BaseHook):
             for attempt in self._get_retry_object():
                 with attempt:
                     if 
self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
-                        token = 
ManagedIdentityCredential().get_token(f"{resource}/.default")
+                        client_id = self.databricks_conn.extra_dejson.get(
+                            "azure_managed_identity_client_id", None
+                        )
+                        token = 
ManagedIdentityCredential(client_id=client_id).get_token(
+                            f"{resource}/.default"
+                        )
                     else:
                         credential = ClientSecretCredential(
                             client_id=self._get_connection_attr("login"),
@@ -387,7 +393,10 @@ class BaseDatabricksHook(BaseHook):
             async for attempt in self._a_get_retry_object():
                 with attempt:
                     if 
self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
-                        async with AsyncManagedIdentityCredential() as 
credential:
+                        client_id = self.databricks_conn.extra_dejson.get(
+                            "azure_managed_identity_client_id", None
+                        )
+                        async with 
AsyncManagedIdentityCredential(client_id=client_id) as credential:
                             token = await 
credential.get_token(f"{resource}/.default")
                     else:
                         async with AsyncClientSecretCredential(
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
index 8c8794ebda9..088c3fc27bc 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from datetime import datetime, timedelta
 from unittest import mock
 
@@ -466,6 +467,38 @@ class TestBaseDatabricksHook:
             
mock_get_aad_token.assert_called_once_with(DEFAULT_DATABRICKS_SCOPE)
             mock_log_debug.assert_called_once_with("Using AAD Token for 
managed identity.")
 
+    @mock.patch("azure.identity.ManagedIdentityCredential")
+    def test_get_aad_token_with_managed_identity_client_id(
+        self,
+        mock_credential,
+    ):
+        conn = Connection(
+            host="example.databricks.com",
+            extra=json.dumps(
+                {
+                    "use_azure_managed_identity": True,
+                    "azure_managed_identity_client_id": "cli-id-abc",
+                }
+            ),
+        )
+
+        token_mock = mock.Mock()
+        token_mock.token = "the-token"
+        token_mock.expires_on = 8888888888
+        mock_credential.return_value.get_token.return_value = token_mock
+
+        hook = BaseDatabricksHook()
+        hook.databricks_conn = conn
+        hook.oauth_tokens = {}
+        hook._get_retry_object = lambda: [
+            mock.Mock(__enter__=lambda s: None, __exit__=lambda s, a, b, c: 
None)
+        ]
+
+        token = hook._get_aad_token("https://databricks.azure.com";)
+
+        assert token == "the-token"
+        mock_credential.assert_called_once_with(client_id="cli-id-abc")
+
     @mock.patch(
         
"airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn",
         new_callable=mock.PropertyMock,

Reply via email to