This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 959e63eb1ee Revert "update AzureBaseHook to return credentials that 
supports get_token me…" (#56223)
959e63eb1ee is described below

commit 959e63eb1eed297daa7388be5663b3d66e07ec6f
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 29 22:45:38 2025 +0200

    Revert "update AzureBaseHook to return credentials that supports get_token 
me…" (#56223)
    
    This reverts commit d5ded852187a8559aef45692dac1ce9853b91ad5.
---
 .../microsoft/azure/docs/connections/azure.rst     |  2 -
 .../providers/microsoft/azure/hooks/base_azure.py  | 92 ++++------------------
 .../unit/microsoft/azure/hooks/test_base_azure.py  | 90 +--------------------
 3 files changed, 17 insertions(+), 167 deletions(-)

diff --git a/providers/microsoft/azure/docs/connections/azure.rst 
b/providers/microsoft/azure/docs/connections/azure.rst
index abe425067eb..f8d111fd34a 100644
--- a/providers/microsoft/azure/docs/connections/azure.rst
+++ b/providers/microsoft/azure/docs/connections/azure.rst
@@ -74,8 +74,6 @@ Extra (optional)
       It specifies the json that contains the authentication information.
     * ``managed_identity_client_id``:  The client ID of a user-assigned 
managed identity. If provided with ``workload_identity_tenant_id``, they'll 
pass to DefaultAzureCredential_.
     * ``workload_identity_tenant_id``: ID of the application's Microsoft Entra 
tenant. Also called its "directory" ID. If provided with 
``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
-    * ``use_azure_identity_object``: If set to true, it will use credential of 
newer type: ClientSecretCredential or DefaultAzureCredential instead of 
ServicePrincipalCredentials or AzureIdentityCredentialAdapter.
-      These newer credentials support get_token method which can be used to 
generate OAuth token with custom scope.
 
     The entire extra column can be left out to fall back on 
DefaultAzureCredential_.
 
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
index 56933216760..2a59234e0a4 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py
@@ -16,25 +16,18 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any
+from typing import Any
 
 from azure.common.client_factory import get_client_from_auth_file, 
get_client_from_json_dict
 from azure.common.credentials import ServicePrincipalCredentials
-from azure.identity import ClientSecretCredential, DefaultAzureCredential
 
 from airflow.exceptions import AirflowException
 from airflow.providers.microsoft.azure.utils import (
     AzureIdentityCredentialAdapter,
     add_managed_identity_connection_widgets,
-    get_sync_default_azure_credential,
 )
 from airflow.providers.microsoft.azure.version_compat import BaseHook
 
-if TYPE_CHECKING:
-    from azure.core.credentials import AccessToken
-
-    from airflow.sdk import Connection
-
 
 class AzureBaseHook(BaseHook):
     """
@@ -92,7 +85,7 @@ class AzureBaseHook(BaseHook):
             },
         }
 
-    def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
+    def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
         self.sdk_client = sdk_client
         self.conn_id = conn_id
         super().__init__()
@@ -103,9 +96,8 @@ class AzureBaseHook(BaseHook):
 
         :return: the authenticated client.
         """
