This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9928bb3292e Use async versions of CertificateCredential and 
ClientSecretCredential in KiotaRequestAdapterHook (#68375)
9928bb3292e is described below

commit 9928bb3292e9cda8e54f0cf624061b0798f521d4
Author: David Blain <[email protected]>
AuthorDate: Thu Jun 18 08:54:36 2026 +0200

    Use async versions of CertificateCredential and ClientSecretCredential in 
KiotaRequestAdapterHook (#68375)
    
    * refactor: Use async versions of CertificateCredential and 
ClientSecretCredential to avoid blocking the event loop, especially when used 
concurrently
    
    * refactor: Fix stale cached request adapter causing "HTTP transport has 
already been closed" errors
    
    * refactor: Invalidate cached request adapters which have closed session in 
KiotaRequestAdapterHook
---
 .../providers/microsoft/azure/hooks/msgraph.py     |  52 +++++--
 .../unit/microsoft/azure/hooks/test_msgraph.py     | 160 +++++++++++++++++----
 2 files changed, 166 insertions(+), 46 deletions(-)

diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 32e21f61ebe..2f3bf3a4030 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, cast
 from urllib.parse import quote, urljoin, urlparse
 
 import httpx
-from azure.identity import CertificateCredential, ClientSecretCredential
+from azure.identity.aio import CertificateCredential, ClientSecretCredential
 from httpx import AsyncHTTPTransport, Response, Timeout
 from kiota_abstractions.api_error import APIError
 from kiota_abstractions.method import Method
@@ -50,18 +50,16 @@ from msgraph_core._enums import NationalClouds
 
 from airflow.exceptions import AirflowBadRequest, AirflowConfigException, 
AirflowProviderDeprecationWarning
 from airflow.providers.common.compat.connection import get_async_connection
-from airflow.providers.common.compat.sdk import AirflowException, 
AirflowNotFoundException, BaseHook
+from airflow.providers.common.compat.sdk import AirflowException, 
AirflowNotFoundException, BaseHook, redact
 
 if TYPE_CHECKING:
-    from azure.identity._internal.client_credential_base import 
ClientCredentialBase
+    from azure.core.credentials_async import AsyncTokenCredential
     from kiota_abstractions.request_adapter import RequestAdapter
     from kiota_abstractions.response_handler import NativeResponseType
     from kiota_abstractions.serialization import ParsableFactory
 
     from airflow.providers.common.compat.sdk import Connection
 
-from airflow.providers.common.compat.sdk import redact
-
 PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]]
 
 
@@ -366,7 +364,6 @@ class KiotaRequestAdapterHook(BaseHook):
             http_client=http_client,
             base_url=base_url,
         )
-        self.cached_request_adapters[self.conn_id] = (api_version, 
request_adapter)
         return api_version, request_adapter
 
     def get_conn(self) -> RequestAdapter:
@@ -374,7 +371,7 @@ class KiotaRequestAdapterHook(BaseHook):
         Initiate a new RequestAdapter connection.
 
         .. warning::
