This is an automated email from the ASF dual-hosted git repository.

basph pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a0a3b8a50fb Enable workload identity authentication for the Databricks 
provider (#41639)
a0a3b8a50fb is described below

commit a0a3b8a50fb051770330f64adef4400fb514f3e1
Author: Bas <[email protected]>
AuthorDate: Wed Nov 6 18:26:27 2024 +0100

    Enable workload identity authentication for the Databricks provider (#41639)
    
    * work on test
    
    * work on test
    
    * Cleanup code
    
    * add kubernetes check
    
    * cleanup
    
    * work
    
    * add async functionality
    
    * work
    
    * Update documenation
    
    * Cleanup imports
    
    * Revert settings file
    
    * cleanup comments
    
    * Add back #
    
    * Fix tests
    
    * fix docs
    
    * Work with callables
    
    * Ruff format
    
    * Revert "Work with callables"
    
    This reverts commit 4a6aef481d07acd71491c0cf0f505b675c9652dc.
    
    * Fix typo
    
    * add param
    
    * Update airflow/providers/databricks/hooks/databricks_base.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * Update airflow/providers/databricks/hooks/databricks_base.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * Update airflow/providers/databricks/hooks/databricks_base.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * remove unused import
    
    * change spelling
    
    * add aks to wordlist
    
    * add aks
    
    * Move env mock inside test.
    
    * Move hook inside test.
    
    * Fix merge conflicts.
    
    ---------
    
    Co-authored-by: Bas van Driel <[email protected]>
    Co-authored-by: Bas Harenslak <[email protected]>
    Co-authored-by: Julian de Ruiter 
<[email protected]>
---
 .../connections/databricks.rst                     |  1 +
 docs/spelling_wordlist.txt                         |  1 +
 .../providers/databricks/hooks/databricks_base.py  | 98 ++++++++++++++++++++++
 .../test_databricks_azure_workload_identity.py     | 90 ++++++++++++++++++++
 ...est_databricks_azure_workload_identity_async.py | 89 ++++++++++++++++++++
 5 files changed, 279 insertions(+)

diff --git 
a/docs/apache-airflow-providers-databricks/connections/databricks.rst 
b/docs/apache-airflow-providers-databricks/connections/databricks.rst
index 7045d4a3ae2..2d92bb67eba 100644
--- a/docs/apache-airflow-providers-databricks/connections/databricks.rst
+++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst
@@ -87,6 +87,7 @@ Extra (optional)
 
     * ``use_azure_managed_identity``: required boolean flag to specify if 
managed identity needs to be used instead of
       service principal
+    * ``use_default_azure_credential``: required boolean flag to specify if 
the `DefaultAzureCredential` class should be used to retrieve a AAD token. For 
example, this can be used when authenticating with workload identity within an 
Azure Kubernetes Service cluster. Note that this option can't be set together 
with the `use_azure_managed_identity` parameter.
     * ``azure_resource_id``: optional Resource ID of the Azure Databricks 
workspace (required if managed identity isn't
       a user inside workspace)
 
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index d2c1540337e..0362ce82bdf 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -32,6 +32,7 @@ airbyte
 AirflowException
 airflowignore
 ajax
+aks
 AlertApi
 alertPolicies
 Alibaba
diff --git 
a/providers/src/airflow/providers/databricks/hooks/databricks_base.py 
b/providers/src/airflow/providers/databricks/hooks/databricks_base.py
index 8a4a7335a43..f804f64a326 100644
--- a/providers/src/airflow/providers/databricks/hooks/databricks_base.py
+++ b/providers/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -65,6 +65,8 @@ AZURE_MANAGEMENT_ENDPOINT = 
"https://management.core.windows.net/";
 DEFAULT_DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
 OIDC_TOKEN_SERVICE_URL = "{}/oidc/v1/token"
 
+DEFAULT_AZURE_CREDENTIAL_SETTING_KEY = "use_default_azure_credential"
+
 
 class BaseDatabricksHook(BaseHook):
     """
@@ -89,6 +91,7 @@ class BaseDatabricksHook(BaseHook):
         "token",
         "host",
         "use_azure_managed_identity",
+        DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
         "azure_ad_endpoint",
         "azure_resource_id",
         "azure_tenant_id",
@@ -376,6 +379,94 @@ class BaseDatabricksHook(BaseHook):
 
         return jsn["access_token"]
 
+    def _get_aad_token_for_default_az_credential(self, resource: str) -> str:
+        """
+        Get AAD token for given resource for workload identity.
+
+        Supports managed identity or service principal auth.
+        :param resource: resource to issue token to
+        :return: AAD token, or raise an exception
+        """
+        aad_token = self.oauth_tokens.get(resource)
+        if aad_token and self._is_oauth_token_valid(aad_token):
+            return aad_token["access_token"]
+
+        self.log.info("Existing AAD token is expired, or going to expire soon. 
Refreshing...")
+        try:
+            from azure.identity import DefaultAzureCredential
+
+            for attempt in self._get_retry_object():
+                with attempt:
+                    # This only works in an Azure Kubernetes Service Cluster 
given the following environment variables:
+                    # AZURE_TENANT_ID, AZURE_CLIENT_ID, 
AZURE_FEDERATED_TOKEN_FILE
+                    #
+                    # While there is a WorkloadIdentityCredential class, the 
below class is advised by Microsoft
+                    # 
https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
+                    token = 
DefaultAzureCredential().get_token(f"{resource}/.default")
+
+                    jsn = {
+                        "access_token": token.token,
+                        "token_type": "Bearer",
+                        "expires_on": token.expires_on,
+                    }
+                    self._is_oauth_token_valid(jsn)
+                    self.oauth_tokens[resource] = jsn
+                    break
+        except ImportError as e:
+            raise AirflowOptionalProviderFeatureException(e)
+        except RetryError:
+            raise AirflowException(f"API requests to Azure failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            msg = f"Response: {e.response.content.decode()}, Status Code: 
{e.response.status_code}"
+            raise AirflowException(msg)
+
+        return token.token
+
+    async def _a_get_aad_token_for_default_az_credential(self, resource: str) 
-> str:
+        """
+        Get AAD token for given resource for workload identity.
+
+        Supports managed identity or service principal auth.
+        :param resource: resource to issue token to
+        :return: AAD token, or raise an exception
+        """
+        aad_token = self.oauth_tokens.get(resource)
+        if aad_token and self._is_oauth_token_valid(aad_token):
+            return aad_token["access_token"]
+
+        self.log.info("Existing AAD token is expired, or going to expire soon. 
Refreshing...")
+        try:
+            from azure.identity.aio import (
+                DefaultAzureCredential as AsyncDefaultAzureCredential,
+            )
+
+            for attempt in self._get_retry_object():
+                with attempt:
+                    # This only works in an Azure Kubernetes Service Cluster 
given the following environment variables:
+                    # AZURE_TENANT_ID, AZURE_CLIENT_ID, 
AZURE_FEDERATED_TOKEN_FILE
+                    #
+                    # While there is a WorkloadIdentityCredential class, the 
below class is advised by Microsoft
+                    # 
https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
+                    token = await 
AsyncDefaultAzureCredential().get_token(f"{resource}/.default")
+
+                    jsn = {
+                        "access_token": token.token,
+                        "token_type": "Bearer",
+                        "expires_on": token.expires_on,
+                    }
+                    self._is_oauth_token_valid(jsn)
+                    self.oauth_tokens[resource] = jsn
+                    break
+        except ImportError as e:
+            raise AirflowOptionalProviderFeatureException(e)
+        except RetryError:
+            raise AirflowException(f"API requests to Azure failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            msg = f"Response: {e.response.content.decode()}, Status Code: 
{e.response.status_code}"
+            raise AirflowException(msg)
+
+        return token.token
+
     def _get_aad_headers(self) -> dict:
         """
         Fill AAD headers if necessary (SPN is outside of the workspace).
@@ -476,6 +567,9 @@ class BaseDatabricksHook(BaseHook):
             self.log.debug("Using AAD Token for managed identity.")
             self._check_azure_metadata_service()
             return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+        elif 
self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, 
False):
+            self.log.debug("Using default Azure Credential authentication.")
+            return 
self._get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE)
         elif self.databricks_conn.extra_dejson.get("service_principal_oauth", 
False):
             if self.databricks_conn.login == "" or 
