This is an automated email from the ASF dual-hosted git repository.
basph pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a0a3b8a50fb Enable workload identity authentication for the Databricks
provider (#41639)
a0a3b8a50fb is described below
commit a0a3b8a50fb051770330f64adef4400fb514f3e1
Author: Bas <[email protected]>
AuthorDate: Wed Nov 6 18:26:27 2024 +0100
Enable workload identity authentication for the Databricks provider (#41639)
* work on test
* work on test
* Cleanup code
* add kubernetes check
* cleanup
* work
* add async functionality
* work
* Update documenation
* Cleanup imports
* Revert settings file
* cleanup comments
* Add back #
* Fix tests
* fix docs
* Work with callables
* Ruff format
* Revert "Work with callables"
This reverts commit 4a6aef481d07acd71491c0cf0f505b675c9652dc.
* Fix typo
* add param
* Update airflow/providers/databricks/hooks/databricks_base.py
Co-authored-by: Bas Harenslak <[email protected]>
* Update airflow/providers/databricks/hooks/databricks_base.py
Co-authored-by: Bas Harenslak <[email protected]>
* Update airflow/providers/databricks/hooks/databricks_base.py
Co-authored-by: Bas Harenslak <[email protected]>
* remove unused import
* change spelling
* add aks to wordlist
* add aks
* Move env mock inside test.
* Move hook inside test.
* Fix merge conflicts.
---------
Co-authored-by: Bas van Driel <[email protected]>
Co-authored-by: Bas Harenslak <[email protected]>
Co-authored-by: Julian de Ruiter
<[email protected]>
---
.../connections/databricks.rst | 1 +
docs/spelling_wordlist.txt | 1 +
.../providers/databricks/hooks/databricks_base.py | 98 ++++++++++++++++++++++
.../test_databricks_azure_workload_identity.py | 90 ++++++++++++++++++++
...est_databricks_azure_workload_identity_async.py | 89 ++++++++++++++++++++
5 files changed, 279 insertions(+)
diff --git
a/docs/apache-airflow-providers-databricks/connections/databricks.rst
b/docs/apache-airflow-providers-databricks/connections/databricks.rst
index 7045d4a3ae2..2d92bb67eba 100644
--- a/docs/apache-airflow-providers-databricks/connections/databricks.rst
+++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst
@@ -87,6 +87,7 @@ Extra (optional)
* ``use_azure_managed_identity``: required boolean flag to specify if
managed identity needs to be used instead of
service principal
+ * ``use_default_azure_credential``: required boolean flag to specify if
the `DefaultAzureCredential` class should be used to retrieve a AAD token. For
example, this can be used when authenticating with workload identity within an
Azure Kubernetes Service cluster. Note that this option can't be set together
with the `use_azure_managed_identity` parameter.
* ``azure_resource_id``: optional Resource ID of the Azure Databricks
workspace (required if managed identity isn't
a user inside workspace)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index d2c1540337e..0362ce82bdf 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -32,6 +32,7 @@ airbyte
AirflowException
airflowignore
ajax
+aks
AlertApi
alertPolicies
Alibaba
diff --git
a/providers/src/airflow/providers/databricks/hooks/databricks_base.py
b/providers/src/airflow/providers/databricks/hooks/databricks_base.py
index 8a4a7335a43..f804f64a326 100644
--- a/providers/src/airflow/providers/databricks/hooks/databricks_base.py
+++ b/providers/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -65,6 +65,8 @@ AZURE_MANAGEMENT_ENDPOINT =
"https://management.core.windows.net/"
DEFAULT_DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
OIDC_TOKEN_SERVICE_URL = "{}/oidc/v1/token"
+DEFAULT_AZURE_CREDENTIAL_SETTING_KEY = "use_default_azure_credential"
+
class BaseDatabricksHook(BaseHook):
"""
@@ -89,6 +91,7 @@ class BaseDatabricksHook(BaseHook):
"token",
"host",
"use_azure_managed_identity",
+ DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
"azure_ad_endpoint",
"azure_resource_id",
"azure_tenant_id",
@@ -376,6 +379,94 @@ class BaseDatabricksHook(BaseHook):
return jsn["access_token"]
+ def _get_aad_token_for_default_az_credential(self, resource: str) -> str:
+ """
+ Get AAD token for given resource for workload identity.
+
+ Supports managed identity or service principal auth.
+ :param resource: resource to issue token to
+ :return: AAD token, or raise an exception
+ """
+ aad_token = self.oauth_tokens.get(resource)
+ if aad_token and self._is_oauth_token_valid(aad_token):
+ return aad_token["access_token"]
+
+ self.log.info("Existing AAD token is expired, or going to expire soon.
Refreshing...")
+ try:
+ from azure.identity import DefaultAzureCredential
+
+ for attempt in self._get_retry_object():
+ with attempt:
+ # This only works in an Azure Kubernetes Service Cluster
given the following environment variables:
+ # AZURE_TENANT_ID, AZURE_CLIENT_ID,
AZURE_FEDERATED_TOKEN_FILE
+ #
+ # While there is a WorkloadIdentityCredential class, the
below class is advised by Microsoft
+ #
https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
+ token =
DefaultAzureCredential().get_token(f"{resource}/.default")
+
+ jsn = {
+ "access_token": token.token,
+ "token_type": "Bearer",
+ "expires_on": token.expires_on,
+ }
+ self._is_oauth_token_valid(jsn)
+ self.oauth_tokens[resource] = jsn
+ break
+ except ImportError as e:
+ raise AirflowOptionalProviderFeatureException(e)
+ except RetryError:
+ raise AirflowException(f"API requests to Azure failed
{self.retry_limit} times. Giving up.")
+ except requests_exceptions.HTTPError as e:
+ msg = f"Response: {e.response.content.decode()}, Status Code:
{e.response.status_code}"
+ raise AirflowException(msg)
+
+ return token.token
+
+ async def _a_get_aad_token_for_default_az_credential(self, resource: str)
-> str:
+ """
+ Get AAD token for given resource for workload identity.
+
+ Supports managed identity or service principal auth.
+ :param resource: resource to issue token to
+ :return: AAD token, or raise an exception
+ """
+ aad_token = self.oauth_tokens.get(resource)
+ if aad_token and self._is_oauth_token_valid(aad_token):
+ return aad_token["access_token"]
+
+ self.log.info("Existing AAD token is expired, or going to expire soon.
Refreshing...")
+ try:
+ from azure.identity.aio import (
+ DefaultAzureCredential as AsyncDefaultAzureCredential,
+ )
+
+ for attempt in self._get_retry_object():
+ with attempt:
+ # This only works in an Azure Kubernetes Service Cluster
given the following environment variables:
+ # AZURE_TENANT_ID, AZURE_CLIENT_ID,
AZURE_FEDERATED_TOKEN_FILE
+ #
+ # While there is a WorkloadIdentityCredential class, the
below class is advised by Microsoft
+ #
https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
+ token = await
AsyncDefaultAzureCredential().get_token(f"{resource}/.default")
+
+ jsn = {
+ "access_token": token.token,
+ "token_type": "Bearer",
+ "expires_on": token.expires_on,
+ }
+ self._is_oauth_token_valid(jsn)
+ self.oauth_tokens[resource] = jsn
+ break
+ except ImportError as e:
+ raise AirflowOptionalProviderFeatureException(e)
+ except RetryError:
+ raise AirflowException(f"API requests to Azure failed
{self.retry_limit} times. Giving up.")
+ except requests_exceptions.HTTPError as e:
+ msg = f"Response: {e.response.content.decode()}, Status Code:
{e.response.status_code}"
+ raise AirflowException(msg)
+
+ return token.token
+
def _get_aad_headers(self) -> dict:
"""
Fill AAD headers if necessary (SPN is outside of the workspace).
@@ -476,6 +567,9 @@ class BaseDatabricksHook(BaseHook):
self.log.debug("Using AAD Token for managed identity.")
self._check_azure_metadata_service()
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+ elif
self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
False):
+ self.log.debug("Using default Azure Credential authentication.")
+ return
self._get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE)
elif self.databricks_conn.extra_dejson.get("service_principal_oauth",
False):
if self.databricks_conn.login == "" or
self.databricks_conn.password == "":
raise AirflowException("Service Principal credentials aren't
provided")
@@ -504,6 +598,10 @@ class BaseDatabricksHook(BaseHook):
self.log.debug("Using AAD Token for managed identity.")
await self._a_check_azure_metadata_service()
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
+ elif
self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY,
False):
+ self.log.debug("Using AzureDefaultCredential for authentication.")
+
+ return await
self._a_get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE)
elif self.databricks_conn.extra_dejson.get("service_principal_oauth",
False):
if self.databricks_conn.login == "" or
self.databricks_conn.password == "":
raise AirflowException("Service Principal credentials aren't
provided")
diff --git
a/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py
new file mode 100644
index 00000000000..6a57a1340cb
--- /dev/null
+++
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import os
+from unittest import mock
+
+import pytest
+import tenacity
+from azure.core.credentials import AccessToken
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_base import
DEFAULT_AZURE_CREDENTIAL_SETTING_KEY
+from airflow.utils.session import provide_session
+
+
+def create_successful_response_mock(content):
+ response = mock.MagicMock()
+ response.json.return_value = content
+ response.status_code = 200
+ return response
+
+
+def create_aad_token_for_resource() -> AccessToken:
+ return AccessToken(expires_on=1575500666, token="sample-token")
+
+
+HOST = "xx.cloud.databricks.com"
+DEFAULT_CONN_ID = "databricks_default"
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+ wait=tenacity.wait_none(),
+ stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
+
+
[email protected]_test
+class TestDatabricksHookAadTokenWorkloadIdentity:
+ _hook: DatabricksHook
+
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.extra = json.dumps(
+ {
+ DEFAULT_AZURE_CREDENTIAL_SETTING_KEY: True,
+ }
+ )
+ session.commit()
+
+ # This will use the default connection id (databricks_default)
+ self._hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AZURE_CLIENT_ID": "fake-client-id",
+ "AZURE_TENANT_ID": "fake-tenant-id",
+ "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
+ "KUBERNETES_SERVICE_HOST": "fakeip",
+ },
+ )
+ @mock.patch(
+ "azure.identity.DefaultAzureCredential.get_token",
return_value=create_aad_token_for_resource()
+ )
+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests.get")
+ def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
+ requests_mock.return_value = create_successful_response_mock({"jobs":
[]})
+
+ result = self._hook.list_jobs()
+
+ assert result == []
diff --git
a/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
new file mode 100644
index 00000000000..0a43262acc4
--- /dev/null
+++
b/providers/tests/databricks/hooks/test_databricks_azure_workload_identity_async.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import os
+from unittest import mock
+
+import pytest
+import tenacity
+from azure.core.credentials import AccessToken
+
+from airflow.models import Connection
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_base import
DEFAULT_AZURE_CREDENTIAL_SETTING_KEY
+from airflow.utils.session import provide_session
+
+
+def create_successful_response_mock(content):
+ response = mock.MagicMock()
+ response.json.return_value = content
+ response.status_code = 200
+ return response
+
+
+def create_aad_token_for_resource() -> AccessToken:
+ return AccessToken(expires_on=1575500666, token="sample-token")
+
+
+HOST = "xx.cloud.databricks.com"
+DEFAULT_CONN_ID = "databricks_default"
+DEFAULT_RETRY_NUMBER = 3
+DEFAULT_RETRY_ARGS = dict(
+ wait=tenacity.wait_none(),
+ stop=tenacity.stop_after_attempt(DEFAULT_RETRY_NUMBER),
+)
+
+
[email protected]_test
+class TestDatabricksHookAadTokenWorkloadIdentityAsync:
+ @provide_session
+ def setup_method(self, method, session=None):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.extra = json.dumps(
+ {
+ DEFAULT_AZURE_CREDENTIAL_SETTING_KEY: True,
+ }
+ )
+ session.commit()
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "azure.identity.aio.DefaultAzureCredential.get_token",
return_value=create_aad_token_for_resource()
+ )
+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
+ async def test_one(self, requests_mock, get_token_mock: mock.MagicMock):
+ with mock.patch.dict(
+ os.environ,
+ {
+ "AZURE_CLIENT_ID": "fake-client-id",
+ "AZURE_TENANT_ID": "fake-tenant-id",
+ "AZURE_FEDERATED_TOKEN_FILE": "/badpath",
+ "KUBERNETES_SERVICE_HOST": "fakeip",
+ },
+ ):
+
requests_mock.return_value.__aenter__.return_value.json.side_effect =
mock.AsyncMock(
+ side_effect=[{"data": 1}]
+ )
+
+ async with DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) as hook:
+ result = await hook.a_get_run_output(0)
+
+ assert result == {"data": 1}