-           This method is deprecated.
+           This method is deprecated. Use :meth:`get_async_conn` instead.
         """
         if not self.conn_id:
             raise AirflowException("Failed to create the 
KiotaRequestAdapterHook. No conn_id provided!")
@@ -390,9 +387,15 @@ class KiotaRequestAdapterHook(BaseHook):
         if not request_adapter:
             connection = self.get_connection(conn_id=self.conn_id)
             api_version, request_adapter = 
self._build_request_adapter(connection)
+            self.cached_request_adapters[self.conn_id] = (api_version, 
request_adapter)
         self.api_version = api_version
         return request_adapter
 
+    @staticmethod
+    def _is_http_client_closed(request_adapter: RequestAdapter) -> bool:
+        """Return True when the underlying httpx AsyncClient has been 
closed."""
+        return cast("HttpxRequestAdapter", 
request_adapter)._http_client.is_closed
+
     async def get_async_conn(self) -> RequestAdapter:
         """Initiate a new RequestAdapter connection asynchronously."""
         if not self.conn_id:
@@ -400,9 +403,19 @@ class KiotaRequestAdapterHook(BaseHook):
 
         api_version, request_adapter = 
self.cached_request_adapters.get(self.conn_id, (None, None))
 
+        if request_adapter and self._is_http_client_closed(request_adapter):
+            self.log.warning(
+                "Cached request adapter for conn_id '%s' has a closed HTTP 
client. Rebuilding.",
+                self.conn_id,
+            )
+            self.cached_request_adapters.pop(self.conn_id, None)
+            request_adapter = None
+
         if not request_adapter:
             connection = await get_async_connection(conn_id=self.conn_id)
             api_version, request_adapter = 
self._build_request_adapter(connection)
+            self.cached_request_adapters[self.conn_id] = (api_version, 
request_adapter)
+
         self.api_version = api_version
         return request_adapter
 
@@ -433,7 +446,7 @@ class KiotaRequestAdapterHook(BaseHook):
         authority: str | None,
         verify: bool,
         proxies: dict | None,
-    ) -> ClientCredentialBase:
+    ) -> AsyncTokenCredential:
         tenant_id = config.get("tenant_id") or config.get("tenantId")
         certificate_path = config.get("certificate_path")
         certificate_data = config.get("certificate_data")
@@ -582,16 +595,25 @@ class KiotaRequestAdapterHook(BaseHook):
     async def send_request(self, request_info: RequestInformation, 
response_type: str | None = None):
         conn = await self.get_async_conn()
 
-        if response_type:
-            return await conn.send_primitive_async(
+        try:
+            if response_type:
+                return await conn.send_primitive_async(
+                    request_info=request_info,
+                    response_type=response_type,
+                    error_map=self.error_mapping(),
+                )
+            return await conn.send_no_response_content_async(
                 request_info=request_info,
-                response_type=response_type,
                 error_map=self.error_mapping(),
             )
-        return await conn.send_no_response_content_async(
-            request_info=request_info,
-            error_map=self.error_mapping(),
-        )
+        except Exception as e:
+            self.log.warning(
+                "Request failed for conn_id '%s': %s. Invalidating cached 
request adapter.",
+                self.conn_id,
+                e,
+            )
+            self.cached_request_adapters.pop(self.conn_id, None)
+            raise
 
     def request_information(
         self,
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 96443b5a67d..dab96656d43 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -18,10 +18,11 @@ from __future__ import annotations
 
 import asyncio
 import inspect
+from contextlib import AbstractAsyncContextManager
 from json import JSONDecodeError
 from os.path import dirname
-from typing import TYPE_CHECKING, cast
-from unittest.mock import Mock, patch
+from typing import cast
+from unittest.mock import AsyncMock, Mock, patch
 
 import pytest
 from httpx import Response
@@ -52,31 +53,8 @@ from unit.microsoft.azure.test_utils import (
     patch_hook_and_request_adapter,
 )
 
-if TYPE_CHECKING:
-    from azure.identity._internal.msal_credentials import MsalCredential
-    from kiota_abstractions.authentication import 
BaseBearerTokenAuthenticationProvider
-    from kiota_abstractions.request_adapter import RequestAdapter
-    from kiota_authentication_azure.azure_identity_access_token_provider 
import (
-        AzureIdentityAccessTokenProvider,
-    )
-
 
 class TestKiotaRequestAdapterHook:
-    @staticmethod
-    def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: 
str):
-        adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", 
request_adapter)
-        auth_provider: BaseBearerTokenAuthenticationProvider = cast(
-            "BaseBearerTokenAuthenticationProvider",
-            adapter._authentication_provider,
-        )
-        access_token_provider: AzureIdentityAccessTokenProvider = cast(
-            "AzureIdentityAccessTokenProvider",
-            auth_provider.access_token_provider,
-        )
-        credentials: MsalCredential = cast("MsalCredential", 
access_token_provider._credentials)
-        tenant_id = credentials._tenant_id
-        assert tenant_id == expected_tenant_id
-
     def test_get_conn(self):
         with patch_hook():
             hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
@@ -276,10 +254,15 @@ class TestKiotaRequestAdapterHook:
     @pytest.mark.asyncio
     async def test_tenant_id(self):
         with patch_hook():
-            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
-            actual = await hook.get_async_conn()
+            with patch(
+                
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
+                autospec=True,
+            ) as mock_credential_cls:
+                hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+                await hook.get_async_conn()
 
-            self.assert_tenant_id(actual, "tenant-id")
+                mock_credential_cls.assert_called_once()
+                assert mock_credential_cls.call_args.kwargs.get("tenant_id") 
== "tenant-id"
 
     @pytest.mark.asyncio
     async def test_azure_tenant_id(self):
@@ -289,10 +272,15 @@ class TestKiotaRequestAdapterHook:
                 azure_tenant_id="azure-tenant-id",
             )
         ):
-            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
-            actual = await hook.get_async_conn()
+            with patch(
+                
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
+                autospec=True,
+            ) as mock_credential_cls:
+                hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+                await hook.get_async_conn()
 
-            self.assert_tenant_id(actual, "azure-tenant-id")
+                mock_credential_cls.assert_called_once()
+                assert mock_credential_cls.call_args.kwargs.get("tenant_id") 
== "azure-tenant-id"
 
     @pytest.mark.asyncio
     async def test_proxies(self):
@@ -472,6 +460,116 @@ class TestKiotaRequestAdapterHook:
 
         assert result == proxies
 
+    def test_get_credentials_returns_async_client_secret_credential(self):
+        """get_credentials must return an async context manager 
(azure.identity.aio credential)."""
+        hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+        config = {"tenant_id": "tenant-id"}
+
+        credentials = hook.get_credentials(
+            login="client_id",
+            password="client_secret",
+            config=config,
+            authority=None,
+            verify=True,
+            proxies=None,
+        )
+
+        assert isinstance(credentials, AbstractAsyncContextManager)
+
+    def test_get_credentials_returns_async_certificate_credential(self):
+        """get_credentials must return an async context manager when 
certificate_data is set."""
+        import datetime
+
+        from cryptography import x509
+        from cryptography.hazmat.primitives import hashes, serialization
+        from cryptography.hazmat.primitives.asymmetric import rsa
+        from cryptography.x509.oid import NameOID
+
+        private_key = rsa.generate_private_key(public_exponent=65537, 
key_size=2048)
+        name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
+        cert = (
+            x509.CertificateBuilder()
+            .subject_name(name)
+            .issuer_name(name)
+            .public_key(private_key.public_key())
+            .serial_number(x509.random_serial_number())
+            .not_valid_before(datetime.datetime.now(datetime.timezone.utc))
+            .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + 
datetime.timedelta(days=1))
+            .sign(private_key, hashes.SHA256())
+        )
+        pem = private_key.private_bytes(
+            serialization.Encoding.PEM,
+            serialization.PrivateFormat.TraditionalOpenSSL,
+            serialization.NoEncryption(),
+        ) + cert.public_bytes(serialization.Encoding.PEM)
+
+        hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+        config = {
+            "tenant_id": "tenant-id",
+            "certificate_data": pem.decode(),
+        }
+
+        credentials = hook.get_credentials(
+            login="client_id",
+            password=None,
+            config=config,
+            authority=None,
+            verify=True,
+            proxies=None,
+        )
+
+        assert isinstance(credentials, AbstractAsyncContextManager)
+
+    @pytest.mark.asyncio
+    async def test_get_async_conn_uses_async_credentials(self):
+        """get_async_conn must build a request adapter backed by async 
credentials."""
+        with patch_hook():
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+            request_adapter = await hook.get_async_conn()
+
+            adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", 
request_adapter)
+            # Reach into the auth provider chain to retrieve the underlying 
credential object.
+            access_token_provider = 
adapter._authentication_provider.access_token_provider
+            credentials = access_token_provider._credentials
+
+            assert isinstance(credentials, AbstractAsyncContextManager)
+
+    @pytest.mark.asyncio
+    async def 
test_get_async_conn_rebuilds_adapter_when_http_client_is_closed(self):
+        """get_async_conn evicts and rebuilds the adapter when the cached HTTP 
client is already closed."""
+        with patch_hook():
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            stale_adapter = Mock(spec=HttpxRequestAdapter)
+            stale_adapter._http_client = Mock(is_closed=True)
+            hook.cached_request_adapters[hook.conn_id] = (hook.api_version, 
stale_adapter)
+
+            fresh_adapter = Mock(spec=HttpxRequestAdapter)
+            fresh_adapter._http_client = Mock(is_closed=False)
+
+            with patch.object(hook, "_build_request_adapter", 
return_value=("v1.0", fresh_adapter)):
+                result = await hook.get_async_conn()
+
+            assert result is fresh_adapter
+            assert hook.cached_request_adapters[hook.conn_id] == ("v1.0", 
fresh_adapter)
+
+    @pytest.mark.asyncio
+    async def 
test_send_request_invalidates_cache_and_raises_on_any_error(self):
+        """send_request evicts the cached adapter and re-raises on any request 
error."""
+        with patch_hook():
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            adapter = Mock(spec=HttpxRequestAdapter)
+            adapter._http_client = Mock(is_closed=False)
+            adapter.send_no_response_content_async = 
AsyncMock(side_effect=RuntimeError("some error"))
+            hook.cached_request_adapters[hook.conn_id] = (hook.api_version, 
adapter)
+
+            with pytest.raises(RuntimeError, match="some error"):
+                await hook.run(url="users")
+
+            adapter.send_no_response_content_async.assert_called_once()
+            assert hook.conn_id not in hook.cached_request_adapters
+
 
 class TestKiotaRequestAdapterHookProtocol:
     """Test protocol handling in KiotaRequestAdapterHook."""

Reply via email to