This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a61f393ec4 Implemented MSGraphSensor as a deferrable sensor (#39304)
a61f393ec4 is described below

commit a61f393ec4361499fcef9f2854668db85b852ec0
Author: David Blain <[email protected]>
AuthorDate: Sun May 5 09:06:36 2024 +0200

    Implemented MSGraphSensor as a deferrable sensor (#39304)
    
    * refactor: Implement default response handler method and added test when 
JSON decode error occurs
    
    * refactor: Reformatted some code to comply to static checks
    
    * refactor: Changed debugging level to debug for printing response in 
operator
    
    * docs: Added example on how to refresh a PowerBI dataset using the 
MSGraphAsyncOperator
    
    * refactor: Changed some info logging statements to debug
    
    * refactor: Changed some info logging statements to debug
    
    * fix: Fixed mock_json_response
    
    * refactor: Return content if response is not a JSON
    
    * refactor: Make sure the operator passes the response_handler to the 
triggerer
    
    * refactor: Should use get instead of directly _getitem_ brackets as 
payload could not have a response key if call isn't done
    
    * refactor: If event has status failure then the sensor should stop the 
async poke
    
    * refactor: Changed default_event_processor as not all responses have the 
status key present
    
    * refactor: Changed default_event_processor as not all responses have the 
status key present
    
    * refactor: Removed response_handler parameter as lambda cannot be 
serialized by MSGraphTrigger
    
    * refactor: Changed some logging statements
    
    * refactor: Updated PowerBI dataset refresh example
    
    * refactor: Fixed 2 static check errors
    
    * refactor: Refactored MSGraphSensor as a real async sensor
    
    * refactor: Changed logging level of sensor statements back to debug
    
    * refactor: Fixed 2 static checks
    
    * refactor: Changed docstring hook
    
    * refactor: Put docstring on one line
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 airflow/providers/microsoft/azure/hooks/msgraph.py |  41 ++++---
 .../providers/microsoft/azure/operators/msgraph.py |  15 +--
 .../providers/microsoft/azure/sensors/msgraph.py   | 118 ++++++++++++---------
 .../providers/microsoft/azure/triggers/msgraph.py  |  11 --
 .../operators/msgraph.rst                          |   8 ++
 .../microsoft/azure/hooks/test_msgraph.py          |  58 ++++++----
 .../microsoft/azure/sensors/test_msgraph.py        |  17 ++-
 tests/providers/microsoft/conftest.py              |  12 ++-
 .../providers/microsoft/azure/example_powerbi.py   |  29 +++++
 9 files changed, 180 insertions(+), 129 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py 
b/airflow/providers/microsoft/azure/hooks/msgraph.py
index 84b2252bd2..56abfa155d 100644
--- a/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,9 +18,11 @@
 from __future__ import annotations
 
 import json
+from contextlib import suppress
 from http import HTTPStatus
 from io import BytesIO
-from typing import TYPE_CHECKING, Any, Callable
+from json import JSONDecodeError
+from typing import TYPE_CHECKING, Any
 from urllib.parse import quote, urljoin, urlparse
 
 import httpx
@@ -51,18 +53,17 @@ if TYPE_CHECKING:
     from airflow.models import Connection
 
 
-class CallableResponseHandler(ResponseHandler):
-    """
-    CallableResponseHandler executes the passed callable_function with 
response as parameter.
-
-    param callable_function: Function that is applied to the response.
-    """
+class DefaultResponseHandler(ResponseHandler):
+    """DefaultResponseHandler returns JSON payload or content in bytes or 
response headers."""
 
-    def __init__(
-        self,
-        callable_function: Callable[[NativeResponseType, dict[str, 
ParsableFactory | None] | None], Any],
-    ):
-        self.callable_function = callable_function
+    @staticmethod
+    def get_value(response: NativeResponseType) -> Any:
+        with suppress(JSONDecodeError):
+            return response.json()
+        content = response.content
+        if not content:
+            return {key: value for key, value in response.headers.items()}
+        return content
 
     async def handle_response_async(
         self, response: NativeResponseType, error_map: dict[str, 
ParsableFactory | None] | None = None
@@ -73,7 +74,7 @@ class CallableResponseHandler(ResponseHandler):
         param response: The type of the native response object.
         param error_map: The error dict to use in case of a failed request.
         """
-        value = self.callable_function(response, error_map)
+        value = self.get_value(response)
         if response.status_code not in {200, 201, 202, 204, 302}:
             message = value or response.reason_phrase
             status_code = HTTPStatus(response.status_code)
@@ -269,20 +270,18 @@ class KiotaRequestAdapterHook(BaseHook):
         self,
         url: str = "",
         response_type: ResponseType | None = None,
-        response_handler: Callable[
-            [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
-        ] = lambda response, error_map: response.json(),
         path_parameters: dict[str, Any] | None = None,
         method: str = "GET",
         query_parameters: dict[str, QueryParams] | None = None,
         headers: dict[str, str] | None = None,
         data: dict[str, Any] | str | BytesIO | None = None,
     ):
+        self.log.info("Executing url '%s' as '%s'", url, method)
+
         response = await self.get_conn().send_primitive_async(
             request_info=self.request_information(
                 url=url,
                 response_type=response_type,
-                response_handler=response_handler,
                 path_parameters=path_parameters,
                 method=method,
                 query_parameters=query_parameters,
@@ -293,7 +292,7 @@ class KiotaRequestAdapterHook(BaseHook):
             error_map=self.error_mapping(),
         )
 
-        self.log.debug("response: %s", response)
+        self.log.info("response: %s", response)
 
         return response
 
@@ -301,9 +300,6 @@ class KiotaRequestAdapterHook(BaseHook):
         self,
         url: str,
         response_type: ResponseType | None = None,
-        response_handler: Callable[
-            [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
-        ] = lambda response, error_map: response.json(),
         path_parameters: dict[str, Any] | None = None,
         method: str = "GET",
         query_parameters: dict[str, QueryParams] | None = None,
@@ -323,12 +319,11 @@ class KiotaRequestAdapterHook(BaseHook):
             request_information.url_template = 
f"{{+baseurl}}/{self.normalize_url(url)}"
         if not response_type:
             
request_information.request_options[ResponseHandlerOption.get_key()] = 
ResponseHandlerOption(
-                response_handler=CallableResponseHandler(response_handler)
+                response_handler=DefaultResponseHandler()
             )
         headers = {**self.DEFAULT_HEADERS, **headers} if headers else 
self.DEFAULT_HEADERS
         for header_name, header_value in headers.items():
             request_information.headers.try_add(header_name=header_name, 
header_value=header_value)
-        self.log.info("data: %s", data)
         if isinstance(data, BytesIO) or isinstance(data, bytes) or 
isinstance(data, str):
             request_information.content = data
         elif data:
diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py 
b/airflow/providers/microsoft/azure/operators/msgraph.py
index 6411f9cc4a..39ca32d2b6 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -39,8 +39,6 @@ if TYPE_CHECKING:
 
     from kiota_abstractions.request_adapter import ResponseType
     from kiota_abstractions.request_information import QueryParams
-    from kiota_abstractions.response_handler import NativeResponseType
-    from kiota_abstractions.serialization import ParsableFactory
     from msgraph_core import APIVersion
 
     from airflow.utils.context import Context
@@ -59,9 +57,6 @@ class MSGraphAsyncOperator(BaseOperator):
     :param url: The url being executed on the Microsoft Graph API (templated).
     :param response_type: The expected return type of the response as a 
string. Possible value are: `bytes`,
         `str`, `int`, `float`, `bool` and `datetime` (default is None).
-    :param response_handler: Function to convert the native HTTPX response 
returned by the hook (default is
-        lambda response, error_map: response.json()).  The default expression 
will convert the native response
-        to JSON.  If response_type parameter is specified, then the 
response_handler will be ignored.
     :param method: The HTTP method being used to do the REST call (default is 
GET).
     :param conn_id: The HTTP Connection ID to run the operator against 
(templated).
     :param key: The key that will be used to store `XCom's` ("return_value" is 
default).
@@ -94,9 +89,6 @@ class MSGraphAsyncOperator(BaseOperator):
         *,
         url: str,
         response_type: ResponseType | None = None,
-        response_handler: Callable[
-            [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
-        ] = lambda response, error_map: response.json(),
         path_parameters: dict[str, Any] | None = None,
         url_template: str | None = None,
         method: str = "GET",
@@ -116,7 +108,6 @@ class MSGraphAsyncOperator(BaseOperator):
         super().__init__(**kwargs)
         self.url = url
         self.response_type = response_type
-        self.response_handler = response_handler
         self.path_parameters = path_parameters
         self.url_template = url_template
         self.method = method
@@ -134,7 +125,6 @@ class MSGraphAsyncOperator(BaseOperator):
         self.results: list[Any] | None = None
 
     def execute(self, context: Context) -> None:
-        self.log.info("Executing url '%s' as '%s'", self.url, self.method)
         self.defer(
             trigger=MSGraphTrigger(
                 url=self.url,
@@ -167,14 +157,14 @@ class MSGraphAsyncOperator(BaseOperator):
         self.log.debug("context: %s", context)
 
         if event:
-            self.log.info("%s completed with %s: %s", self.task_id, 
event.get("status"), event)
+            self.log.debug("%s completed with %s: %s", self.task_id, 
event.get("status"), event)
 
             if event.get("status") == "failure":
                 raise AirflowException(event.get("message"))
 
             response = event.get("response")
 
-            self.log.info("response: %s", response)
+            self.log.debug("response: %s", response)
 
             if response:
                 response = self.serializer.deserialize(response)
@@ -281,7 +271,6 @@ class MSGraphAsyncOperator(BaseOperator):
                         url=url,
                         query_parameters=query_parameters,
                         response_type=self.response_type,
-                        response_handler=self.response_handler,
                         conn_id=self.conn_id,
                         timeout=self.timeout,
                         proxies=self.proxies,
diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py 
b/airflow/providers/microsoft/azure/sensors/msgraph.py
index ffbf244dbe..3e1b10cbeb 100644
--- a/airflow/providers/microsoft/azure/sensors/msgraph.py
+++ b/airflow/providers/microsoft/azure/sensors/msgraph.py
@@ -17,33 +17,25 @@
 # under the License.
 from __future__ import annotations
 
-import asyncio
-import json
 from typing import TYPE_CHECKING, Any, Callable, Sequence
 
+from airflow.exceptions import AirflowException
 from airflow.providers.microsoft.azure.hooks.msgraph import 
KiotaRequestAdapterHook
 from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, 
ResponseSerializer
-from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
+from airflow.sensors.base import BaseSensorOperator
+from airflow.triggers.temporal import TimeDeltaTrigger
 
 if TYPE_CHECKING:
+    from datetime import timedelta
     from io import BytesIO
 
     from kiota_abstractions.request_information import QueryParams
-    from kiota_abstractions.response_handler import NativeResponseType
-    from kiota_abstractions.serialization import ParsableFactory
     from kiota_http.httpx_request_adapter import ResponseType
     from msgraph_core import APIVersion
 
-    from airflow.triggers.base import TriggerEvent
     from airflow.utils.context import Context
 
 
-def default_event_processor(context: Context, event: TriggerEvent) -> bool:
-    if event.payload["status"] == "success":
-        return json.loads(event.payload["response"])["status"] == "Succeeded"
-    return False
-
-
 class MSGraphSensor(BaseSensorOperator):
     """
     A Microsoft Graph API sensor which allows you to poll an async REST call 
to the Microsoft Graph API.
@@ -51,9 +43,6 @@ class MSGraphSensor(BaseSensorOperator):
     :param url: The url being executed on the Microsoft Graph API (templated).
     :param response_type: The expected return type of the response as a 
string. Possible value are: `bytes`,
         `str`, `int`, `float`, `bool` and `datetime` (default is None).
-    :param response_handler: Function to convert the native HTTPX response 
returned by the hook (default is
-        lambda response, error_map: response.json()).  The default expression 
will convert the native response
-        to JSON.  If response_type parameter is specified, then the 
response_handler will be ignored.
     :param method: The HTTP method being used to do the REST call (default is 
GET).
     :param conn_id: The HTTP Connection ID to run the operator against 
(templated).
     :param proxies: A dict defining the HTTP proxies to be used (default is 
None).
@@ -85,9 +74,6 @@ class MSGraphSensor(BaseSensorOperator):
         self,
         url: str,
         response_type: ResponseType | None = None,
-        response_handler: Callable[
-            [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
-        ] = lambda response, error_map: response.json(),
         path_parameters: dict[str, Any] | None = None,
         url_template: str | None = None,
         method: str = "GET",
@@ -97,15 +83,15 @@ class MSGraphSensor(BaseSensorOperator):
         conn_id: str = KiotaRequestAdapterHook.default_conn_name,
         proxies: dict | None = None,
         api_version: APIVersion | None = None,
-        event_processor: Callable[[Context, TriggerEvent], bool] = 
default_event_processor,
+        event_processor: Callable[[Context, Any], bool] = lambda context, e: 
e.get("status") == "Succeeded",
         result_processor: Callable[[Context, Any], Any] = lambda context, 
result: result,
         serializer: type[ResponseSerializer] = ResponseSerializer,
+        retry_delay: timedelta | float = 60,
         **kwargs,
     ):
-        super().__init__(**kwargs)
+        super().__init__(retry_delay=retry_delay, **kwargs)
         self.url = url
         self.response_type = response_type
-        self.response_handler = response_handler
         self.path_parameters = path_parameters
         self.url_template = url_template
         self.method = method
@@ -119,45 +105,73 @@ class MSGraphSensor(BaseSensorOperator):
         self.result_processor = result_processor
         self.serializer = serializer()
 
-    @property
-    def trigger(self):
-        return MSGraphTrigger(
-            url=self.url,
-            response_type=self.response_type,
-            response_handler=self.response_handler,
-            path_parameters=self.path_parameters,
-            url_template=self.url_template,
-            method=self.method,
-            query_parameters=self.query_parameters,
-            headers=self.headers,
-            data=self.data,
-            conn_id=self.conn_id,
-            timeout=self.timeout,
-            proxies=self.proxies,
-            api_version=self.api_version,
-            serializer=type(self.serializer),
+    def execute(self, context: Context):
+        self.defer(
+            trigger=MSGraphTrigger(
+                url=self.url,
+                response_type=self.response_type,
+                path_parameters=self.path_parameters,
+                url_template=self.url_template,
+                method=self.method,
+                query_parameters=self.query_parameters,
+                headers=self.headers,
+                data=self.data,
+                conn_id=self.conn_id,
+                timeout=self.timeout,
+                proxies=self.proxies,
+                api_version=self.api_version,
+                serializer=type(self.serializer),
+            ),
+            method_name=self.execute_complete.__name__,
         )
 
-    async def async_poke(self, context: Context) -> bool | PokeReturnValue:
-        self.log.info("Sensor triggered")
+    def retry_execute(
+        self,
+        context: Context,
+    ) -> Any:
+        self.execute(context=context)
+
+    def execute_complete(
+        self,
+        context: Context,
+        event: dict[Any, Any] | None = None,
+    ) -> Any:
+        """
+        Execute callback when MSGraphSensor finishes execution.
+
+        This method gets executed automatically when MSGraphTrigger completes 
its execution.
+        """
+        self.log.debug("context: %s", context)
+
+        if event:
+            self.log.debug("%s completed with %s: %s", self.task_id, 
event.get("status"), event)
+
+            if event.get("status") == "failure":
+                raise AirflowException(event.get("message"))
+
+            response = event.get("response")
+
+            self.log.debug("response: %s", response)
 
-        async for event in self.trigger.run():
-            self.log.debug("event: %s", event)
+            if response:
+                response = self.serializer.deserialize(response)
 
-            is_done = self.event_processor(context, event)
+                self.log.debug("deserialize response: %s", response)
 
-            self.log.debug("is_done: %s", is_done)
+                is_done = self.event_processor(context, response)
 
-            response = self.serializer.deserialize(event.payload["response"])
+                self.log.debug("is_done: %s", is_done)
 
-            self.log.debug("deserialize event: %s", response)
+                if is_done:
+                    result = self.result_processor(context, response)
 
-            result = self.result_processor(context, response)
+                    self.log.debug("processed response: %s", result)
 
-            self.log.debug("result: %s", result)
+                    return result
 
-            return PokeReturnValue(is_done=is_done, xcom_value=result)
-        return PokeReturnValue(is_done=True)
+                self.defer(
+                    trigger=TimeDeltaTrigger(self.retry_delay),
+                    method_name=self.retry_execute.__name__,
+                )
 
-    def poke(self, context) -> bool | PokeReturnValue:
-        return asyncio.run(self.async_poke(context))
+        return None
diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py 
b/airflow/providers/microsoft/azure/triggers/msgraph.py
index 1848f969f8..4b9ccb7a71 100644
--- a/airflow/providers/microsoft/azure/triggers/msgraph.py
+++ b/airflow/providers/microsoft/azure/triggers/msgraph.py
@@ -27,7 +27,6 @@ from typing import (
     TYPE_CHECKING,
     Any,
     AsyncIterator,
-    Callable,
     Sequence,
 )
 from uuid import UUID
@@ -43,8 +42,6 @@ if TYPE_CHECKING:
 
     from kiota_abstractions.request_adapter import RequestAdapter
     from kiota_abstractions.request_information import QueryParams
-    from kiota_abstractions.response_handler import NativeResponseType
-    from kiota_abstractions.serialization import ParsableFactory
     from kiota_http.httpx_request_adapter import ResponseType
     from msgraph_core import APIVersion
 
@@ -89,9 +86,6 @@ class MSGraphTrigger(BaseTrigger):
     :param url: The url being executed on the Microsoft Graph API (templated).
     :param response_type: The expected return type of the response as a 
string. Possible value are: `bytes`,
         `str`, `int`, `float`, `bool` and `datetime` (default is None).
-    :param response_handler: Function to convert the native HTTPX response 
returned by the hook (default is
-        lambda response, error_map: response.json()).  The default expression 
will convert the native response
-        to JSON.  If response_type parameter is specified, then the 
response_handler will be ignored.
     :param method: The HTTP method being used to do the REST call (default is 
GET).
     :param conn_id: The HTTP Connection ID to run the operator against 
(templated).
     :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` 
(default is None).
@@ -119,9 +113,6 @@ class MSGraphTrigger(BaseTrigger):
         self,
         url: str,
         response_type: ResponseType | None = None,
-        response_handler: Callable[
-            [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
-        ] = lambda response, error_map: response.json(),
         path_parameters: dict[str, Any] | None = None,
         url_template: str | None = None,
         method: str = "GET",
@@ -143,7 +134,6 @@ class MSGraphTrigger(BaseTrigger):
         )
         self.url = url
         self.response_type = response_type
-        self.response_handler = response_handler
         self.path_parameters = path_parameters
         self.url_template = url_template
         self.method = method
@@ -207,7 +197,6 @@ class MSGraphTrigger(BaseTrigger):
             response = await self.hook.run(
                 url=self.url,
                 response_type=self.response_type,
-                response_handler=self.response_handler,
                 path_parameters=self.path_parameters,
                 method=self.method,
                 query_parameters=self.query_parameters,
diff --git 
a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst 
b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
index 817b14f783..342bf54276 100644
--- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
@@ -64,6 +64,14 @@ Below is an example of using this operator to get PowerBI 
workspaces info.
     :start-after: [START howto_operator_powerbi_workspaces_info]
     :end-before: [END howto_operator_powerbi_workspaces_info]
 
+Below is an example of using this operator to refresh PowerBI dataset.
+
+.. exampleinclude:: 
/../../tests/system/providers/microsoft/azure/example_powerbi.py
+    :language: python
+    :dedent: 0
+    :start-after: [START howto_operator_powerbi_refresh_dataset]
+    :end-before: [END howto_operator_powerbi_refresh_dataset]
+
 
 Reference
 ---------
diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py 
b/tests/providers/microsoft/azure/hooks/test_msgraph.py
index 9d2db07acf..71d280a197 100644
--- a/tests/providers/microsoft/azure/hooks/test_msgraph.py
+++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import asyncio
+from json import JSONDecodeError
 from unittest.mock import patch
 
 import pytest
@@ -24,12 +25,17 @@ from kiota_http.httpx_request_adapter import 
HttpxRequestAdapter
 from msgraph_core import APIVersion, NationalClouds
 
 from airflow.exceptions import AirflowBadRequest, AirflowException, 
AirflowNotFoundException
-from airflow.providers.microsoft.azure.hooks.msgraph import 
CallableResponseHandler, KiotaRequestAdapterHook
+from airflow.providers.microsoft.azure.hooks.msgraph import (
+    DefaultResponseHandler,
+    KiotaRequestAdapterHook,
+)
 from tests.providers.microsoft.conftest import (
     get_airflow_connection,
+    load_file,
     load_json,
     mock_connection,
     mock_json_response,
+    mock_response,
 )
 
 
@@ -95,45 +101,53 @@ class TestKiotaRequestAdapterHook:
 
 
 class TestResponseHandler:
-    def test_handle_response_async_when_ok(self):
+    def test_default_response_handler_when_json(self):
         users = load_json("resources", "users.json")
         response = mock_json_response(200, users)
 
-        actual = asyncio.run(
-            CallableResponseHandler(lambda response, error_map: 
response.json()).handle_response_async(
-                response, None
-            )
-        )
+        actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
 
         assert isinstance(actual, dict)
         assert actual == users
 
+    def test_default_response_handler_when_not_json(self):
+        response = mock_json_response(200, JSONDecodeError("", "", 0))
+
+        actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+        assert actual == {}
+
+    def test_default_response_handler_when_content(self):
+        users = load_file("resources", "users.json").encode()
+        response = mock_response(200, users)
+
+        actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+        assert isinstance(actual, bytes)
+        assert actual == users
+
+    def test_default_response_handler_when_no_content_but_headers(self):
+        response = mock_response(200, headers={"RequestId": 
"ffb6096e-d409-4826-aaeb-b5d4b165dc4d"})
+
+        actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+        assert isinstance(actual, dict)
+        assert actual["requestid"] == "ffb6096e-d409-4826-aaeb-b5d4b165dc4d"
+
     def test_handle_response_async_when_bad_request(self):
         response = mock_json_response(400, {})
 
         with pytest.raises(AirflowBadRequest):
-            asyncio.run(
-                CallableResponseHandler(lambda response, error_map: 
response.json()).handle_response_async(
-                    response, None
-                )
-            )
+            
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
 
     def test_handle_response_async_when_not_found(self):
         response = mock_json_response(404, {})
 
         with pytest.raises(AirflowNotFoundException):
-            asyncio.run(
-                CallableResponseHandler(lambda response, error_map: 
response.json()).handle_response_async(
-                    response, None
-                )
-            )
+            
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
 
     def test_handle_response_async_when_internal_server_error(self):
         response = mock_json_response(500, {})
 
         with pytest.raises(AirflowException):
-            asyncio.run(
-                CallableResponseHandler(lambda response, error_map: 
response.json()).handle_response_async(
-                    response, None
-                )
-            )
+            
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
diff --git a/tests/providers/microsoft/azure/sensors/test_msgraph.py 
b/tests/providers/microsoft/azure/sensors/test_msgraph.py
index 50fd2474ab..e257984aff 100644
--- a/tests/providers/microsoft/azure/sensors/test_msgraph.py
+++ b/tests/providers/microsoft/azure/sensors/test_msgraph.py
@@ -16,9 +16,12 @@
 # under the License.
 from __future__ import annotations
 
+import json
+
 from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
+from airflow.triggers.base import TriggerEvent
 from tests.providers.microsoft.azure.base import Base
-from tests.providers.microsoft.conftest import load_json, mock_context, 
mock_json_response
+from tests.providers.microsoft.conftest import load_json, mock_json_response
 
 
 class TestMSGraphSensor(Base):
@@ -35,10 +38,16 @@ class TestMSGraphSensor(Base):
                 result_processor=lambda context, result: result["id"],
                 timeout=350.0,
             )
-            actual = sensor.execute(context=mock_context(task=sensor))
 
-            assert isinstance(actual, str)
-            assert actual == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
+            results, events = self.execute_operator(sensor)
+
+            assert isinstance(results, str)
+            assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["status"] == "success"
+            assert events[0].payload["type"] == "builtins.dict"
+            assert events[0].payload["response"] == json.dumps(status)
 
     def test_template_fields(self):
         sensor = MSGraphSensor(
diff --git a/tests/providers/microsoft/conftest.py 
b/tests/providers/microsoft/conftest.py
index dfba931023..aa3c48c5d7 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -20,12 +20,13 @@ from __future__ import annotations
 import json
 import random
 import string
+from json import JSONDecodeError
 from os.path import dirname, join
 from typing import TYPE_CHECKING, Any, Iterable, TypeVar
 from unittest.mock import MagicMock
 
 import pytest
-from httpx import Response
+from httpx import Headers, Response
 from msgraph_core import APIVersion
 
 from airflow.models import Connection
@@ -89,18 +90,21 @@ def mock_connection(schema: str | None = None, host: str | 
None = None) -> Conne
 def mock_json_response(status_code, *contents) -> Response:
     response = MagicMock(spec=Response)
     response.status_code = status_code
+    response.headers = Headers({})
+    response.content = b""
     if contents:
-        contents = list(contents)
-        response.json.side_effect = lambda: contents.pop(0)
+        response.json.side_effect = list(contents)
     else:
         response.json.return_value = None
     return response
 
 
-def mock_response(status_code, content: Any = None) -> Response:
+def mock_response(status_code, content: Any = None, headers: dict | None = 
None) -> Response:
     response = MagicMock(spec=Response)
     response.status_code = status_code
+    response.headers = Headers(headers or {})
     response.content = content
+    response.json.side_effect = JSONDecodeError("", "", 0)
     return response
 
 
diff --git a/tests/system/providers/microsoft/azure/example_powerbi.py 
b/tests/system/providers/microsoft/azure/example_powerbi.py
index cbee9a62af..0a1bfde54a 100644
--- a/tests/system/providers/microsoft/azure/example_powerbi.py
+++ b/tests/system/providers/microsoft/azure/example_powerbi.py
@@ -66,7 +66,36 @@ with models.DAG(
     ).expand(path_parameters=workspaces_info_task.output)
     # [END howto_sensor_powerbi_scan_status]
 
+    # [START howto_operator_powerbi_refresh_dataset]
+    refresh_dataset_task = MSGraphAsyncOperator(
+        task_id="refresh_dataset",
+        conn_id="powerbi_api",
+        url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes",
+        method="POST",
+        path_parameters={
+            "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51",
+            "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d",
+        },
+        data={"type": "full"},  # Needed for enhanced refresh
+        result_processor=lambda context, response: response["requestid"],
+    )
+
+    refresh_dataset_history_task = MSGraphSensor(
+        task_id="refresh_dataset_history",
+        conn_id="powerbi_api",
+        
url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes/{refreshId}",
+        path_parameters={
+            "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51",
+            "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d",
+            "refreshId": refresh_dataset_task.output,
+        },
+        timeout=350.0,
+        event_processor=lambda context, event: event["status"] == "Completed",
+    )
+    # [END howto_operator_powerbi_refresh_dataset]
+
     workspaces_task >> workspaces_info_task >> check_workspace_status_task
+    refresh_dataset_task >> refresh_dataset_history_task
 
     from tests.system.utils.watcher import watcher
 

Reply via email to