This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2b101e2377 Feature: Added event_handler parameter in
MSGraphAsyncOperator (#42539)
2b101e2377 is described below
commit 2b101e2377f8d49a46aca6c219e4b38ee099a98d
Author: David Blain <[email protected]>
AuthorDate: Tue Oct 15 02:41:36 2024 +0200
Feature: Added event_handler parameter in MSGraphAsyncOperator (#42539)
* refactor: Added parameter in MSGraphAsyncOperator to allow overriding
default event_handler
* docs: Added docstring for event_handler parameter in MSGraphAsyncOperator
* refactor: Fixed TestMSGraphAsyncOperator
* refactor: Check if event is not None
* refactor: Register the TextParseNodeFactory and JsonParseNodeFactory so
error messages get handled correctly in RequestAdapter
* refactor: Reorganized import for TestMSGraphAsyncOperator
* refactor: Added missing kiota-serialization packages in azure provider
* refactor: Updated provider dependencies
* refactor: Reorganized import of TestKiotaRequestAdapterHook
* refactor: Downgraded version of json kiota serialization
* refactor: Updated provider dependencies
* refactor: Put import of Context in TYPE_CHECKING block
* refactor: Fixed lookup of tenant-id
* refactor: Fixed kiota serialization dependencies to 1.0.0 to avoid
pendulum dependency issues for backward compatibility
* refactor: Updated provider dependencies
* refactored: Fixed import of test_utils in test_dag_run
---------
Co-authored-by: David Blain <[email protected]>
---
generated/provider_dependencies.json | 2 +
.../providers/microsoft/azure/hooks/msgraph.py | 7 ++++
.../providers/microsoft/azure/operators/msgraph.py | 18 +++++++--
.../providers/microsoft/azure/provider.yaml | 2 +
.../tests/microsoft/azure/hooks/test_msgraph.py | 45 +++++++++++++++++++++-
.../microsoft/azure/operators/test_msgraph.py | 30 +++++++++++++++
6 files changed, 99 insertions(+), 5 deletions(-)
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 4d921fc1fb..8efdd5eae7 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -828,6 +828,8 @@
"azure-synapse-artifacts>=0.17.0",
"azure-synapse-spark>=0.2.0",
"microsoft-kiota-http>=1.3.0,!=1.3.4",
+ "microsoft-kiota-serialization-json==1.0.0",
+ "microsoft-kiota-serialization-text==1.0.0",
"msgraph-core>=1.0.0"
],
"devel-deps": [
diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 61e555f4ca..4ab3aaf3ba 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -32,11 +32,14 @@ from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.response_handler import ResponseHandler
+from kiota_abstractions.serialization import ParseNodeFactoryRegistry
from kiota_authentication_azure.azure_identity_authentication_provider import (
AzureIdentityAuthenticationProvider,
)
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from kiota_http.middleware.options import ResponseHandlerOption
+from kiota_serialization_json.json_parse_node_factory import
JsonParseNodeFactory
+from kiota_serialization_text.text_parse_node_factory import
TextParseNodeFactory
from msgraph_core import APIVersion, GraphClientFactory
from msgraph_core._enums import NationalClouds
@@ -249,8 +252,12 @@ class KiotaRequestAdapterHook(BaseHook):
scopes=scopes,
allowed_hosts=allowed_hosts,
)
+ parse_node_factory = ParseNodeFactoryRegistry()
+ parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"]
= TextParseNodeFactory()
+
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] =
JsonParseNodeFactory()
request_adapter = HttpxRequestAdapter(
authentication_provider=auth_provider,
+ parse_node_factory=parse_node_factory,
http_client=http_client,
base_url=base_url,
)
diff --git
a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
index b3d14b14a5..0d187ebd51 100644
--- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -44,6 +44,14 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
+def default_event_handler(context: Context, event: dict[Any, Any] | None =
None) -> Any:
+ if event:
+ if event.get("status") == "failure":
+ raise AirflowException(event.get("message"))
+
+ return event.get("response")
+
+
class MSGraphAsyncOperator(BaseOperator):
"""
A Microsoft Graph API operator which allows you to execute REST call to
the Microsoft Graph API.
@@ -69,6 +77,9 @@ class MSGraphAsyncOperator(BaseOperator):
:param result_processor: Function to further process the response from MS
Graph API
(default is lambda: context, response: response). When the response
returned by the
`KiotaRequestAdapterHook` are bytes, then those will be base64 encoded
into a string.
+ :param event_handler: Function to process the event returned from
`MSGraphTrigger`. By default, when the
+ event returned by the `MSGraphTrigger` has a failed status, an
AirflowException is being raised with
+ the message from the event, otherwise the response from the event
payload is returned.
:param serializer: Class which handles response serialization (default is
ResponseSerializer).
Bytes will be base64 encoded into a string, so it can be stored as an
XCom.
"""
@@ -102,6 +113,7 @@ class MSGraphAsyncOperator(BaseOperator):
api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict, Context],
tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context,
result: result,
+ event_handler: Callable[[Context, dict[Any, Any] | None], Any] | None
= None,
serializer: type[ResponseSerializer] = ResponseSerializer,
**kwargs: Any,
):
@@ -121,6 +133,7 @@ class MSGraphAsyncOperator(BaseOperator):
self.api_version = api_version
self.pagination_function = pagination_function or self.paginate
self.result_processor = result_processor
+ self.event_handler = event_handler or default_event_handler
self.serializer: ResponseSerializer = serializer()
def execute(self, context: Context) -> None:
@@ -158,10 +171,7 @@ class MSGraphAsyncOperator(BaseOperator):
if event:
self.log.debug("%s completed with %s: %s", self.task_id,
event.get("status"), event)
- if event.get("status") == "failure":
- raise AirflowException(event.get("message"))
-
- response = event.get("response")
+ response = self.event_handler(context, event)
self.log.debug("response: %s", response)
diff --git a/providers/src/airflow/providers/microsoft/azure/provider.yaml
b/providers/src/airflow/providers/microsoft/azure/provider.yaml
index cf0b3f75ef..c4831a641b 100644
--- a/providers/src/airflow/providers/microsoft/azure/provider.yaml
+++ b/providers/src/airflow/providers/microsoft/azure/provider.yaml
@@ -111,6 +111,8 @@ dependencies:
# msgraph-core has transient import failures with microsoft-kiota-http==1.3.4
# See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/706
- microsoft-kiota-http>=1.3.0,!=1.3.4
+ - microsoft-kiota-serialization-json==1.0.0
+ - microsoft-kiota-serialization-text==1.0.0
devel-dependencies:
- pywinrm
diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py
b/providers/tests/microsoft/azure/hooks/test_msgraph.py
index 0ecad98548..aff5d0226a 100644
--- a/providers/tests/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py
@@ -19,11 +19,15 @@ from __future__ import annotations
import asyncio
from json import JSONDecodeError
from typing import TYPE_CHECKING
-from unittest.mock import patch
+from unittest.mock import Mock, patch
import pytest
+from httpx import Response
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
+from kiota_serialization_json.json_parse_node import JsonParseNode
+from kiota_serialization_text.text_parse_node import TextParseNode
from msgraph_core import APIVersion, NationalClouds
+from opentelemetry.trace import Span
from airflow.exceptions import AirflowBadRequest, AirflowException,
AirflowNotFoundException
from airflow.providers.microsoft.azure.hooks.msgraph import (
@@ -175,6 +179,45 @@ class TestKiotaRequestAdapterHook:
assert actual == {"%24expand":
"reports,users,datasets,dataflows,dashboards", "%24top": 5000}
+ @pytest.mark.asyncio
+ async def test_throw_failed_responses_with_text_plain_content_type(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ response = Mock(spec=Response)
+ response.headers = {"content-type": "text/plain"}
+ response.status_code = 429
+ response.content = b"TenantThrottleThresholdExceeded"
+ response.is_success = False
+ span = Mock(spec=Span)
+
+ actual = await hook.get_conn().get_root_parse_node(response, span,
span)
+
+ assert isinstance(actual, TextParseNode)
+ assert actual.get_str_value() == "TenantThrottleThresholdExceeded"
+
+ @pytest.mark.asyncio
+ async def
test_throw_failed_responses_with_application_json_content_type(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ response = Mock(spec=Response)
+ response.headers = {"content-type": "application/json"}
+ response.status_code = 429
+ response.content = b'{"error": {"code":
"TenantThrottleThresholdExceeded"}}'
+ response.is_success = False
+ span = Mock(spec=Span)
+
+ actual = await hook.get_conn().get_root_parse_node(response, span,
span)
+
+ assert isinstance(actual, JsonParseNode)
+ error_code =
actual.get_child_node("error").get_child_node("code").get_str_value()
+ assert error_code == "TenantThrottleThresholdExceeded"
+
class TestResponseHandler:
def test_default_response_handler_when_json(self):
diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py
b/providers/tests/microsoft/azure/operators/test_msgraph.py
index 372152fe97..fe404e48e6 100644
--- a/providers/tests/microsoft/azure/operators/test_msgraph.py
+++ b/providers/tests/microsoft/azure/operators/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import json
import locale
from base64 import b64encode
+from typing import TYPE_CHECKING, Any
import pytest
@@ -35,6 +36,9 @@ from providers.tests.microsoft.conftest import (
mock_response,
)
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
@@ -101,6 +105,32 @@ class TestMSGraphAsyncOperator(Base):
with pytest.raises(AirflowException):
self.execute_operator(operator)
+ @pytest.mark.db_test
+ def test_execute_when_an_exception_occurs_on_custom_event_handler(self):
+ with self.patch_hook_and_request_adapter(AirflowException("An error
occurred")):
+
+ def custom_event_handler(context: Context, event: dict[Any, Any] |
None = None):
+ if event:
+ if event.get("status") == "failure":
+ return None
+
+ return event.get("response")
+
+ operator = MSGraphAsyncOperator(
+ task_id="users_delta",
+ conn_id="msgraph_api",
+ url="users/delta",
+ event_handler=custom_event_handler,
+ )
+
+ results, events = self.execute_operator(operator)
+
+ assert not results
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["status"] == "failure"
+ assert events[0].payload["message"] == "An error occurred"
+
@pytest.mark.db_test
def test_execute_when_response_is_bytes(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)