self.databricks_conn.password == "":
                 raise AirflowException("Service Principal credentials aren't 
provided")
@@ -504,6 +598,10 @@ class BaseDatabricksHook(BaseHook):
             self.log.debug("Using AAD Token for managed identity.")
             await self._a_check_azure_metadata_service()
             return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+        elif 
self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, 
False):
+            self.log.debug("Using AzureDefaultCredential for authentication.")
+
+            return await 
self._a_get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE)
         elif self.databricks_conn.extra_dejson.get("service_principal_oauth", 
False):
             if self.databricks_conn.login == "" or 
self.databricks_conn.password == "":
                 raise AirflowException("Service Principal credentials aren't 
provided")
diff --git 
a/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py 
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py
new file mode 100644
index 00000000000..6a57a1340cb
--- /dev/null
+++ 
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import os
+from unittest import mock
+
+import pytest
+import tenacity
+from azure.core.credentials import AccessToken
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_base import 
DEFAULT_AZURE_CREDENTIAL_SETTING_KEY
+from airflow.utils.session import provide_session
+
+
+def create_successful_response_mock(content):
+    response = mock.MagicMock()
+    response.json.return_value = content
+    response.status_code = 200
+    return response
+
+
+def create_aad_token_for_resource() -> AccessToken:
+    return AccessToken(expires_on=1575500666, token="sample-token")
+
+
+HOST = "xx.cloud.databricks.com"
+DEFAULT_CONN_ID = "databricks_default"
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+    wait=tenacity.wait_none(),
+    stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
+
+
[email protected]_test
+class TestDatabricksHookAadTokenWorkloadIdentity:
+    _hook: DatabricksHook
+
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.host = HOST
+        conn.extra = json.dumps(
+            {
+                DEFAULT_AZURE_CREDENTIAL_SETTING_KEY: True,
+            }
+        )
+        session.commit()
+
+        # This will use the default connection id (databricks_default)
+        self._hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+    @mock.patch.dict(
+        os.environ,
+        {
+            "AZURE_CLIENT_ID": "fake-client-id",
+            "AZURE_TENANT_ID": "fake-tenant-id",
+            "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
+            "KUBERNETES_SERVICE_HOST": "fakeip",
+        },
+    )
+    @mock.patch(
+        "azure.identity.DefaultAzureCredential.get_token", 
return_value=create_aad_token_for_resource()
+    )
+    
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests.get")
+    def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
+        requests_mock.return_value = create_successful_response_mock({"jobs": 
[]})
+
+        result = self._hook.list_jobs()
+
+        assert result == []
diff --git 
a/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
 
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
new file mode 100644
index 00000000000..0a43262acc4
--- /dev/null
+++ 
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import os
+from unittest import mock
+
+import pytest
+import tenacity
+from azure.core.credentials import AccessToken
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_base import 
DEFAULT_AZURE_CREDENTIAL_SETTING_KEY
+from airflow.utils.session import provide_session
+
+
+def create_successful_response_mock(content):
+    response = mock.MagicMock()
+    response.json.return_value = content
+    response.status_code = 200
+    return response
+
+
+def create_aad_token_for_resource() -> AccessToken:
+    return AccessToken(expires_on=1575500666, token="sample-token")
+
+
+HOST = "xx.cloud.databricks.com"
+DEFAULT_CONN_ID = "databricks_default"
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+    wait=tenacity.wait_none(),
+    stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
+
+
[email protected]_test
+class TestDatabricksHookAadTokenWorkloadIdentityAsync:
+    @provide_session
+    def setup_method(self, method, session=None):
+        conn = session.query(Connection).filter(Connection.conn_id == 
DEFAULT_CONN_ID).first()
+        conn.host = HOST
+        conn.extra = json.dumps(
+            {
+                DEFAULT_AZURE_CREDENTIAL_SETTING_KEY: True,
+            }
+        )
+        session.commit()
+
+    @pytest.mark.asyncio
+    @mock.patch(
+        "azure.identity.aio.DefaultAzureCredential.get_token", 
return_value=create_aad_token_for_resource()
+    )
+    
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
+    async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
+        with mock.patch.dict(
+            os.environ,
+            {
+                "AZURE_CLIENT_ID": "fake-client-id",
+                "AZURE_TENANT_ID": "fake-tenant-id",
+                "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
+                "KUBERNETES_SERVICE_HOST": "fakeip",
+            },
+        ):
+            
requests_mock.return_value.__aenter__.return_value.json.side_effect = 
mock.AsyncMock(
+                side_effect=[{"data": 1}]
+            )
+
+            async with DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) as hook:
+                result = await hook.a_get_run_output(0)
+
+            assert result == {"data": 1}

Reply via email to