This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 959e63eb1ee Revert "update AzureBaseHook to return credentials that
supports get_token me…" (#56223)
959e63eb1ee is described below
commit 959e63eb1eed297daa7388be5663b3d66e07ec6f
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 29 22:45:38 2025 +0200
Revert "update AzureBaseHook to return credentials that supports get_token
me…" (#56223)
This reverts commit d5ded852187a8559aef45692dac1ce9853b91ad5.
---
.../microsoft/azure/docs/connections/azure.rst | 2 -
.../providers/microsoft/azure/hooks/base_azure.py | 92 ++++------------------
.../unit/microsoft/azure/hooks/test_base_azure.py | 90 +--------------------
3 files changed, 17 insertions(+), 167 deletions(-)
diff --git a/providers/microsoft/azure/docs/connections/azure.rst
b/providers/microsoft/azure/docs/connections/azure.rst
index abe425067eb..f8d111fd34a 100644
--- a/providers/microsoft/azure/docs/connections/azure.rst
+++ b/providers/microsoft/azure/docs/connections/azure.rst
@@ -74,8 +74,6 @@ Extra (optional)
It specifies the json that contains the authentication information.
* ``managed_identity_client_id``: The client ID of a user-assigned
managed identity. If provided with ``workload_identity_tenant_id``, they'll
pass to DefaultAzureCredential_.
* ``workload_identity_tenant_id``: ID of the application's Microsoft Entra
tenant. Also called its "directory" ID. If provided with
``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
- * ``use_azure_identity_object``: If set to true, it will use credential of
newer type: ClientSecretCredential or DefaultAzureCredential instead of
ServicePrincipalCredentials or AzureIdentityCredentialAdapter.
- These newer credentials support get_token method which can be used to
generate OAuth token with custom scope.
The entire extra column can be left out to fall back on
DefaultAzureCredential_.
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
index 56933216760..2a59234e0a4 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -16,25 +16,18 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any
+from typing import Any
from azure.common.client_factory import get_client_from_auth_file,
get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
- get_sync_default_azure_credential,
)
from airflow.providers.microsoft.azure.version_compat import BaseHook
-if TYPE_CHECKING:
- from azure.core.credentials import AccessToken
-
- from airflow.sdk import Connection
-
class AzureBaseHook(BaseHook):
"""
@@ -92,7 +85,7 @@ class AzureBaseHook(BaseHook):
},
}
- def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
+ def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
self.sdk_client = sdk_client
self.conn_id = conn_id
super().__init__()
@@ -103,9 +96,8 @@ class AzureBaseHook(BaseHook):
:return: the authenticated client.
"""
- if not self.sdk_client:
- raise ValueError("`sdk_client` must be provided to AzureBaseHook
to use `get_conn` method.")
conn = self.get_connection(self.conn_id)
+ tenant = conn.extra_dejson.get("tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
key_path = conn.extra_dejson.get("key_path")
if key_path:
@@ -119,74 +111,22 @@ class AzureBaseHook(BaseHook):
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client,
config_dict=key_json)
- credentials = self.get_credential(conn=conn)
-
- return self.sdk_client(
- credentials=credentials,
- subscription_id=subscription_id,
- )
-
- def get_credential(self, *, conn: Connection | None = None) -> Any:
- """
- Get Azure credential object for the connection.
-
- Azure Identity based credential object (``ClientSecretCredential``,
``DefaultAzureCredential``) can be used to get OAuth token using ``get_token``
method.
- Older Credential objects (``ServicePrincipalCredentials``,
``AzureIdentityCredentialAdapter``) are supported for backward compatibility.
-
- :return: The Azure credential object
- """
- if not conn:
- conn = self.get_connection(self.conn_id)
- tenant = conn.extra_dejson.get("tenantId")
- credential: (
- ServicePrincipalCredentials
- | AzureIdentityCredentialAdapter
- | ClientSecretCredential
- | DefaultAzureCredential
- )
+ credentials: ServicePrincipalCredentials |
AzureIdentityCredentialAdapter
if all([conn.login, conn.password, tenant]):
- credential = self._get_client_secret_credential(conn)
- else:
- credential = self._get_default_azure_credential(conn)
- return credential
-
- def _get_client_secret_credential(self, conn: Connection):
- self.log.info("Getting credentials using specific credentials and
subscription_id.")
- extra_dejson = conn.extra_dejson
- tenant = extra_dejson.get("tenantId")
- use_azure_identity_object =
extra_dejson.get("use_azure_identity_object", False)
- if use_azure_identity_object:
- return ClientSecretCredential(
- client_id=conn.login, # type: ignore[arg-type]
- client_secret=conn.password, # type: ignore[arg-type]
- tenant_id=tenant, # type: ignore[arg-type]
+ self.log.info("Getting connection using specific credentials and
subscription_id.")
+ credentials = ServicePrincipalCredentials(
+ client_id=conn.login, secret=conn.password, tenant=tenant
)
- return ServicePrincipalCredentials(client_id=conn.login,
secret=conn.password, tenant=tenant)
-
- def _get_default_azure_credential(self, conn: Connection):
- self.log.info("Using DefaultAzureCredential as credential")
- extra_dejson = conn.extra_dejson
- managed_identity_client_id =
extra_dejson.get("managed_identity_client_id")
- workload_identity_tenant_id =
extra_dejson.get("workload_identity_tenant_id")
- use_azure_identity_object =
extra_dejson.get("use_azure_identity_object", False)
- if use_azure_identity_object:
- return get_sync_default_azure_credential(
+ else:
+ self.log.info("Using DefaultAzureCredential as credential")
+ managed_identity_client_id =
conn.extra_dejson.get("managed_identity_client_id")
+ workload_identity_tenant_id =
conn.extra_dejson.get("workload_identity_tenant_id")
+ credentials = AzureIdentityCredentialAdapter(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
- return AzureIdentityCredentialAdapter(
- managed_identity_client_id=managed_identity_client_id,
- workload_identity_tenant_id=workload_identity_tenant_id,
- )
- def get_token(self, *scopes, **kwargs) -> AccessToken:
- """Request an access token for `scopes`."""
- credential = self.get_credential()
- if isinstance(credential, AzureIdentityCredentialAdapter) or
isinstance(
- credential, AzureIdentityCredentialAdapter
- ):
- raise AttributeError(
- "The azure credential does not support get_token method. "
- "Please set `use_azure_identity_object: True` in the
connection extra field to use credential that support get_token method."
- )
- return credential.get_token(*scopes, **kwargs)
+ return self.sdk_client(
+ credentials=credentials,
+ subscription_id=subscription_id,
+ )
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
index 98df1a4c08f..89881eae165 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
@@ -31,7 +31,6 @@ if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
Connection = MagicMock() # type: ignore[misc]
MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
-UTILS = "airflow.providers.microsoft.azure.utils"
class TestBaseAzureHook:
@@ -112,7 +111,7 @@ class TestBaseAzureHook:
indirect=True,
)
@patch("azure.common.credentials.ServicePrincipalCredentials")
- @patch(f"{MODULE}.AzureIdentityCredentialAdapter")
+
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
def test_get_conn_fallback_to_azure_identity_credential_adapter(
self,
mock_credential_adapter,
@@ -134,90 +133,3 @@ class TestBaseAzureHook:
credentials=mock_credential,
subscription_id="subscription_id",
)
-
- @patch(f"{MODULE}.ClientSecretCredential")
- @pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="azure_default",
- login="my_login",
- password="my_password",
- extra={"tenantId": "my_tenant", "use_azure_identity_object":
True},
- ),
- ],
- indirect=True,
- )
- def test_get_credential_with_client_secret(self, mock_spc,
mocked_connection):
- mock_spc.return_value = "foo-bar"
- cred = AzureBaseHook().get_credential()
-
- mock_spc.assert_called_once_with(
- client_id=mocked_connection.login,
- client_secret=mocked_connection.password,
- tenant_id=mocked_connection.extra_dejson["tenantId"],
- )
- assert cred == "foo-bar"
-
- @patch(f"{UTILS}.DefaultAzureCredential")
- @pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="azure_default",
- extra={"use_azure_identity_object": True},
- ),
- ],
- indirect=True,
- )
- def test_get_credential_with_azure_default_credential(self, mock_spc,
mocked_connection):
- mock_spc.return_value = "foo-bar"
- cred = AzureBaseHook().get_credential()
-
- mock_spc.assert_called_once_with()
- assert cred == "foo-bar"
-
- @patch(f"{UTILS}.DefaultAzureCredential")
- @pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="azure_default",
- extra={
- "managed_identity_client_id": "test_client_id",
- "workload_identity_tenant_id": "test_tenant_id",
- "use_azure_identity_object": True,
- },
- ),
- ],
- indirect=True,
- )
- def test_get_credential_with_azure_default_credential_with_extra(self,
mock_spc, mocked_connection):
- mock_spc.return_value = "foo-bar"
- cred = AzureBaseHook().get_credential()
-
- mock_spc.assert_called_once_with(
-
managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"),
-
workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"),
-
additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")],
- )
- assert cred == "foo-bar"
-
- @patch(f"{UTILS}.DefaultAzureCredential")
- @pytest.mark.parametrize(
- "mocked_connection",
- [
- Connection(
- conn_id="azure_default",
- extra={"use_azure_identity_object": True},
- ),
- ],
- indirect=True,
- )
- def test_get_token_with_azure_default_credential(self, mock_spc,
mocked_connection):
- mock_spc.get_token.return_value = "new-token"
- scope = "custom_scope"
- token = AzureBaseHook().get_token(scope)
-
- mock_spc.assert_called_once_with()
- assert token == "new-token"