This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ea1ebf56c8 Added paginated_run method to KiotaRequestAdapterHook in 
MSGraph (#57536)
4ea1ebf56c8 is described below

commit 4ea1ebf56c8417bebe989891fd57b48a815dcf02
Author: David Blain <[email protected]>
AuthorDate: Mon Dec 8 20:35:18 2025 +0000

    Added paginated_run method to KiotaRequestAdapterHook in MSGraph (#57536)
    
    * refactor: Added paginated_run method to KiotaRequestAdapterHook
    
    * refactor: Make execution of pagination function dynamic in 
KiotaRequestAdapterHook
    
    * refactor: Refactored PaginationCallable as protocol in 
KiotaRequestAdapterHook
    
    * Revert "refactor: Refactored PaginationCallable as protocol in 
KiotaRequestAdapterHook"
    
    This reverts commit 1885132835a1523d9a4cf800c637c40cc928936c.
    
    * refactor: Ignore type with execute_callable
    
    * refactor: Fixed default_pagination in MSGraphAsyncOperator
    
    * fix: url of default_pagination should be optional
    
    * fix: Fixed unit test related to pagination with operator
    
    * Revert "fix: Fixed unit test related to pagination with operator"
    
    This reverts commit 2ff2fa181486d44bde67ad80de47cc6877cb1030.
    
    * refactor: Replace default_pagination back to static paginate method in 
MSGraphAsyncOperator
    
    * refactor: Allow overriding proxy settings in MSGraph hook with empty dict 
so that proxy settings from connection are ignored
    
    * refactor: Reformatted tests
    
    * refactor: Make sure host starts with http or https
    
    * refactor: Fixed typing
    
    * refactor: Fixed proxy tests
    
    * refactor: Accept warnings raised by get_conn method
    
    * refactor: Reformatted proxies tests
---
 .../providers/microsoft/azure/hooks/msgraph.py     | 159 +++++++++++++++++----
 .../providers/microsoft/azure/operators/msgraph.py |  21 +--
 .../unit/microsoft/azure/hooks/test_msgraph.py     |  99 +++++++++++++
 3 files changed, 241 insertions(+), 38 deletions(-)

diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 0fdf426e477..84698c3d6a4 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,9 +18,11 @@
 from __future__ import annotations
 
 import asyncio
+import inspect
 import json
 import warnings
 from ast import literal_eval
+from collections.abc import Callable
 from contextlib import suppress
 from http import HTTPStatus
 from io import BytesIO
@@ -58,6 +60,32 @@ if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Connection
 
 
+PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]]
+
+
+def execute_callable(func: Callable, *args: Any, **kwargs: Any) -> Any:
+    """Dynamically call a function by matching its signature to provided 
args/kwargs."""
+    sig = inspect.signature(func)
+    accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in 
sig.parameters.values())
+
+    if not accepts_kwargs:
+        # Only pass arguments the function explicitly declares
+        filtered_kwargs = {k: v for k, v in kwargs.items() if k in 
sig.parameters}
+    else:
+        filtered_kwargs = kwargs
+
+    try:
+        sig.bind(*args, **filtered_kwargs)
+    except TypeError as err:
+        raise TypeError(
+            f"Failed to bind arguments to function {func.__name__}: {err}\n"
+            f"Expected parameters: {list(sig.parameters.keys())}\n"
+            f"Provided kwargs: {list(kwargs.keys())}"
+        ) from err
+
+    return func(*args, **filtered_kwargs)
+
+
 class DefaultResponseHandler(ResponseHandler):
     """DefaultResponseHandler returns JSON payload or content in bytes or 
response headers."""
 
@@ -122,7 +150,7 @@ class KiotaRequestAdapterHook(BaseHook):
         conn_id: str = default_conn_name,
         timeout: float | None = None,
         proxies: dict | None = None,
-        host: str = NationalClouds.Global.value,
+        host: str | None = None,
         scopes: str | list[str] | None = None,
         api_version: APIVersion | str | None = None,
     ):
@@ -198,8 +226,12 @@ class KiotaRequestAdapterHook(BaseHook):
         )  # type: ignore
 
     def get_host(self, connection: Connection) -> str:
