This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4ea1ebf56c8 Added paginated_run method to KiotaRequestAdapterHook in
MSGraph (#57536)
4ea1ebf56c8 is described below
commit 4ea1ebf56c8417bebe989891fd57b48a815dcf02
Author: David Blain <[email protected]>
AuthorDate: Mon Dec 8 20:35:18 2025 +0000
Added paginated_run method to KiotaRequestAdapterHook in MSGraph (#57536)
* refactor: Added paginated_run method to KiotaRequestAdapterHook
* refactor: Make execution of pagination function dynamic in
KiotaRequestAdapterHook
* refactor: Refactored PaginationCallable as protocol in
KiotaRequestAdapterHook
* Revert "refactor: Refactored PaginationCallable as protocol in
KiotaRequestAdapterHook"
This reverts commit 1885132835a1523d9a4cf800c637c40cc928936c.
* refactor: Ignore type with execute_callable
* refactor: Fixed default_pagination in MSGraphAsyncOperator
* fix: url of default_pagination should be optional
* fix: Fixed unit test related to pagination with operator
* Revert "fix: Fixed unit test related to pagination with operator"
This reverts commit 2ff2fa181486d44bde67ad80de47cc6877cb1030.
* refactor: Replace default_pagination back to static paginate method in
MSGraphAsyncOperator
* refactor: Allow overriding proxy settings in MSGraph hook with empty dict
so that proxy settings from connection are ignored
* refactor: Reformatted tests
* refactor: Make sure host starts with http or https
* refactor: Fixed typing
* refactor: Fixed proxy tests
* refactor: Accept warnings raised by get_conn method
* refactor: Reformatted proxies tests
---
.../providers/microsoft/azure/hooks/msgraph.py | 159 +++++++++++++++++----
.../providers/microsoft/azure/operators/msgraph.py | 21 +--
.../unit/microsoft/azure/hooks/test_msgraph.py | 99 +++++++++++++
3 files changed, 241 insertions(+), 38 deletions(-)
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 0fdf426e477..84698c3d6a4 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,9 +18,11 @@
from __future__ import annotations
import asyncio
+import inspect
import json
import warnings
from ast import literal_eval
+from collections.abc import Callable
from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
@@ -58,6 +60,32 @@ if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Connection
+PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]]
+
+
+def execute_callable(func: Callable, *args: Any, **kwargs: Any) -> Any:
+ """Dynamically call a function by matching its signature to provided
args/kwargs."""
+ sig = inspect.signature(func)
+ accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in
sig.parameters.values())
+
+ if not accepts_kwargs:
+ # Only pass arguments the function explicitly declares
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in
sig.parameters}
+ else:
+ filtered_kwargs = kwargs
+
+ try:
+ sig.bind(*args, **filtered_kwargs)
+ except TypeError as err:
+ raise TypeError(
+ f"Failed to bind arguments to function {func.__name__}: {err}\n"
+ f"Expected parameters: {list(sig.parameters.keys())}\n"
+ f"Provided kwargs: {list(kwargs.keys())}"
+ ) from err
+
+ return func(*args, **filtered_kwargs)
+
+
class DefaultResponseHandler(ResponseHandler):
"""DefaultResponseHandler returns JSON payload or content in bytes or
response headers."""
@@ -122,7 +150,7 @@ class KiotaRequestAdapterHook(BaseHook):
conn_id: str = default_conn_name,
timeout: float | None = None,
proxies: dict | None = None,
- host: str = NationalClouds.Global.value,
+ host: str | None = None,
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
):
@@ -198,8 +226,12 @@ class KiotaRequestAdapterHook(BaseHook):
) # type: ignore
def get_host(self, connection: Connection) -> str:
- if connection.schema and connection.host:
- return f"{connection.schema}://{connection.host}"
+ if not self.host:
+ if connection.schema and connection.host:
+ return f"{connection.schema}://{connection.host}"
+ return NationalClouds.Global.value
+ if not self.host.startswith("http://") or not
self.host.startswith("https://"):
+ return f"{connection.schema}://{self.host}"
return self.host
def get_base_url(self, host: str, api_version: str, config: dict) -> str:
@@ -216,7 +248,7 @@ class KiotaRequestAdapterHook(BaseHook):
return url
@classmethod
- def to_httpx_proxies(cls, proxies: dict) -> dict:
+ def to_httpx_proxies(cls, proxies: dict | None) -> dict | None:
if proxies:
proxies = proxies.copy()
if proxies.get("http"):
@@ -226,9 +258,10 @@ class KiotaRequestAdapterHook(BaseHook):
if proxies.get("no"):
for url in proxies.pop("no", "").split(","):
proxies[cls.format_no_proxy_url(url.strip())] = None
- return proxies
+ return proxies
+ return None
- def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict |
None:
+ def to_msal_proxies(self, authority: str | None, proxies: dict | None) ->
dict | None:
self.log.debug("authority: %s", authority)
if authority and proxies:
no_proxies = proxies.get("no")
@@ -240,7 +273,8 @@ class KiotaRequestAdapterHook(BaseHook):
self.log.debug("domain_name: %s", domain_name)
if authority.endswith(domain_name):
return None
- return proxies
+ return proxies
+ return None
def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
client_id = connection.login
@@ -361,22 +395,24 @@ class KiotaRequestAdapterHook(BaseHook):
self.api_version = api_version
return request_adapter
- def get_proxies(self, config: dict) -> dict:
- proxies = self.proxies or config.get("proxies", {})
- if isinstance(proxies, str):
- # TODO: Once provider depends on Airflow 2.10 or higher code below
won't be needed anymore as
- # we could then use the get_extra_dejson method on the
connection which deserializes
- # nested json. Make sure to use
connection.get_extra_dejson(nested=True) instead of
- # connection.extra_dejson.
- with suppress(JSONDecodeError):
- proxies = json.loads(proxies)
- with suppress(Exception):
- proxies = literal_eval(proxies)
- if not isinstance(proxies, dict):
- raise AirflowConfigException(
- f"Proxies must be of type dict, got {type(proxies).__name__}
instead!"
- )
- return proxies
+ def get_proxies(self, config: dict) -> dict | None:
+ proxies = self.proxies if self.proxies is not None else
config.get("proxies", {})
+ if proxies:
+ if isinstance(proxies, str):
+ # TODO: Once provider depends on Airflow 2.10 or higher code
below won't be needed anymore as
+ # we could then use the get_extra_dejson method on the
connection which deserializes
+ # nested json. Make sure to use
connection.get_extra_dejson(nested=True) instead of
+ # connection.extra_dejson.
+ with suppress(JSONDecodeError):
+ proxies = json.loads(proxies)
+ with suppress(Exception):
+ proxies = literal_eval(proxies)
+ if not isinstance(proxies, dict):
+ raise AirflowConfigException(
+ f"Proxies must be of type dict, got
{type(proxies).__name__} instead!"
+ )
+ return proxies
+ return None
def get_credentials(
self,
@@ -385,7 +421,7 @@ class KiotaRequestAdapterHook(BaseHook):
config,
authority: str | None,
verify: bool,
- proxies: dict,
+ proxies: dict | None,
) -> ClientCredentialBase:
tenant_id = config.get("tenant_id") or config.get("tenantId")
certificate_path = config.get("certificate_path")
@@ -428,6 +464,27 @@ class KiotaRequestAdapterHook(BaseHook):
except Exception as e:
return False, str(e)
+ @staticmethod
+ def default_pagination(
+ response: dict,
+ url: str | None = None,
+ query_parameters: dict[str, Any] | None = None,
+ responses: Callable[[], list[dict[str, Any]] | None] = lambda: [],
+ ) -> tuple[Any, dict[str, Any] | None]:
+ if isinstance(response, dict):
+ odata_count = response.get("@odata.count")
+ if odata_count and query_parameters:
+ top = query_parameters.get("$top")
+
+ if top and odata_count:
+ if len(response.get("value", [])) == top:
+ results = responses()
+ skip = sum([len(result["value"]) for result in
results]) + top if results else top # type: ignore
+ query_parameters["$skip"] = skip
+ return url, query_parameters
+ return response.get("@odata.nextLink"), query_parameters
+ return None, query_parameters
+
async def run(
self,
url: str = "",
@@ -457,6 +514,60 @@ class KiotaRequestAdapterHook(BaseHook):
return response
+ async def paginated_run(
+ self,
+ url: str = "",
+ response_type: str | None = None,
+ path_parameters: dict[str, Any] | None = None,
+ method: str = "GET",
+ query_parameters: dict[str, Any] | None = None,
+ headers: dict[str, str] | None = None,
+ data: dict[str, Any] | str | BytesIO | None = None,
+ pagination_function: PaginationCallable | None = None,
+ ):
+ if pagination_function is None:
+ pagination_function = self.default_pagination
+
+ responses: list[dict] = []
+
+ async def run(
+ url: str = "",
+ query_parameters: dict[str, Any] | None = None,
+ ):
+ while url:
+ response = await self.run(
+ url=url,
+ response_type=response_type,
+ path_parameters=path_parameters,
+ method=method,
+ query_parameters=query_parameters,
+ headers=headers,
+ data=data,
+ )
+
+ if response:
+ responses.append(response)
+
+ if pagination_function:
+ url, query_parameters = execute_callable(
+ pagination_function,
+ response=response,
+ url=url,
+ response_type=response_type,
+ path_parameters=path_parameters,
+ method=method,
+ query_parameters=query_parameters,
+ headers=headers,
+ data=data,
+ responses=lambda: responses,
+ )
+ else:
+ break
+
+ await run(url=url, query_parameters=query_parameters)
+
+ return responses
+
async def send_request(self, request_info: RequestInformation,
response_type: str | None = None):
conn = await self.get_async_conn()
diff --git
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
index ea7197e8aa8..2449ac1a1fd 100644
---
a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
+++
b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from contextlib import suppress
-from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
@@ -247,7 +246,7 @@ class MSGraphAsyncOperator(BaseOperator):
@classmethod
def append_result(
cls,
- results: Any,
+ results: list[Any],
result: Any,
append_result_as_list_if_absent: bool = False,
) -> list[Any]:
@@ -312,18 +311,12 @@ class MSGraphAsyncOperator(BaseOperator):
def paginate(
operator: MSGraphAsyncOperator, response: dict, **context
) -> tuple[Any, dict[str, Any] | None]:
- odata_count = response.get("@odata.count")
- if odata_count and operator.query_parameters:
- query_parameters = deepcopy(operator.query_parameters)
- top = query_parameters.get("$top")
-
- if top and odata_count:
- if len(response.get("value", [])) == top and context:
- results = operator.pull_xcom(context)
- skip = sum([len(result["value"]) for result in results]) +
top if results else top
- query_parameters["$skip"] = skip
- return operator.url, query_parameters
- return response.get("@odata.nextLink"), operator.query_parameters
+ return KiotaRequestAdapterHook.default_pagination(
+ response=response,
+ url=operator.url,
+ query_parameters=operator.query_parameters,
+ responses=lambda: operator.pull_xcom(context),
+ )
def trigger_next_link(self, response, method_name: str, context: Context)
-> None:
if isinstance(response, dict):
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 1162f506745..f4d8bba7d8f 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -38,6 +38,7 @@ from airflow.providers.common.compat.sdk import
AirflowException, AirflowNotFoun
from airflow.providers.microsoft.azure.hooks.msgraph import (
DefaultResponseHandler,
KiotaRequestAdapterHook,
+ execute_callable,
)
from tests_common.test_utils.file_loading import load_file_from_resources,
load_json_from_resources
@@ -48,6 +49,7 @@ from unit.microsoft.azure.test_utils import (
mock_json_response,
mock_response,
patch_hook,
+ patch_hook_and_request_adapter,
)
if TYPE_CHECKING:
@@ -232,6 +234,45 @@ class TestKiotaRequestAdapterHook:
assert actual == NationalClouds.Global.value
+ def
test_get_host_when_connection_has_no_scheme_or_host_but_hook_overrides_host(self):
+ with patch_hook():
+ hook = KiotaRequestAdapterHook(
+ conn_id="msgraph_api",
host="wabi-north-europe-o-primary-redirect.analysis.windows.net"
+ )
+ connection = mock_connection(schema="https",
host=NationalClouds.Global.value)
+ actual = hook.get_host(connection)
+
+ assert actual ==
"https://wabi-north-europe-o-primary-redirect.analysis.windows.net"
+
+ def test_execute_callable(self):
+ response = load_json_from_resources(dirname(__file__), "..",
"resources", "users.json")
+
+ url, query_parameters = execute_callable(
+ KiotaRequestAdapterHook.default_pagination,
+ response=response,
+ )
+
+ assert url == response["@odata.nextLink"]
+ assert not query_parameters
+
+ def test_execute_callable_with_additional_parameters(self):
+ response = load_json_from_resources(dirname(__file__), "..",
"resources", "users.json")
+
+ url, query_parameters = execute_callable(
+ KiotaRequestAdapterHook.default_pagination,
+ response=response,
+ url="users",
+ query_parameters={},
+ data=None,
+ )
+
+ assert url == response["@odata.nextLink"]
+ assert query_parameters == {}
+
+ def test_execute_callable_when_required_parameter_is_missing(self):
+ with pytest.raises(TypeError):
+ execute_callable(KiotaRequestAdapterHook.default_pagination)
+
@pytest.mark.asyncio
async def test_tenant_id(self):
with patch_hook():
@@ -253,6 +294,36 @@ class TestKiotaRequestAdapterHook:
self.assert_tenant_id(actual, "azure-tenant-id")
+ @pytest.mark.asyncio
+ async def test_proxies(self):
+ with patch_hook(
+ side_effect=lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ proxies={"http": "http://proxy:80", "https":
"https://proxy:80"},
+ )
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ actual = hook.get_conn()
+
+ assert actual._http_client._mounts
+
+ @pytest.mark.asyncio
+ async def test_proxies_override_with_empty_dict(self):
+ with patch_hook(
+ side_effect=lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ proxies={"http": "http://proxy:80", "https":
"https://proxy:80"},
+ )
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api", proxies={})
+
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ actual = hook.get_conn()
+
+ assert not actual._http_client._mounts
+
def test_encoded_query_parameters(self):
actual = KiotaRequestAdapterHook.encoded_query_parameters(
query_parameters={"$expand":
"reports,users,datasets,dataflows,dashboards", "$top": 5000},
@@ -313,6 +384,34 @@ class TestKiotaRequestAdapterHook:
error_code =
actual.get_child_node("error").get_child_node("code").get_str_value()
assert error_code == "TenantThrottleThresholdExceeded"
+ @pytest.mark.asyncio
+ async def test_run(self):
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
+ next_users = load_json_from_resources(dirname(__file__), "..",
"resources", "next_users.json")
+ response = mock_json_response(200, users, next_users)
+
+ with patch_hook_and_request_adapter(response):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ actual = await hook.run(url="users")
+
+ assert isinstance(actual, dict)
+ assert actual == users
+
+ @pytest.mark.asyncio
+ async def test_paginated_run(self):
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
+ next_users = load_json_from_resources(dirname(__file__), "..",
"resources", "next_users.json")
+ response = mock_json_response(200, users, next_users)
+
+ with patch_hook_and_request_adapter(response):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ actual = await hook.paginated_run(url="users")
+
+ assert isinstance(actual, list)
+ assert actual == [users, next_users]
+
class TestResponseHandler:
def test_default_response_handler_when_json(self):