-        if not self.sdk_client:
-            raise ValueError("`sdk_client` must be provided to AzureBaseHook 
to use `get_conn` method.")
         conn = self.get_connection(self.conn_id)
+        tenant = conn.extra_dejson.get("tenantId")
         subscription_id = conn.extra_dejson.get("subscriptionId")
         key_path = conn.extra_dejson.get("key_path")
         if key_path:
@@ -119,74 +111,22 @@ class AzureBaseHook(BaseHook):
             self.log.info("Getting connection using a JSON config.")
             return get_client_from_json_dict(client_class=self.sdk_client, 
config_dict=key_json)
 
-        credentials = self.get_credential(conn=conn)
-
-        return self.sdk_client(
-            credentials=credentials,
-            subscription_id=subscription_id,
-        )
-
-    def get_credential(self, *, conn: Connection | None = None) -> Any:
-        """
-        Get Azure credential object for the connection.
-
-        Azure Identity based credential object (``ClientSecretCredential``, 
``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` 
method.
-        Older Credential objects (``ServicePrincipalCredentials``, 
``AzureIdentityCredentialAdapter``) are supported for backward compatibility.
-
-        :return: The Azure credential object
-        """
-        if not conn:
-            conn = self.get_connection(self.conn_id)
-        tenant = conn.extra_dejson.get("tenantId")
-        credential: (
-            ServicePrincipalCredentials
-            | AzureIdentityCredentialAdapter
-            | ClientSecretCredential
-            | DefaultAzureCredential
-        )
+        credentials: ServicePrincipalCredentials | 
AzureIdentityCredentialAdapter
         if all([conn.login, conn.password, tenant]):
-            credential = self._get_client_secret_credential(conn)
-        else:
-            credential = self._get_default_azure_credential(conn)
-        return credential
-
-    def _get_client_secret_credential(self, conn: Connection):
-        self.log.info("Getting credentials using specific credentials and 
subscription_id.")
-        extra_dejson = conn.extra_dejson
-        tenant = extra_dejson.get("tenantId")
-        use_azure_identity_object = 
extra_dejson.get("use_azure_identity_object", False)
-        if use_azure_identity_object:
-            return ClientSecretCredential(
-                client_id=conn.login,  # type: ignore[arg-type]
-                client_secret=conn.password,  # type: ignore[arg-type]
-                tenant_id=tenant,  # type: ignore[arg-type]
+            self.log.info("Getting connection using specific credentials and 
subscription_id.")
+            credentials = ServicePrincipalCredentials(
+                client_id=conn.login, secret=conn.password, tenant=tenant
             )
-        return ServicePrincipalCredentials(client_id=conn.login, 
secret=conn.password, tenant=tenant)
-
-    def _get_default_azure_credential(self, conn: Connection):
-        self.log.info("Using DefaultAzureCredential as credential")
-        extra_dejson = conn.extra_dejson
-        managed_identity_client_id = 
extra_dejson.get("managed_identity_client_id")
-        workload_identity_tenant_id = 
extra_dejson.get("workload_identity_tenant_id")
-        use_azure_identity_object = 
extra_dejson.get("use_azure_identity_object", False)
-        if use_azure_identity_object:
-            return get_sync_default_azure_credential(
+        else:
+            self.log.info("Using DefaultAzureCredential as credential")
+            managed_identity_client_id = 
conn.extra_dejson.get("managed_identity_client_id")
+            workload_identity_tenant_id = 
conn.extra_dejson.get("workload_identity_tenant_id")
+            credentials = AzureIdentityCredentialAdapter(
                 managed_identity_client_id=managed_identity_client_id,
                 workload_identity_tenant_id=workload_identity_tenant_id,
             )
-        return AzureIdentityCredentialAdapter(
-            managed_identity_client_id=managed_identity_client_id,
-            workload_identity_tenant_id=workload_identity_tenant_id,
-        )
 
-    def get_token(self, *scopes, **kwargs) -> AccessToken:
-        """Request an access token for `scopes`."""
-        credential = self.get_credential()
-        if isinstance(credential, AzureIdentityCredentialAdapter) or 
isinstance(
-            credential, AzureIdentityCredentialAdapter
-        ):
-            raise AttributeError(
-                "The azure credential does not support get_token method. "
-                "Please set `use_azure_identity_object: True` in the 
connection extra field to use credential that support get_token method."
-            )
-        return credential.get_token(*scopes, **kwargs)
+        return self.sdk_client(
+            credentials=credentials,
+            subscription_id=subscription_id,
+        )
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
index 98df1a4c08f..89881eae165 100644
--- 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py
@@ -31,7 +31,6 @@ if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
     Connection = MagicMock()  # type: ignore[misc]
 
 MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"
-UTILS = "airflow.providers.microsoft.azure.utils"
 
 
 class TestBaseAzureHook:
@@ -112,7 +111,7 @@ class TestBaseAzureHook:
         indirect=True,
     )
     @patch("azure.common.credentials.ServicePrincipalCredentials")
-    @patch(f"{MODULE}.AzureIdentityCredentialAdapter")
+    
@patch("airflow.providers.microsoft.azure.hooks.base_azure.AzureIdentityCredentialAdapter")
     def test_get_conn_fallback_to_azure_identity_credential_adapter(
         self,
         mock_credential_adapter,
@@ -134,90 +133,3 @@ class TestBaseAzureHook:
             credentials=mock_credential,
             subscription_id="subscription_id",
         )
-
-    @patch(f"{MODULE}.ClientSecretCredential")
-    @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="azure_default",
-                login="my_login",
-                password="my_password",
-                extra={"tenantId": "my_tenant", "use_azure_identity_object": 
True},
-            ),
-        ],
-        indirect=True,
-    )
-    def test_get_credential_with_client_secret(self, mock_spc, 
mocked_connection):
-        mock_spc.return_value = "foo-bar"
-        cred = AzureBaseHook().get_credential()
-
-        mock_spc.assert_called_once_with(
-            client_id=mocked_connection.login,
-            client_secret=mocked_connection.password,
-            tenant_id=mocked_connection.extra_dejson["tenantId"],
-        )
-        assert cred == "foo-bar"
-
-    @patch(f"{UTILS}.DefaultAzureCredential")
-    @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="azure_default",
-                extra={"use_azure_identity_object": True},
-            ),
-        ],
-        indirect=True,
-    )
-    def test_get_credential_with_azure_default_credential(self, mock_spc, 
mocked_connection):
-        mock_spc.return_value = "foo-bar"
-        cred = AzureBaseHook().get_credential()
-
-        mock_spc.assert_called_once_with()
-        assert cred == "foo-bar"
-
-    @patch(f"{UTILS}.DefaultAzureCredential")
-    @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="azure_default",
-                extra={
-                    "managed_identity_client_id": "test_client_id",
-                    "workload_identity_tenant_id": "test_tenant_id",
-                    "use_azure_identity_object": True,
-                },
-            ),
-        ],
-        indirect=True,
-    )
-    def test_get_credential_with_azure_default_credential_with_extra(self, 
mock_spc, mocked_connection):
-        mock_spc.return_value = "foo-bar"
-        cred = AzureBaseHook().get_credential()
-
-        mock_spc.assert_called_once_with(
-            
managed_identity_client_id=mocked_connection.extra_dejson.get("managed_identity_client_id"),
-            
workload_identity_tenant_id=mocked_connection.extra_dejson.get("workload_identity_tenant_id"),
-            
additionally_allowed_tenants=[mocked_connection.extra_dejson.get("workload_identity_tenant_id")],
-        )
-        assert cred == "foo-bar"
-
-    @patch(f"{UTILS}.DefaultAzureCredential")
-    @pytest.mark.parametrize(
-        "mocked_connection",
-        [
-            Connection(
-                conn_id="azure_default",
-                extra={"use_azure_identity_object": True},
-            ),
-        ],
-        indirect=True,
-    )
-    def test_get_token_with_azure_default_credential(self, mock_spc, 
mocked_connection):
-        mock_spc.get_token.return_value = "new-token"
-        scope = "custom_scope"
-        token = AzureBaseHook().get_token(scope)
-
-        mock_spc.assert_called_once_with()
-        assert token == "new-token"

Reply via email to