-        if connection.schema and connection.host:
-            return f"{connection.schema}://{connection.host}"
+        if not self.host:
+            if connection.schema and connection.host:
+                return f"{connection.schema}://{connection.host}"
+            return NationalClouds.Global.value
+        if not self.host.startswith("http://";) or not 
self.host.startswith("https://";):
+            return f"{connection.schema}://{self.host}"
         return self.host
 
     def get_base_url(self, host: str, api_version: str, config: dict) -> str:
@@ -216,7 +248,7 @@ class KiotaRequestAdapterHook(BaseHook):
         return url
 
     @classmethod
-    def to_httpx_proxies(cls, proxies: dict) -> dict:
+    def to_httpx_proxies(cls, proxies: dict | None) -> dict | None:
         if proxies:
             proxies = proxies.copy()
             if proxies.get("http"):
@@ -226,9 +258,10 @@ class KiotaRequestAdapterHook(BaseHook):
             if proxies.get("no"):
                 for url in proxies.pop("no", "").split(","):
                     proxies[cls.format_no_proxy_url(url.strip())] = None
-        return proxies
+            return proxies
+        return None
 
-    def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict | 
None:
+    def to_msal_proxies(self, authority: str | None, proxies: dict | None) -> 
dict | None:
         self.log.debug("authority: %s", authority)
         if authority and proxies:
             no_proxies = proxies.get("no")
@@ -240,7 +273,8 @@ class KiotaRequestAdapterHook(BaseHook):
                     self.log.debug("domain_name: %s", domain_name)
                     if authority.endswith(domain_name):
                         return None
-        return proxies
+            return proxies
+        return None
 
     def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
         client_id = connection.login
@@ -361,22 +395,24 @@ class KiotaRequestAdapterHook(BaseHook):
         self.api_version = api_version
         return request_adapter
 
-    def get_proxies(self, config: dict) -> dict:
-        proxies = self.proxies or config.get("proxies", {})
-        if isinstance(proxies, str):
-            # TODO: Once provider depends on Airflow 2.10 or higher code below 
won't be needed anymore as
-            #       we could then use the get_extra_dejson method on the 
connection which deserializes
-            #       nested json. Make sure to use 
connection.get_extra_dejson(nested=True) instead of
-            #       connection.extra_dejson.
-            with suppress(JSONDecodeError):
-                proxies = json.loads(proxies)
-            with suppress(Exception):
-                proxies = literal_eval(proxies)
-        if not isinstance(proxies, dict):
-            raise AirflowConfigException(
-                f"Proxies must be of type dict, got {type(proxies).__name__} 
instead!"
-            )
-        return proxies
+    def get_proxies(self, config: dict) -> dict | None:
+        proxies = self.proxies if self.proxies is not None else 
config.get("proxies", {})
+        if proxies:
+            if isinstance(proxies, str):
+                # TODO: Once provider depends on Airflow 2.10 or higher code 
below won't be needed anymore as
+                #       we could then use the get_extra_dejson method on the 
connection which deserializes
+                #       nested json. Make sure to use 
connection.get_extra_dejson(nested=True) instead of
+                #       connection.extra_dejson.
+                with suppress(JSONDecodeError):
+                    proxies = json.loads(proxies)
+                with suppress(Exception):
+                    proxies = literal_eval(proxies)
+            if not isinstance(proxies, dict):
+                raise AirflowConfigException(
+                    f"Proxies must be of type dict, got 
{type(proxies).__name__} instead!"
+                )
+            return proxies
+        return None
 
     def get_credentials(
         self,
@@ -385,7 +421,7 @@ class KiotaRequestAdapterHook(BaseHook):
         config,
         authority: str | None,
         verify: bool,
-        proxies: dict,
+        proxies: dict | None,
     ) -> ClientCredentialBase:
         tenant_id = config.get("tenant_id") or config.get("tenantId")
         certificate_path = config.get("certificate_path")
@@ -428,6 +464,27 @@ class KiotaRequestAdapterHook(BaseHook):
         except Exception as e:
             return False, str(e)
 
