This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a61f393ec4 Implemented MSGraphSensor as a deferrable sensor (#39304)
a61f393ec4 is described below
commit a61f393ec4361499fcef9f2854668db85b852ec0
Author: David Blain <[email protected]>
AuthorDate: Sun May 5 09:06:36 2024 +0200
Implemented MSGraphSensor as a deferrable sensor (#39304)
* refactor: Implement default response handler method and added test when
JSON decode error occurs
* refactor: Reformatted some code to comply to static checks
* refactor: Changed debugging level to debug for printing response in
operator
* docs: Added example on how to refresh a PowerBI dataset using the
MSGraphAsyncOperator
* refactor: Changed some info logging statements to debug
* refactor: Changed some info logging statements to debug
* fix: Fixed mock_json_response
* refactor: Return content if response is not a JSON
* refactor: Make sure the operator passes the response_handler to the
triggerer
* refactor: Should use get instead of directly _getitem_ brackets as
payload could not have a response key if call isn't done
* refactor: If event has status failure then the sensor should stop the
async poke
* refactor: Changed default_event_processor as not all responses have the
status key present
* refactor: Changed default_event_processor as not all responses have the
status key present
* refactor: Removed response_handler parameter as lambda cannot be
serialized by MSGraphTrigger
* refactor: Changed some logging statements
* refactor: Updated PowerBI dataset refresh example
* refactor: Fixed 2 static check errors
* refactor: Refactored MSGraphSensor as a real async sensor
* refactor: Changed logging level of sensor statements back to debug
* refactor: Fixed 2 static checks
* refactor: Changed docstring hook
* refactor: Put docstring on one line
---------
Co-authored-by: David Blain <[email protected]>
---
airflow/providers/microsoft/azure/hooks/msgraph.py | 41 ++++---
.../providers/microsoft/azure/operators/msgraph.py | 15 +--
.../providers/microsoft/azure/sensors/msgraph.py | 118 ++++++++++++---------
.../providers/microsoft/azure/triggers/msgraph.py | 11 --
.../operators/msgraph.rst | 8 ++
.../microsoft/azure/hooks/test_msgraph.py | 58 ++++++----
.../microsoft/azure/sensors/test_msgraph.py | 17 ++-
tests/providers/microsoft/conftest.py | 12 ++-
.../providers/microsoft/azure/example_powerbi.py | 29 +++++
9 files changed, 180 insertions(+), 129 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py
b/airflow/providers/microsoft/azure/hooks/msgraph.py
index 84b2252bd2..56abfa155d 100644
--- a/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,9 +18,11 @@
from __future__ import annotations
import json
+from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
-from typing import TYPE_CHECKING, Any, Callable
+from json import JSONDecodeError
+from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin, urlparse
import httpx
@@ -51,18 +53,17 @@ if TYPE_CHECKING:
from airflow.models import Connection
-class CallableResponseHandler(ResponseHandler):
- """
- CallableResponseHandler executes the passed callable_function with
response as parameter.
-
- param callable_function: Function that is applied to the response.
- """
+class DefaultResponseHandler(ResponseHandler):
+ """DefaultResponseHandler returns JSON payload or content in bytes or
response headers."""
- def __init__(
- self,
- callable_function: Callable[[NativeResponseType, dict[str,
ParsableFactory | None] | None], Any],
- ):
- self.callable_function = callable_function
+ @staticmethod
+ def get_value(response: NativeResponseType) -> Any:
+ with suppress(JSONDecodeError):
+ return response.json()
+ content = response.content
+ if not content:
+ return {key: value for key, value in response.headers.items()}
+ return content
async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str,
ParsableFactory | None] | None = None
@@ -73,7 +74,7 @@ class CallableResponseHandler(ResponseHandler):
param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
- value = self.callable_function(response, error_map)
+ value = self.get_value(response)
if response.status_code not in {200, 201, 202, 204, 302}:
message = value or response.reason_phrase
status_code = HTTPStatus(response.status_code)
@@ -269,20 +270,18 @@ class KiotaRequestAdapterHook(BaseHook):
self,
url: str = "",
response_type: ResponseType | None = None,
- response_handler: Callable[
- [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
- ] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
):
+ self.log.info("Executing url '%s' as '%s'", url, method)
+
response = await self.get_conn().send_primitive_async(
request_info=self.request_information(
url=url,
response_type=response_type,
- response_handler=response_handler,
path_parameters=path_parameters,
method=method,
query_parameters=query_parameters,
@@ -293,7 +292,7 @@ class KiotaRequestAdapterHook(BaseHook):
error_map=self.error_mapping(),
)
- self.log.debug("response: %s", response)
+ self.log.info("response: %s", response)
return response
@@ -301,9 +300,6 @@ class KiotaRequestAdapterHook(BaseHook):
self,
url: str,
response_type: ResponseType | None = None,
- response_handler: Callable[
- [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
- ] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
@@ -323,12 +319,11 @@ class KiotaRequestAdapterHook(BaseHook):
request_information.url_template =
f"{{+baseurl}}/{self.normalize_url(url)}"
if not response_type:
request_information.request_options[ResponseHandlerOption.get_key()] =
ResponseHandlerOption(
- response_handler=CallableResponseHandler(response_handler)
+ response_handler=DefaultResponseHandler()
)
headers = {**self.DEFAULT_HEADERS, **headers} if headers else
self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name,
header_value=header_value)
- self.log.info("data: %s", data)
if isinstance(data, BytesIO) or isinstance(data, bytes) or
isinstance(data, str):
request_information.content = data
elif data:
diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py
b/airflow/providers/microsoft/azure/operators/msgraph.py
index 6411f9cc4a..39ca32d2b6 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -39,8 +39,6 @@ if TYPE_CHECKING:
from kiota_abstractions.request_adapter import ResponseType
from kiota_abstractions.request_information import QueryParams
- from kiota_abstractions.response_handler import NativeResponseType
- from kiota_abstractions.serialization import ParsableFactory
from msgraph_core import APIVersion
from airflow.utils.context import Context
@@ -59,9 +57,6 @@ class MSGraphAsyncOperator(BaseOperator):
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a
string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
- :param response_handler: Function to convert the native HTTPX response
returned by the hook (default is
- lambda response, error_map: response.json()). The default expression
will convert the native response
- to JSON. If response_type parameter is specified, then the
response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is
GET).
:param conn_id: The HTTP Connection ID to run the operator against
(templated).
:param key: The key that will be used to store `XCom's` ("return_value" is
default).
@@ -94,9 +89,6 @@ class MSGraphAsyncOperator(BaseOperator):
*,
url: str,
response_type: ResponseType | None = None,
- response_handler: Callable[
- [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
- ] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
@@ -116,7 +108,6 @@ class MSGraphAsyncOperator(BaseOperator):
super().__init__(**kwargs)
self.url = url
self.response_type = response_type
- self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
@@ -134,7 +125,6 @@ class MSGraphAsyncOperator(BaseOperator):
self.results: list[Any] | None = None
def execute(self, context: Context) -> None:
- self.log.info("Executing url '%s' as '%s'", self.url, self.method)
self.defer(
trigger=MSGraphTrigger(
url=self.url,
@@ -167,14 +157,14 @@ class MSGraphAsyncOperator(BaseOperator):
self.log.debug("context: %s", context)
if event:
- self.log.info("%s completed with %s: %s", self.task_id,
event.get("status"), event)
+ self.log.debug("%s completed with %s: %s", self.task_id,
event.get("status"), event)
if event.get("status") == "failure":
raise AirflowException(event.get("message"))
response = event.get("response")
- self.log.info("response: %s", response)
+ self.log.debug("response: %s", response)
if response:
response = self.serializer.deserialize(response)
@@ -281,7 +271,6 @@ class MSGraphAsyncOperator(BaseOperator):
url=url,
query_parameters=query_parameters,
response_type=self.response_type,
- response_handler=self.response_handler,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py
b/airflow/providers/microsoft/azure/sensors/msgraph.py
index ffbf244dbe..3e1b10cbeb 100644
--- a/airflow/providers/microsoft/azure/sensors/msgraph.py
+++ b/airflow/providers/microsoft/azure/sensors/msgraph.py
@@ -17,33 +17,25 @@
# under the License.
from __future__ import annotations
-import asyncio
-import json
from typing import TYPE_CHECKING, Any, Callable, Sequence
+from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.msgraph import
KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger,
ResponseSerializer
-from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
+from airflow.sensors.base import BaseSensorOperator
+from airflow.triggers.temporal import TimeDeltaTrigger
if TYPE_CHECKING:
+ from datetime import timedelta
from io import BytesIO
from kiota_abstractions.request_information import QueryParams
- from kiota_abstractions.response_handler import NativeResponseType
- from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion
- from airflow.triggers.base import TriggerEvent
from airflow.utils.context import Context
-def default_event_processor(context: Context, event: TriggerEvent) -> bool:
- if event.payload["status"] == "success":
- return json.loads(event.payload["response"])["status"] == "Succeeded"
- return False
-
-
class MSGraphSensor(BaseSensorOperator):
"""
A Microsoft Graph API sensor which allows you to poll an async REST call
to the Microsoft Graph API.
@@ -51,9 +43,6 @@ class MSGraphSensor(BaseSensorOperator):
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a
string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
- :param response_handler: Function to convert the native HTTPX response
returned by the hook (default is
- lambda response, error_map: response.json()). The default expression
will convert the native response
- to JSON. If response_type parameter is specified, then the
response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is
GET).
:param conn_id: The HTTP Connection ID to run the operator against
(templated).
:param proxies: A dict defining the HTTP proxies to be used (default is
None).
@@ -85,9 +74,6 @@ class MSGraphSensor(BaseSensorOperator):
self,
url: str,
response_type: ResponseType | None = None,
- response_handler: Callable[
- [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
- ] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
@@ -97,15 +83,15 @@ class MSGraphSensor(BaseSensorOperator):
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
proxies: dict | None = None,
api_version: APIVersion | None = None,
- event_processor: Callable[[Context, TriggerEvent], bool] =
default_event_processor,
+ event_processor: Callable[[Context, Any], bool] = lambda context, e:
e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context,
result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
+ retry_delay: timedelta | float = 60,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(retry_delay=retry_delay, **kwargs)
self.url = url
self.response_type = response_type
- self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
@@ -119,45 +105,73 @@ class MSGraphSensor(BaseSensorOperator):
self.result_processor = result_processor
self.serializer = serializer()
- @property
- def trigger(self):
- return MSGraphTrigger(
- url=self.url,
- response_type=self.response_type,
- response_handler=self.response_handler,
- path_parameters=self.path_parameters,
- url_template=self.url_template,
- method=self.method,
- query_parameters=self.query_parameters,
- headers=self.headers,
- data=self.data,
- conn_id=self.conn_id,
- timeout=self.timeout,
- proxies=self.proxies,
- api_version=self.api_version,
- serializer=type(self.serializer),
+ def execute(self, context: Context):
+ self.defer(
+ trigger=MSGraphTrigger(
+ url=self.url,
+ response_type=self.response_type,
+ path_parameters=self.path_parameters,
+ url_template=self.url_template,
+ method=self.method,
+ query_parameters=self.query_parameters,
+ headers=self.headers,
+ data=self.data,
+ conn_id=self.conn_id,
+ timeout=self.timeout,
+ proxies=self.proxies,
+ api_version=self.api_version,
+ serializer=type(self.serializer),
+ ),
+ method_name=self.execute_complete.__name__,
)
- async def async_poke(self, context: Context) -> bool | PokeReturnValue:
- self.log.info("Sensor triggered")
+ def retry_execute(
+ self,
+ context: Context,
+ ) -> Any:
+ self.execute(context=context)
+
+ def execute_complete(
+ self,
+ context: Context,
+ event: dict[Any, Any] | None = None,
+ ) -> Any:
+ """
+ Execute callback when MSGraphSensor finishes execution.
+
+ This method gets executed automatically when MSGraphTrigger completes
its execution.
+ """
+ self.log.debug("context: %s", context)
+
+ if event:
+ self.log.debug("%s completed with %s: %s", self.task_id,
event.get("status"), event)
+
+ if event.get("status") == "failure":
+ raise AirflowException(event.get("message"))
+
+ response = event.get("response")
+
+ self.log.debug("response: %s", response)
- async for event in self.trigger.run():
- self.log.debug("event: %s", event)
+ if response:
+ response = self.serializer.deserialize(response)
- is_done = self.event_processor(context, event)
+ self.log.debug("deserialize response: %s", response)
- self.log.debug("is_done: %s", is_done)
+ is_done = self.event_processor(context, response)
- response = self.serializer.deserialize(event.payload["response"])
+ self.log.debug("is_done: %s", is_done)
- self.log.debug("deserialize event: %s", response)
+ if is_done:
+ result = self.result_processor(context, response)
- result = self.result_processor(context, response)
+ self.log.debug("processed response: %s", result)
- self.log.debug("result: %s", result)
+ return result
- return PokeReturnValue(is_done=is_done, xcom_value=result)
- return PokeReturnValue(is_done=True)
+ self.defer(
+ trigger=TimeDeltaTrigger(self.retry_delay),
+ method_name=self.retry_execute.__name__,
+ )
- def poke(self, context) -> bool | PokeReturnValue:
- return asyncio.run(self.async_poke(context))
+ return None
diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py
b/airflow/providers/microsoft/azure/triggers/msgraph.py
index 1848f969f8..4b9ccb7a71 100644
--- a/airflow/providers/microsoft/azure/triggers/msgraph.py
+++ b/airflow/providers/microsoft/azure/triggers/msgraph.py
@@ -27,7 +27,6 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
- Callable,
Sequence,
)
from uuid import UUID
@@ -43,8 +42,6 @@ if TYPE_CHECKING:
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
- from kiota_abstractions.response_handler import NativeResponseType
- from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion
@@ -89,9 +86,6 @@ class MSGraphTrigger(BaseTrigger):
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a
string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
- :param response_handler: Function to convert the native HTTPX response
returned by the hook (default is
- lambda response, error_map: response.json()). The default expression
will convert the native response
- to JSON. If response_type parameter is specified, then the
response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is
GET).
:param conn_id: The HTTP Connection ID to run the operator against
(templated).
:param timeout: The HTTP timeout being used by the `KiotaRequestAdapter`
(default is None).
@@ -119,9 +113,6 @@ class MSGraphTrigger(BaseTrigger):
self,
url: str,
response_type: ResponseType | None = None,
- response_handler: Callable[
- [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
- ] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
@@ -143,7 +134,6 @@ class MSGraphTrigger(BaseTrigger):
)
self.url = url
self.response_type = response_type
- self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
@@ -207,7 +197,6 @@ class MSGraphTrigger(BaseTrigger):
response = await self.hook.run(
url=self.url,
response_type=self.response_type,
- response_handler=self.response_handler,
path_parameters=self.path_parameters,
method=self.method,
query_parameters=self.query_parameters,
diff --git
a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
index 817b14f783..342bf54276 100644
--- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
@@ -64,6 +64,14 @@ Below is an example of using this operator to get PowerBI
workspaces info.
:start-after: [START howto_operator_powerbi_workspaces_info]
:end-before: [END howto_operator_powerbi_workspaces_info]
+Below is an example of using this operator to refresh PowerBI dataset.
+
+.. exampleinclude::
/../../tests/system/providers/microsoft/azure/example_powerbi.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_powerbi_refresh_dataset]
+ :end-before: [END howto_operator_powerbi_refresh_dataset]
+
Reference
---------
diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py
b/tests/providers/microsoft/azure/hooks/test_msgraph.py
index 9d2db07acf..71d280a197 100644
--- a/tests/providers/microsoft/azure/hooks/test_msgraph.py
+++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
+from json import JSONDecodeError
from unittest.mock import patch
import pytest
@@ -24,12 +25,17 @@ from kiota_http.httpx_request_adapter import
HttpxRequestAdapter
from msgraph_core import APIVersion, NationalClouds
from airflow.exceptions import AirflowBadRequest, AirflowException,
AirflowNotFoundException
-from airflow.providers.microsoft.azure.hooks.msgraph import
CallableResponseHandler, KiotaRequestAdapterHook
+from airflow.providers.microsoft.azure.hooks.msgraph import (
+ DefaultResponseHandler,
+ KiotaRequestAdapterHook,
+)
from tests.providers.microsoft.conftest import (
get_airflow_connection,
+ load_file,
load_json,
mock_connection,
mock_json_response,
+ mock_response,
)
@@ -95,45 +101,53 @@ class TestKiotaRequestAdapterHook:
class TestResponseHandler:
- def test_handle_response_async_when_ok(self):
+ def test_default_response_handler_when_json(self):
users = load_json("resources", "users.json")
response = mock_json_response(200, users)
- actual = asyncio.run(
- CallableResponseHandler(lambda response, error_map:
response.json()).handle_response_async(
- response, None
- )
- )
+ actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
assert isinstance(actual, dict)
assert actual == users
+ def test_default_response_handler_when_not_json(self):
+ response = mock_json_response(200, JSONDecodeError("", "", 0))
+
+ actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+ assert actual == {}
+
+ def test_default_response_handler_when_content(self):
+ users = load_file("resources", "users.json").encode()
+ response = mock_response(200, users)
+
+ actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+ assert isinstance(actual, bytes)
+ assert actual == users
+
+ def test_default_response_handler_when_no_content_but_headers(self):
+ response = mock_response(200, headers={"RequestId":
"ffb6096e-d409-4826-aaeb-b5d4b165dc4d"})
+
+ actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+ assert isinstance(actual, dict)
+ assert actual["requestid"] == "ffb6096e-d409-4826-aaeb-b5d4b165dc4d"
+
def test_handle_response_async_when_bad_request(self):
response = mock_json_response(400, {})
with pytest.raises(AirflowBadRequest):
- asyncio.run(
- CallableResponseHandler(lambda response, error_map:
response.json()).handle_response_async(
- response, None
- )
- )
+
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
def test_handle_response_async_when_not_found(self):
response = mock_json_response(404, {})
with pytest.raises(AirflowNotFoundException):
- asyncio.run(
- CallableResponseHandler(lambda response, error_map:
response.json()).handle_response_async(
- response, None
- )
- )
+
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
def test_handle_response_async_when_internal_server_error(self):
response = mock_json_response(500, {})
with pytest.raises(AirflowException):
- asyncio.run(
- CallableResponseHandler(lambda response, error_map:
response.json()).handle_response_async(
- response, None
- )
- )
+
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
diff --git a/tests/providers/microsoft/azure/sensors/test_msgraph.py
b/tests/providers/microsoft/azure/sensors/test_msgraph.py
index 50fd2474ab..e257984aff 100644
--- a/tests/providers/microsoft/azure/sensors/test_msgraph.py
+++ b/tests/providers/microsoft/azure/sensors/test_msgraph.py
@@ -16,9 +16,12 @@
# under the License.
from __future__ import annotations
+import json
+
from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
+from airflow.triggers.base import TriggerEvent
from tests.providers.microsoft.azure.base import Base
-from tests.providers.microsoft.conftest import load_json, mock_context,
mock_json_response
+from tests.providers.microsoft.conftest import load_json, mock_json_response
class TestMSGraphSensor(Base):
@@ -35,10 +38,16 @@ class TestMSGraphSensor(Base):
result_processor=lambda context, result: result["id"],
timeout=350.0,
)
- actual = sensor.execute(context=mock_context(task=sensor))
- assert isinstance(actual, str)
- assert actual == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
+ results, events = self.execute_operator(sensor)
+
+ assert isinstance(results, str)
+ assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["type"] == "builtins.dict"
+ assert events[0].payload["response"] == json.dumps(status)
def test_template_fields(self):
sensor = MSGraphSensor(
diff --git a/tests/providers/microsoft/conftest.py
b/tests/providers/microsoft/conftest.py
index dfba931023..aa3c48c5d7 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -20,12 +20,13 @@ from __future__ import annotations
import json
import random
import string
+from json import JSONDecodeError
from os.path import dirname, join
from typing import TYPE_CHECKING, Any, Iterable, TypeVar
from unittest.mock import MagicMock
import pytest
-from httpx import Response
+from httpx import Headers, Response
from msgraph_core import APIVersion
from airflow.models import Connection
@@ -89,18 +90,21 @@ def mock_connection(schema: str | None = None, host: str |
None = None) -> Conne
def mock_json_response(status_code, *contents) -> Response:
response = MagicMock(spec=Response)
response.status_code = status_code
+ response.headers = Headers({})
+ response.content = b""
if contents:
- contents = list(contents)
- response.json.side_effect = lambda: contents.pop(0)
+ response.json.side_effect = list(contents)
else:
response.json.return_value = None
return response
-def mock_response(status_code, content: Any = None) -> Response:
+def mock_response(status_code, content: Any = None, headers: dict | None =
None) -> Response:
response = MagicMock(spec=Response)
response.status_code = status_code
+ response.headers = Headers(headers or {})
response.content = content
+ response.json.side_effect = JSONDecodeError("", "", 0)
return response
diff --git a/tests/system/providers/microsoft/azure/example_powerbi.py
b/tests/system/providers/microsoft/azure/example_powerbi.py
index cbee9a62af..0a1bfde54a 100644
--- a/tests/system/providers/microsoft/azure/example_powerbi.py
+++ b/tests/system/providers/microsoft/azure/example_powerbi.py
@@ -66,7 +66,36 @@ with models.DAG(
).expand(path_parameters=workspaces_info_task.output)
# [END howto_sensor_powerbi_scan_status]
+ # [START howto_operator_powerbi_refresh_dataset]
+ refresh_dataset_task = MSGraphAsyncOperator(
+ task_id="refresh_dataset",
+ conn_id="powerbi_api",
+ url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes",
+ method="POST",
+ path_parameters={
+ "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51",
+ "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d",
+ },
+ data={"type": "full"}, # Needed for enhanced refresh
+ result_processor=lambda context, response: response["requestid"],
+ )
+
+ refresh_dataset_history_task = MSGraphSensor(
+ task_id="refresh_dataset_history",
+ conn_id="powerbi_api",
+
url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes/{refreshId}",
+ path_parameters={
+ "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51",
+ "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d",
+ "refreshId": refresh_dataset_task.output,
+ },
+ timeout=350.0,
+ event_processor=lambda context, event: event["status"] == "Completed",
+ )
+ # [END howto_operator_powerbi_refresh_dataset]
+
workspaces_task >> workspaces_info_task >> check_workspace_status_task
+ refresh_dataset_task >> refresh_dataset_history_task
from tests.system.utils.watcher import watcher