+    @staticmethod
+    def default_pagination(
+        response: dict,
+        url: str | None = None,
+        query_parameters: dict[str, Any] | None = None,
+        responses: Callable[[], list[dict[str, Any]] | None] = lambda: [],
+    ) -> tuple[Any, dict[str, Any] | None]:
+        if isinstance(response, dict):
+            odata_count = response.get("@odata.count")
+            if odata_count and query_parameters:
+                top = query_parameters.get("$top")
+
+                if top and odata_count:
+                    if len(response.get("value", [])) == top:
+                        results = responses()
+                        skip = sum([len(result["value"]) for result in 
results]) + top if results else top  # type: ignore
+                        query_parameters["$skip"] = skip
+                        return url, query_parameters
+            return response.get("@odata.nextLink"), query_parameters
+        return None, query_parameters
+
     async def run(
         self,
         url: str = "",
@@ -457,6 +514,60 @@ class KiotaRequestAdapterHook(BaseHook):
 
         return response
 
+    async def paginated_run(
+        self,
+        url: str = "",
+        response_type: str | None = None,
+        path_parameters: dict[str, Any] | None = None,
+        method: str = "GET",
+        query_parameters: dict[str, Any] | None = None,
+        headers: dict[str, str] | None = None,
+        data: dict[str, Any] | str | BytesIO | None = None,
+        pagination_function: PaginationCallable | None = None,
+    ):
+        if pagination_function is None:
+            pagination_function = self.default_pagination
+
+        responses: list[dict] = []
+
+        async def run(
+            url: str = "",
+            query_parameters: dict[str, Any] | None = None,
+        ):
+            while url:
+                response = await self.run(
+                    url=url,
+                    response_type=response_type,
+                    path_parameters=path_parameters,
+                    method=method,
+                    query_parameters=query_parameters,
+                    headers=headers,
+                    data=data,
+                )
+
+                if response:
+                    responses.append(response)
+
+                    if pagination_function:
+                        url, query_parameters = execute_callable(
+                            pagination_function,
+                            response=response,
+                            url=url,
+                            response_type=response_type,
+                            path_parameters=path_parameters,
+                            method=method,
+                            query_parameters=query_parameters,
+                            headers=headers,
+                            data=data,
+                            responses=lambda: responses,
+                        )
+                else:
+                    break
+
+        await run(url=url, query_parameters=query_parameters)
+
+        return responses
+
     async def send_request(self, request_info: RequestInformation, 
response_type: str | None = None):
         conn = await self.get_async_conn()
 
diff --git 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
index ea7197e8aa8..2449ac1a1fd 100644
--- 
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
+++ 
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -20,7 +20,6 @@ from __future__ import annotations
 import warnings
 from collections.abc import Callable, Sequence
 from contextlib import suppress
-from copy import deepcopy
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -247,7 +246,7 @@ class MSGraphAsyncOperator(BaseOperator):
     @classmethod
     def append_result(
         cls,
-        results: Any,
+        results: list[Any],
         result: Any,
         append_result_as_list_if_absent: bool = False,
     ) -> list[Any]:
@@ -312,18 +311,12 @@ class MSGraphAsyncOperator(BaseOperator):
     def paginate(
         operator: MSGraphAsyncOperator, response: dict, **context
     ) -> tuple[Any, dict[str, Any] | None]:
-        odata_count = response.get("@odata.count")
-        if odata_count and operator.query_parameters:
-            query_parameters = deepcopy(operator.query_parameters)
-            top = query_parameters.get("$top")
-
-            if top and odata_count:
-                if len(response.get("value", [])) == top and context:
-                    results = operator.pull_xcom(context)
-                    skip = sum([len(result["value"]) for result in results]) + 
top if results else top
-                    query_parameters["$skip"] = skip
-                    return operator.url, query_parameters
-        return response.get("@odata.nextLink"), operator.query_parameters
+        return KiotaRequestAdapterHook.default_pagination(
+            response=response,
+            url=operator.url,
+            query_parameters=operator.query_parameters,
+            responses=lambda: operator.pull_xcom(context),
+        )
 
     def trigger_next_link(self, response, method_name: str, context: Context) 
-> None:
         if isinstance(response, dict):
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 1162f506745..f4d8bba7d8f 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -38,6 +38,7 @@ from airflow.providers.common.compat.sdk import 
AirflowException, AirflowNotFoun
 from airflow.providers.microsoft.azure.hooks.msgraph import (
     DefaultResponseHandler,
     KiotaRequestAdapterHook,
+    execute_callable,
 )
 
 from tests_common.test_utils.file_loading import load_file_from_resources, 
load_json_from_resources
@@ -48,6 +49,7 @@ from unit.microsoft.azure.test_utils import (
     mock_json_response,
     mock_response,
     patch_hook,
+    patch_hook_and_request_adapter,
 )
 
 if TYPE_CHECKING:
@@ -232,6 +234,45 @@ class TestKiotaRequestAdapterHook:
 
             assert actual == NationalClouds.Global.value
 
+    def 
test_get_host_when_connection_has_no_scheme_or_host_but_hook_overrides_host(self):
+        with patch_hook():
+            hook = KiotaRequestAdapterHook(
+                conn_id="msgraph_api", 
host="wabi-north-europe-o-primary-redirect.analysis.windows.net"
+            )
+            connection = mock_connection(schema="https", 
host=NationalClouds.Global.value)
+            actual = hook.get_host(connection)
+
+            assert actual == 
"https://wabi-north-europe-o-primary-redirect.analysis.windows.net";
+
+    def test_execute_callable(self):
+        response = load_json_from_resources(dirname(__file__), "..", 
"resources", "users.json")
+
+        url, query_parameters = execute_callable(
+            KiotaRequestAdapterHook.default_pagination,
+            response=response,
+        )
+
+        assert url == response["@odata.nextLink"]
+        assert not query_parameters
+
+    def test_execute_callable_with_additional_parameters(self):
+        response = load_json_from_resources(dirname(__file__), "..", 
"resources", "users.json")
+
+        url, query_parameters = execute_callable(
+            KiotaRequestAdapterHook.default_pagination,
+            response=response,
+            url="users",
+            query_parameters={},
+            data=None,
+        )
+
+        assert url == response["@odata.nextLink"]
+        assert query_parameters == {}
+
+    def test_execute_callable_when_required_parameter_is_missing(self):
+        with pytest.raises(TypeError):
+            execute_callable(KiotaRequestAdapterHook.default_pagination)
+
     @pytest.mark.asyncio
     async def test_tenant_id(self):
         with patch_hook():
@@ -253,6 +294,36 @@ class TestKiotaRequestAdapterHook:
 
             self.assert_tenant_id(actual, "azure-tenant-id")
 
+    @pytest.mark.asyncio
+    async def test_proxies(self):
+        with patch_hook(
+            side_effect=lambda conn_id: get_airflow_connection(
+                conn_id=conn_id,
+                proxies={"http": "http://proxy:80";, "https": 
"https://proxy:80"},
+            )
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            with pytest.warns(AirflowProviderDeprecationWarning):
+                actual = hook.get_conn()
+
+            assert actual._http_client._mounts
+
+    @pytest.mark.asyncio
+    async def test_proxies_override_with_empty_dict(self):
+        with patch_hook(
+            side_effect=lambda conn_id: get_airflow_connection(
+                conn_id=conn_id,
+                proxies={"http": "http://proxy:80";, "https": 
"https://proxy:80"},
+            )
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api", proxies={})
+
+            with pytest.warns(AirflowProviderDeprecationWarning):
+                actual = hook.get_conn()
+
+            assert not actual._http_client._mounts
+
     def test_encoded_query_parameters(self):
         actual = KiotaRequestAdapterHook.encoded_query_parameters(
             query_parameters={"$expand": 
"reports,users,datasets,dataflows,dashboards", "$top": 5000},
@@ -313,6 +384,34 @@ class TestKiotaRequestAdapterHook:
             error_code = 
actual.get_child_node("error").get_child_node("code").get_str_value()
             assert error_code == "TenantThrottleThresholdExceeded"
 
+    @pytest.mark.asyncio
+    async def test_run(self):
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
+        next_users = load_json_from_resources(dirname(__file__), "..", 
"resources", "next_users.json")
+        response = mock_json_response(200, users, next_users)
+
+        with patch_hook_and_request_adapter(response):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            actual = await hook.run(url="users")
+
+            assert isinstance(actual, dict)
+            assert actual == users
+
+    @pytest.mark.asyncio
+    async def test_paginated_run(self):
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
+        next_users = load_json_from_resources(dirname(__file__), "..", 
"resources", "next_users.json")
+        response = mock_json_response(200, users, next_users)
+
+        with patch_hook_and_request_adapter(response):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            actual = await hook.paginated_run(url="users")
+
+            assert isinstance(actual, list)
+            assert actual == [users, next_users]
+
 
 class TestResponseHandler:
     def test_default_response_handler_when_json(self):

Reply via email to