This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 1c9a6609f3 Adding MSGraphOperator in Microsoft Azure provider (#38111) 1c9a6609f3 is described below commit 1c9a6609f36a6fabddfd6d3858cca049d4088668 Author: David Blain <i...@dabla.be> AuthorDate: Sun Apr 14 21:39:56 2024 +0200 Adding MSGraphOperator in Microsoft Azure provider (#38111) * refactor: Initial commit contains the new MSGraphOperator * refactor: Extracted common method into Base class for patching airflow connection and request adapter + make multiple patches into one context manager Python 3.8 compatible * refactor: Refactored some typing issues related to msgraph * refactor: Added some docstrings and fixed additional typing issues * refactor: Fixed more static checks * refactor: Added license on top of test serializer and fixed import * Revert "refactor: Added license on top of test serializer and fixed import" This reverts commit 04d6b85494b0b9d3973564d4ac5abb718ac32cc7. * refactor: Added license on top of serializer files and fixed additional static checks * refactor: Added new line at end of json test files * refactor: Try fixing docstrings on operator and serializer * refactor: Replaced NoneType with None * refactor: Made type unions Python 3.8 compatible * refactor: Reformatted some files to comply with static checks formatting * refactor: Reformatted base to comply with static checks formatting * refactor: Added msgraph-core dependency to provider.yaml * refactor: Added msgraph integration info to provider.yaml * refactor: Added init in resources * fix: Fixed typing of response_handler * refactor: Added assertions on conn_id, tenant_id, client_id and client_secret * refactor: Fixed some static checks * Revert "refactor: Added assertions on conn_id, tenant_id, client_id and client_secret" This reverts commit 88aa7dccd95b98585872ae9eb5cd339162a06bb9. * refactor: Changed imports in hook as we don't use mockito anymore we don't need the module before constructor * refactor: Renamed test methods * refactor: Replace List type with list * refactor: Moved docstring as one line * refactor: Fixed typing for tests and added test for response_handler * refactor: Refactored tests * fix: Fixed MS Graph logo filename * refactor: Fixed additional static checks remarks * refactor: Added white line in type checking block * refactor: Added msgraph-core dependency to provider_dependencies.json * refactor: Updated docstring on response handler * refactor: Moved ResponseHandler and Serializer to triggers module * docs: Added documentation on how to use the MSGraphAsyncOperator * docs: Fixed END tag in examples * refactor: Removed docstring from CallableResponseHandler * refactor: Ignore UP031 Use format specifiers instead of percent format as this is not possible here the way the DAG is evaluated in Airflow (due to XCom's) * Revert "refactor: Removed docstring from CallableResponseHandler" This reverts commit 6a14ebe01936ca31ab188ab0fcbb40ba1960c3ba. * refactor: Simplified docstring on CallableResponseHandler * refactor: Updated provider.yaml to add reference of msgraph to how-to-guide * refactor: Updated docstrings on operator and trigger * refactor: Fixed additional static checks * refactor: Ignore UP031 Use format specifiers instead of percent format as this is not possible here the way the DAG is evaluated in Airflow (due to XCom's) * refactor: Added param to docstring ResponseHandler * refactor: Updated pyproject.toml as main * refactor: Reformatted docstrings in trigger * refactor: Removed unused serialization module * fix: Fixed execution of consecutive tasks in execute_operator method * refactor: Added customizable pagination_function parameter to Operator and made operator PowerBI compatible * refactor: Reformatted operator and trigger * refactor: Added check if query_parameters is not None * refactor: Removed typing of top and odata_count * refactor: Ignore type for tenant_id (this is an issue in the ClientSecretCredential class) * refactor: Changed docstring on MSGraphTrigger * refactor: Changed docstring on MSGraphTrigger * refactor: Added docstring to handle_response_async method * refactor: Fixed docstring to imperative for handle_response_async method * refactor: Try quoting Sharepoint so it doesn't get spell checked * refactor: Try double quoting Sharepoint so it doesn't get spell checked * refactor: Always get a new event loop and close it after test is done * refactor: Reordered imports from contextlib * refactor: Added Sharepoint to spelling_wordlist.txt * refactor: Removed connection-type for KiotaRequestAdapterHook * refactor: Refactored encoded_query_parameters * refactor: Suppress ImportError * refactor: Added return type to paginate method * refactor: Updated paging_function type in MSGraphAsyncOperator * refactor: Pass the method name from method reference instead of hard coded string which is re-factor friendly * refactor: Changed return type of paginate method * refactor: Added MSGraphSensor which easily allows us to poll PowerBI statuses * refactor: Moved BytesIO and Context to type checking block for MSGraphSensor * refactor: Added noqa check on pull_execute_complete method of MSGraphOperator * fix: Fixed test_serialize of TestMSGraphTrigger * refactor: Added docstring to MSGraphSensor and updated the docstring of the MSGraphAsyncOperator * refactor: Reformatted docstring of MSGraphSensor * refactor: Added white line at end of status.json file to keep static check happy * refactor: Removed timeout parameter from constructor MSGraphSensor as it is already defined in the BaseSensorOperator * fix: Added missing return for async_poke in MSGraphSensor * Revert "refactor: Added noqa check on pull_execute_complete method of MSGraphOperator" This reverts commit ca6f92cae94edeed190df1c2c807324e510bbae3. * refactor: Reorganised imports on MSGraphSensor * refactor: Reformatted TestMSGraphSensor * refactor: Added MSGraph sensor integration name in provider.yaml * refactor: Updated apache-airflow version to at least 2.7.0 in provider.yaml of microsoft-azure provider * refactor: Exclude microsoft-azure from compatibility check with airflow 2.6.0 as version 2.7.0 will at least be required * refactor: Also updated the apache-airflow dependency version from 2.6.0 to 2.7.0 for microsoft-azure provider in provider_dependencies.json * refactor: Reformatted global_constants.py * refactor: Add logging statements for proxies and authority related stuff * fix: Fixed exclusion of microsoft.azure dependency in global_constants.py * refactor: Some Azure related imports should be ignored when running Airflow 2.6.0 or lower * refactor: Import of ADLSListOperator should be ignored when running Airflow 2.6.0 or lower * refactor: Moved optional provider imports that should be ignored when running Airflow 2.6.0 or lower at top of file * refactor: Fixed the event loop closed issue when executing long running tests on the MSGraphOperator * refactor: Extracted reusable mock_context method * refactor: Moved import of Session into type checking block * refactor: Updated the TestMSGraphSensor * refactor: Reformatted the mock_context method * refactor: Try implementing cached connections on MSGraphTrigger * docs: Added example for the MSGraphSensor and additional examples on how you can use the operator for PowerBI * Revert "refactor: Try implementing cached connections on MSGraphTrigger" This reverts commit 693975eb8dbf8a2982f3b3d05a4d385b312009f9. * fix: Fixed serialization of event payload as xcom_value for the MSGraphSensor * refactor: TestMSGraphAsyncOperator should be allowed to run as a db test * Revert "refactor: TestMSGraphAsyncOperator should be allowed to run as a db test" This reverts commit c7a06dbab1c516e9ffb67ad279f425c640d7a851. * refactor: TestMSGraphAsyncOperator should be allowed to run as a db test * refactor: Also added result_processor to MSGraphSensor * refactor: Fixed template_fields in operator, trigger and sensor --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- .../amazon/aws/transfers/azure_blob_to_s3.py | 8 +- .../google/cloud/transfers/adls_to_gcs.py | 10 +- .../google/cloud/transfers/azure_blob_to_gcs.py | 8 +- .../cloud/transfers/azure_fileshare_to_gcs.py | 8 +- airflow/providers/microsoft/azure/hooks/msgraph.py | 208 ++++++++++++++ .../providers/microsoft/azure/operators/msgraph.py | 292 +++++++++++++++++++ airflow/providers/microsoft/azure/provider.yaml | 21 +- .../providers/microsoft/azure/sensors/msgraph.py | 163 +++++++++++ .../providers/microsoft/azure/triggers/msgraph.py | 316 +++++++++++++++++++++ dev/breeze/src/airflow_breeze/global_constants.py | 4 +- .../operators/msgraph.rst | 74 +++++ .../sensors/msgraph.rst | 42 +++ .../azure/Microsoft-Graph-API.png | Bin 0 -> 3784 bytes docs/spelling_wordlist.txt | 1 + generated/provider_dependencies.json | 5 +- tests/providers/microsoft/azure/base.py | 121 ++++++++ .../microsoft/azure/hooks/test_msgraph.py | 79 ++++++ .../microsoft/azure/operators/test_msgraph.py | 129 +++++++++ .../microsoft/azure/resources}/__init__.py | 0 .../providers/microsoft/azure/resources/dummy.pdf | Bin 0 -> 13264 bytes .../microsoft/azure/resources/next_users.json | 1 + .../microsoft/azure/resources/status.json | 1 + .../providers/microsoft/azure/resources/users.json | 1 + .../microsoft/azure/sensors/test_msgraph.py | 51 ++++ .../microsoft/azure/triggers/test_msgraph.py | 192 +++++++++++++ tests/providers/microsoft/conftest.py | 107 ++++++- .../providers/microsoft/azure/example_msgraph.py | 61 ++++ .../providers/microsoft/azure/example_powerbi.py | 80 ++++++ 28 files changed, 1973 insertions(+), 10 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py index d62931a765..9af93e212b 100644 --- a/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py @@ -23,7 +23,13 @@ from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + +try: + from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +except ModuleNotFoundError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/google/cloud/transfers/adls_to_gcs.py b/airflow/providers/google/cloud/transfers/adls_to_gcs.py index 7abbd9a9c3..f11b6aa881 100644 --- a/airflow/providers/google/cloud/transfers/adls_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/adls_to_gcs.py @@ -24,8 +24,14 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Sequence from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url -from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook -from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator + +try: + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator +except ModuleNotFoundError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py index 8ba6f2d6eb..1da9e82c09 100644 --- a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py @@ -22,7 +22,13 @@ from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook -from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + +try: + from airflow.providers.microsoft.azure.hooks.wasb import WasbHook +except ModuleNotFoundError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py index 9ba6129791..cca318001c 100644 --- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -24,7 +24,13 @@ from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory -from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook + +try: + from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook +except ModuleNotFoundError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py new file mode 100644 index 0000000000..7fcc328f86 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING +from urllib.parse import urljoin, urlparse + +import httpx +from azure.identity import ClientSecretCredential +from httpx import Timeout +from kiota_authentication_azure.azure_identity_authentication_provider import ( + AzureIdentityAuthenticationProvider, +) +from kiota_http.httpx_request_adapter import HttpxRequestAdapter +from msgraph_core import GraphClientFactory +from msgraph_core._enums import APIVersion, NationalClouds + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + +if TYPE_CHECKING: + from kiota_abstractions.request_adapter import RequestAdapter + + from airflow.models import Connection + + +class KiotaRequestAdapterHook(BaseHook): + """ + A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter. + + https://github.com/microsoftgraph/msgraph-sdk-python-core + + :param conn_id: The HTTP Connection ID to run the trigger against. + :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None). + When no timeout is specified or set to None then no HTTP timeout is applied on each request. + :param proxies: A Dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as "v1.0" or "beta". + """ + + cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {} + default_conn_name: str = "msgraph_default" + + def __init__( + self, + conn_id: str = default_conn_name, + timeout: float | None = None, + proxies: dict | None = None, + api_version: APIVersion | str | None = None, + ): + super().__init__() + self.conn_id = conn_id + self.timeout = timeout + self.proxies = proxies + self._api_version = self.resolve_api_version_from_value(api_version) + + @property + def api_version(self) -> APIVersion: + self.get_conn() # Make sure config has been loaded through get_conn to have correct api version! + return self._api_version + + @staticmethod + def resolve_api_version_from_value( + api_version: APIVersion | str, default: APIVersion | None = None + ) -> APIVersion: + if isinstance(api_version, APIVersion): + return api_version + return next( + filter(lambda version: version.value == api_version, APIVersion), + default, + ) + + def get_api_version(self, config: dict) -> APIVersion: + if self._api_version is None: + return self.resolve_api_version_from_value( + api_version=config.get("api_version"), default=APIVersion.v1 + ) + return self._api_version + + @staticmethod + def get_host(connection: Connection) -> str: + if connection.schema and connection.host: + return f"{connection.schema}://{connection.host}" + return NationalClouds.Global.value + + @staticmethod + def format_no_proxy_url(url: str) -> str: + if "://" not in url: + url = f"all://{url}" + return url + + @classmethod + def to_httpx_proxies(cls, proxies: dict) -> dict: + proxies = proxies.copy() + if proxies.get("http"): + proxies["http://"] = proxies.pop("http") + if proxies.get("https"): + proxies["https://"] = proxies.pop("https") + if proxies.get("no"): + for url in proxies.pop("no", "").split(","): + proxies[cls.format_no_proxy_url(url.strip())] = None + return proxies + + @classmethod + def to_msal_proxies(cls, authority: str | None, proxies: dict): + if authority: + no_proxies = proxies.get("no") + if no_proxies: + for url in no_proxies.split(","): + domain_name = urlparse(url).path.replace("*", "") + if authority.endswith(domain_name): + return None + return proxies + + def get_conn(self) -> RequestAdapter: + if not self.conn_id: + raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!") + + api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None)) + + if not request_adapter: + connection = self.get_connection(conn_id=self.conn_id) + client_id = connection.login + client_secret = connection.password + config = connection.extra_dejson if connection.extra else {} + tenant_id = config.get("tenant_id") + api_version = self.get_api_version(config) + host = self.get_host(connection) + base_url = config.get("base_url", urljoin(host, api_version.value)) + authority = config.get("authority") + proxies = self.proxies or config.get("proxies", {}) + msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) + httpx_proxies = self.to_httpx_proxies(proxies=proxies) + scopes = config.get("scopes", ["https://graph.microsoft.com/.default"]) + verify = config.get("verify", True) + trust_env = config.get("trust_env", False) + disable_instance_discovery = config.get("disable_instance_discovery", False) + allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",") + + self.log.info( + "Creating Microsoft Graph SDK client %s for conn_id: %s", + api_version.value, + self.conn_id, + ) + self.log.info("Host: %s", host) + self.log.info("Base URL: %s", base_url) + self.log.info("Tenant id: %s", tenant_id) + self.log.info("Client id: %s", client_id) + self.log.info("Client secret: %s", client_secret) + self.log.info("API version: %s", api_version.value) + self.log.info("Scope: %s", scopes) + self.log.info("Verify: %s", verify) + self.log.info("Timeout: %s", self.timeout) + self.log.info("Trust env: %s", trust_env) + self.log.info("Authority: %s", authority) + self.log.info("Disable instance discovery: %s", disable_instance_discovery) + self.log.info("Allowed hosts: %s", allowed_hosts) + self.log.info("Proxies: %s", proxies) + self.log.info("MSAL Proxies: %s", msal_proxies) + self.log.info("HTTPX Proxies: %s", httpx_proxies) + credentials = ClientSecretCredential( + tenant_id=tenant_id, # type: ignore + client_id=connection.login, + client_secret=connection.password, + authority=authority, + proxies=msal_proxies, + disable_instance_discovery=disable_instance_discovery, + connection_verify=verify, + ) + http_client = GraphClientFactory.create_with_default_middleware( + api_version=api_version, + client=httpx.AsyncClient( + proxies=httpx_proxies, + timeout=Timeout(timeout=self.timeout), + verify=verify, + trust_env=trust_env, + ), + host=host, + ) + auth_provider = AzureIdentityAuthenticationProvider( + credentials=credentials, + scopes=scopes, + allowed_hosts=allowed_hosts, + ) + request_adapter = HttpxRequestAdapter( + authentication_provider=auth_provider, + http_client=http_client, + base_url=base_url, + ) + self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) + self._api_version = api_version + return request_adapter diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py new file mode 100644 index 0000000000..6411f9cc4a --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -0,0 +1,292 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Sequence, +) + +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from airflow.providers.microsoft.azure.triggers.msgraph import ( + MSGraphTrigger, + ResponseSerializer, +) +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + from io import BytesIO + + from kiota_abstractions.request_adapter import ResponseType + from kiota_abstractions.request_information import QueryParams + from kiota_abstractions.response_handler import NativeResponseType + from kiota_abstractions.serialization import ParsableFactory + from msgraph_core import APIVersion + + from airflow.utils.context import Context + + +class MSGraphAsyncOperator(BaseOperator): + """ + A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API. + + https://learn.microsoft.com/en-us/graph/use-the-api + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MSGraphAsyncOperator` + + :param url: The url being executed on the Microsoft Graph API (templated). + :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, + `str`, `int`, `float`, `bool` and `datetime` (default is None). + :param response_handler: Function to convert the native HTTPX response returned by the hook (default is + lambda response, error_map: response.json()). The default expression will convert the native response + to JSON. If response_type parameter is specified, then the response_handler will be ignored. + :param method: The HTTP method being used to do the REST call (default is GET). + :param conn_id: The HTTP Connection ID to run the operator against (templated). + :param key: The key that will be used to store `XCom's` ("return_value" is default). + :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). + When no timeout is specified or set to None then there is no HTTP timeout on each request. + :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as `v1.0` or `beta`. + :param result_processor: Function to further process the response from MS Graph API + (default is lambda: context, response: response). When the response returned by the + `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. + :param serializer: Class which handles response serialization (default is ResponseSerializer). + Bytes will be base64 encoded into a string, so it can be stored as an XCom. + """ + + template_fields: Sequence[str] = ( + "url", + "response_type", + "path_parameters", + "url_template", + "query_parameters", + "headers", + "data", + "conn_id", + ) + + def __init__( + self, + *, + url: str, + response_type: ResponseType | None = None, + response_handler: Callable[ + [NativeResponseType, dict[str, ParsableFactory | None] | None], Any + ] = lambda response, error_map: response.json(), + path_parameters: dict[str, Any] | None = None, + url_template: str | None = None, + method: str = "GET", + query_parameters: dict[str, QueryParams] | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | BytesIO | None = None, + conn_id: str = KiotaRequestAdapterHook.default_conn_name, + key: str = XCOM_RETURN_KEY, + timeout: float | None = None, + proxies: dict | None = None, + api_version: APIVersion | None = None, + pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None, + result_processor: Callable[[Context, Any], Any] = lambda context, result: result, + serializer: type[ResponseSerializer] = ResponseSerializer, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.url = url + self.response_type = response_type + self.response_handler = response_handler + self.path_parameters = path_parameters + self.url_template = url_template + self.method = method + self.query_parameters = query_parameters + self.headers = headers + self.data = data + self.conn_id = conn_id + self.key = key + self.timeout = timeout + self.proxies = proxies + self.api_version = api_version + self.pagination_function = pagination_function or self.paginate + self.result_processor = result_processor + self.serializer: ResponseSerializer = serializer() + self.results: list[Any] | None = None + + def execute(self, context: Context) -> None: + self.log.info("Executing url '%s' as '%s'", self.url, self.method) + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, + ) + + def execute_complete( + self, + context: Context, + event: dict[Any, Any] | None = None, + ) -> Any: + """ + Execute callback when MSGraphTrigger finishes execution. + + This method gets executed automatically when MSGraphTrigger completes its execution. + """ + self.log.debug("context: %s", context) + + if event: + self.log.info("%s completed with %s: %s", self.task_id, event.get("status"), event) + + if event.get("status") == "failure": + raise AirflowException(event.get("message")) + + response = event.get("response") + + self.log.info("response: %s", response) + + if response: + response = self.serializer.deserialize(response) + + self.log.debug("deserialize response: %s", response) + + result = self.result_processor(context, response) + + self.log.debug("processed response: %s", result) + + event["response"] = result + + try: + self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__) + except TaskDeferred as exception: + self.append_result( + result=result, + append_result_as_list_if_absent=True, + ) + self.push_xcom(context=context, value=self.results) + raise exception + + self.append_result(result=result) + self.log.debug("results: %s", self.results) + + return self.results + return None + + def append_result( + self, + result: Any, + append_result_as_list_if_absent: bool = False, + ): + self.log.debug("value: %s", result) + + if isinstance(self.results, list): + if isinstance(result, list): + self.results.extend(result) + else: + self.results.append(result) + else: + if append_result_as_list_if_absent: + if isinstance(result, list): + self.results = result + else: + self.results = [result] + else: + self.results = result + + def push_xcom(self, context: Context, value) -> None: + self.log.debug("do_xcom_push: %s", self.do_xcom_push) + if self.do_xcom_push: + self.log.info("Pushing XCom with key '%s': %s", self.key, value) + self.xcom_push(context=context, key=self.key, value=value) + + def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any: + self.results = list( + self.xcom_pull( + context=context, + task_ids=self.task_id, + dag_id=self.dag_id, + key=self.key, + ) + or [] + ) + self.log.info( + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", + self.task_id, + self.dag_id, + self.key, + self.results, + ) + return self.execute_complete(context, event) + + @staticmethod + def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]: + odata_count = response.get("@odata.count") + if odata_count and operator.query_parameters: + query_parameters = deepcopy(operator.query_parameters) + top = query_parameters.get("$top") + odata_count = response.get("@odata.count") + + if top and odata_count: + if len(response.get("value", [])) == top: + skip = ( + sum(map(lambda result: len(result["value"]), operator.results)) + top + if operator.results + else top + ) + query_parameters["$skip"] = skip + return operator.url, query_parameters + return response.get("@odata.nextLink"), operator.query_parameters + + def trigger_next_link(self, response, method_name="execute_complete") -> None: + if isinstance(response, dict): + url, query_parameters = self.pagination_function(self, response) + + self.log.debug("url: %s", url) + self.log.debug("query_parameters: %s", query_parameters) + + if url: + self.defer( + trigger=MSGraphTrigger( + url=url, + query_parameters=query_parameters, + response_type=self.response_type, + response_handler=self.response_handler, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=method_name, + ) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 2ddc479b63..f1fa058f86 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -76,7 +76,7 @@ versions: - 1.0.0 dependencies: - - apache-airflow>=2.6.0 + - apache-airflow>=2.7.0 - adlfs>=2023.10.0 - azure-batch>=8.0.0 - azure-cosmos>=4.6.0 @@ -98,6 +98,7 @@ dependencies: - azure-mgmt-datafactory>=2.0.0 - azure-mgmt-containerregistry>=8.0.0 - azure-mgmt-containerinstance>=9.0.0 + - msgraph-core>=1.0.0 devel-dependencies: - pywinrm @@ -164,6 +165,12 @@ integrations: external-doc-url: https://azure.microsoft.com/en-us/products/storage/data-lake-storage/ logo: /integration-logos/azure/Data Lake Storage.svg tags: [azure] + - integration-name: Microsoft Graph API + external-doc-url: https://learn.microsoft.com/en-us/graph/use-the-api/ + logo: /integration-logos/azure/Microsoft-Graph-API.png + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -193,6 +200,9 @@ operators: - integration-name: Microsoft Azure Synapse python-modules: - airflow.providers.microsoft.azure.operators.synapse + - integration-name: Microsoft Graph API + python-modules: + - airflow.providers.microsoft.azure.operators.msgraph sensors: - integration-name: Microsoft Azure Cosmos DB @@ -204,6 +214,9 @@ sensors: - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.sensors.data_factory + - integration-name: Microsoft Graph API + python-modules: + - airflow.providers.microsoft.azure.sensors.msgraph filesystems: - airflow.providers.microsoft.azure.fs.adls @@ -247,6 +260,9 @@ hooks: - integration-name: Microsoft Azure Synapse python-modules: - airflow.providers.microsoft.azure.hooks.synapse + - integration-name: Microsoft Graph API + python-modules: + - airflow.providers.microsoft.azure.hooks.msgraph triggers: - integration-name: Microsoft Azure Data Factory @@ -255,6 +271,9 @@ triggers: - integration-name: Microsoft Azure Blob Storage python-modules: - airflow.providers.microsoft.azure.triggers.wasb + - integration-name: Microsoft Graph API + python-modules: + - airflow.providers.microsoft.azure.triggers.msgraph transfers: - source-integration-name: Local diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py new file mode 100644 index 0000000000..ffbf244dbe --- /dev/null +++ b/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue + +if TYPE_CHECKING: + from io import BytesIO + + from kiota_abstractions.request_information import QueryParams + from kiota_abstractions.response_handler import NativeResponseType + from kiota_abstractions.serialization import ParsableFactory + from kiota_http.httpx_request_adapter import ResponseType + from msgraph_core import APIVersion + + from airflow.triggers.base import TriggerEvent + from airflow.utils.context import Context + + +def default_event_processor(context: Context, event: TriggerEvent) -> bool: + if event.payload["status"] == "success": + return json.loads(event.payload["response"])["status"] == "Succeeded" + return False + + +class MSGraphSensor(BaseSensorOperator): + """ + A Microsoft Graph API sensor which allows you to poll an async REST call to the Microsoft Graph API. + + :param url: The url being executed on the Microsoft Graph API (templated). + :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, + `str`, `int`, `float`, `bool` and `datetime` (default is None). + :param response_handler: Function to convert the native HTTPX response returned by the hook (default is + lambda response, error_map: response.json()). The default expression will convert the native response + to JSON. If response_type parameter is specified, then the response_handler will be ignored. + :param method: The HTTP method being used to do the REST call (default is GET). + :param conn_id: The HTTP Connection ID to run the operator against (templated). + :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as `v1.0` or `beta`. + :param event_processor: Function which checks the response from MS Graph API (default is the + `default_event_processor` method) and returns a boolean. When the result is True, the sensor + will stop poking, otherwise it will continue until it's True or times out. + :param result_processor: Function to further process the response from MS Graph API + (default is lambda: context, response: response). When the response returned by the + `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. + :param serializer: Class which handles response serialization (default is ResponseSerializer). + Bytes will be base64 encoded into a string, so it can be stored as an XCom. + """ + + template_fields: Sequence[str] = ( + "url", + "response_type", + "path_parameters", + "url_template", + "query_parameters", + "headers", + "data", + "conn_id", + ) + + def __init__( + self, + url: str, + response_type: ResponseType | None = None, + response_handler: Callable[ + [NativeResponseType, dict[str, ParsableFactory | None] | None], Any + ] = lambda response, error_map: response.json(), + path_parameters: dict[str, Any] | None = None, + url_template: str | None = None, + method: str = "GET", + query_parameters: dict[str, QueryParams] | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | BytesIO | None = None, + conn_id: str = KiotaRequestAdapterHook.default_conn_name, + proxies: dict | None = None, + api_version: APIVersion | None = None, + event_processor: Callable[[Context, TriggerEvent], bool] = default_event_processor, + result_processor: Callable[[Context, Any], Any] = lambda context, result: result, + serializer: type[ResponseSerializer] = ResponseSerializer, + **kwargs, + ): + super().__init__(**kwargs) + self.url = url + self.response_type = response_type + self.response_handler = response_handler + self.path_parameters = path_parameters + self.url_template = url_template + self.method = method + self.query_parameters = query_parameters + self.headers = headers + self.data = data + self.conn_id = conn_id + self.proxies = proxies + self.api_version = api_version + self.event_processor = event_processor + self.result_processor = result_processor + self.serializer = serializer() + + @property + def trigger(self): + return MSGraphTrigger( + url=self.url, + response_type=self.response_type, + response_handler=self.response_handler, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + serializer=type(self.serializer), + ) + + async def async_poke(self, context: Context) -> bool | PokeReturnValue: + self.log.info("Sensor triggered") + + async for event in self.trigger.run(): + self.log.debug("event: %s", event) + + is_done = self.event_processor(context, event) + + self.log.debug("is_done: %s", is_done) + + response = self.serializer.deserialize(event.payload["response"]) + + self.log.debug("deserialize event: %s", response) + + result = self.result_processor(context, response) + + self.log.debug("result: %s", result) + + return PokeReturnValue(is_done=is_done, xcom_value=result) + return PokeReturnValue(is_done=True) + + def poke(self, context) -> bool | PokeReturnValue: + return asyncio.run(self.async_poke(context)) diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py new file mode 100644 index 0000000000..c0e5ee85a0 --- /dev/null +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -0,0 +1,316 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import locale +from base64 import b64encode +from contextlib import suppress +from datetime import datetime +from io import BytesIO +from json import JSONDecodeError +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Sequence, +) +from urllib.parse import quote +from uuid import UUID + +import pendulum +from kiota_abstractions.api_error import APIError +from kiota_abstractions.method import Method +from kiota_abstractions.request_information import RequestInformation +from kiota_abstractions.response_handler import ResponseHandler +from kiota_http.middleware.options import ResponseHandlerOption + +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.module_loading import import_string + +if TYPE_CHECKING: + from kiota_abstractions.request_adapter import RequestAdapter + from kiota_abstractions.request_information import QueryParams + from kiota_abstractions.response_handler import NativeResponseType + from kiota_abstractions.serialization import ParsableFactory + from kiota_http.httpx_request_adapter import ResponseType + from msgraph_core import APIVersion + + +class ResponseSerializer: + """ResponseSerializer serializes the response as a string.""" + + def __init__(self, encoding: str | None = None): + self.encoding = encoding or locale.getpreferredencoding() + + def serialize(self, response) -> str | None: + def convert(value) -> str | None: + if value is not None: + if isinstance(value, UUID): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, pendulum.DateTime): + return value.to_iso8601_string() # Adjust the format as needed + raise TypeError(f"Object of type {type(value)} is not JSON serializable!") + return None + + if response is not None: + if isinstance(response, bytes): + return b64encode(response).decode(self.encoding) + with suppress(JSONDecodeError): + return json.dumps(response, default=convert) + return response + return None + + def deserialize(self, response) -> Any: + if isinstance(response, str): + with suppress(JSONDecodeError): + response = json.loads(response) + return response + + +class CallableResponseHandler(ResponseHandler): + """ + CallableResponseHandler executes the passed callable_function with response as parameter. + + param callable_function: Function that is applied to the response. + """ + + def __init__( + self, + callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any], + ): + self.callable_function = callable_function + + async def handle_response_async( + self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None + ) -> Any: + """ + Invoke this callback method when a response is received. + + param response: The type of the native response object. + param error_map: The error dict to use in case of a failed request. + """ + return self.callable_function(response, error_map) + + +class MSGraphTrigger(BaseTrigger): + """ + A Microsoft Graph API trigger which allows you to execute an async REST call to the Microsoft Graph API. + + :param url: The url being executed on the Microsoft Graph API (templated). + :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, + `str`, `int`, `float`, `bool` and `datetime` (default is None). + :param response_handler: Function to convert the native HTTPX response returned by the hook (default is + lambda response, error_map: response.json()). The default expression will convert the native response + to JSON. If response_type parameter is specified, then the response_handler will be ignored. + :param method: The HTTP method being used to do the REST call (default is GET). + :param conn_id: The HTTP Connection ID to run the operator against (templated). + :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). + When no timeout is specified or set to None then there is no HTTP timeout on each request. + :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as `v1.0` or `beta`. + :param serializer: Class which handles response serialization (default is ResponseSerializer). + Bytes will be base64 encoded into a string, so it can be stored as an XCom. + """ + + DEFAULT_HEADERS = {"Accept": "application/json;q=1"} + template_fields: Sequence[str] = ( + "url", + "response_type", + "path_parameters", + "url_template", + "query_parameters", + "headers", + "data", + "conn_id", + ) + + def __init__( + self, + url: str, + response_type: ResponseType | None = None, + response_handler: Callable[ + [NativeResponseType, dict[str, ParsableFactory | None] | None], Any + ] = lambda response, error_map: response.json(), + path_parameters: dict[str, Any] | None = None, + url_template: str | None = None, + method: str = "GET", + query_parameters: dict[str, QueryParams] | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | BytesIO | None = None, + conn_id: str = KiotaRequestAdapterHook.default_conn_name, + timeout: float | None = None, + proxies: dict | None = None, + api_version: APIVersion | None = None, + serializer: type[ResponseSerializer] = ResponseSerializer, + ): + super().__init__() + self.hook = KiotaRequestAdapterHook( + conn_id=conn_id, + timeout=timeout, + proxies=proxies, + api_version=api_version, + ) + self.url = url + self.response_type = response_type + self.response_handler = response_handler + self.path_parameters = path_parameters + self.url_template = url_template + self.method = method + self.query_parameters = query_parameters + self.headers = headers + self.data = data + self.serializer: ResponseSerializer = self.resolve_type(serializer, default=ResponseSerializer)() + + @classmethod + def resolve_type(cls, value: str | type, default) -> type: + if isinstance(value, str): + with suppress(ImportError): + return import_string(value) + return default + return value or default + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the HttpTrigger arguments and classpath.""" + api_version = self.api_version.value if self.api_version else None + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}", + { + "conn_id": self.conn_id, + "timeout": self.timeout, + "proxies": self.proxies, + "api_version": api_version, + "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", + "url": self.url, + "path_parameters": self.path_parameters, + "url_template": self.url_template, + "method": self.method, + "query_parameters": self.query_parameters, + "headers": self.headers, + "data": self.data, + "response_type": self.response_type, + }, + ) + + def get_conn(self) -> RequestAdapter: + return self.hook.get_conn() + + @property + def conn_id(self) -> str: + return self.hook.conn_id + + @property + def timeout(self) -> float | None: + return self.hook.timeout + + @property + def proxies(self) -> dict | None: + return self.hook.proxies + + @property + def api_version(self) -> APIVersion: + return self.hook.api_version + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook.""" + try: + response = await self.execute() + + self.log.debug("response: %s", response) + + if response: + response_type = type(response) + + self.log.debug("response type: %s", response_type) + + yield TriggerEvent( + { + "status": "success", + "type": f"{response_type.__module__}.{response_type.__name__}", + "response": self.serializer.serialize(response), + } + ) + else: + yield TriggerEvent( + { + "status": "success", + "type": None, + "response": None, + } + ) + except Exception as e: + self.log.exception("An error occurred: %s", e) + yield TriggerEvent({"status": "failure", "message": str(e)}) + + def normalize_url(self) -> str | None: + if self.url.startswith("/"): + return self.url.replace("/", "", 1) + return self.url + + def encoded_query_parameters(self) -> dict: + if self.query_parameters: + return {quote(key): quote(str(value)) for key, value in self.query_parameters.items()} + return {} + + def request_information(self) -> RequestInformation: + request_information = RequestInformation() + request_information.path_parameters = self.path_parameters or {} + request_information.http_method = Method(self.method.strip().upper()) + request_information.query_parameters = self.encoded_query_parameters() + if self.url.startswith("http"): + request_information.url = self.url + elif request_information.query_parameters.keys(): + query = ",".join(request_information.query_parameters.keys()) + request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}{{?{query}}}" + else: + request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}" + if not self.response_type: + request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption( + response_handler=CallableResponseHandler(self.response_handler) + ) + headers = {**self.DEFAULT_HEADERS, **self.headers} if self.headers else self.DEFAULT_HEADERS + for header_name, header_value in headers.items(): + request_information.headers.try_add(header_name=header_name, header_value=header_value) + if isinstance(self.data, BytesIO) or isinstance(self.data, bytes) or isinstance(self.data, str): + request_information.content = self.data + elif self.data: + request_information.headers.try_add( + header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" + ) + request_information.content = json.dumps(self.data).encode("utf-8") + return request_information + + @staticmethod + def error_mapping() -> dict[str, ParsableFactory | None]: + return { + "4XX": APIError, + "5XX": APIError, + } + + async def execute(self) -> AsyncIterator[TriggerEvent]: + return await self.get_conn().send_primitive_async( + request_info=self.request_information(), + response_type=self.response_type, + error_map=self.error_mapping(), + ) diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index b527cafe3c..efc01b5885 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -473,7 +473,9 @@ BASE_PROVIDERS_COMPATIBILITY_CHECKS: list[dict[str, str]] = [ { "python-version": "3.8", "airflow-version": "2.6.0", - "remove-providers": _exclusion(["openlineage", "common.io", "cohere", "fab", "qdrant"]), + "remove-providers": _exclusion( + ["openlineage", "common.io", "cohere", "fab", "qdrant", "microsoft.azure"] + ), }, { "python-version": "3.8", diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst new file mode 100644 index 0000000000..817b14f783 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst @@ -0,0 +1,74 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Microsoft Graph API Operators +============================= + +Prerequisite Tasks +^^^^^^^^^^^^^^^^^^ + +.. include:: /operators/_partials/prerequisite_tasks.rst + +.. _howto/operator:MSGraphAsyncOperator: + +MSGraphAsyncOperator +---------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.operators.msgraph.MSGraphAsyncOperator` to call Microsoft Graph API. + + +Below is an example of using this operator to get a Sharepoint site. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msgraph.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_graph_site] + :end-before: [END howto_operator_graph_site] + +Below is an example of using this operator to get a Sharepoint site pages. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msgraph.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_graph_site_pages] + :end-before: [END howto_operator_graph_site_pages] + +Below is an example of using this operator to get PowerBI workspaces. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_powerbi_workspaces] + :end-before: [END howto_operator_powerbi_workspaces] + +Below is an example of using this operator to get PowerBI workspaces info. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_powerbi_workspaces_info] + :end-before: [END howto_operator_powerbi_workspaces_info] + + +Reference +--------- + +For further information, look at: + +* `Use the Microsoft Graph API <https://learn.microsoft.com/en-us/graph/use-the-api/>`__ +* `Using the Power BI REST APIs <https://learn.microsoft.com/en-us/rest/api/power-bi/>`__ diff --git a/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst new file mode 100644 index 0000000000..4ddad88f19 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst @@ -0,0 +1,42 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Microsoft Graph API Sensors +============================= + +MSGraphSensor +------------- +Use the +:class:`~airflow.providers.microsoft.azure.sensors.msgraph.MSGraphSensor` to poll a Power BI API. + + +Below is an example of using this sensor to poll the status of a PowerBI workspace. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py + :language: python + :dedent: 0 + :start-after: [START howto_sensor_powerbi_scan_status] + :end-before: [END howto_sensor_powerbi_scan_status] + + +Reference +--------- + +For further information, look at: + +* `Using the Power BI REST APIs <https://learn.microsoft.com/en-us/rest/api/power-bi/>`__ diff --git a/docs/integration-logos/azure/Microsoft-Graph-API.png b/docs/integration-logos/azure/Microsoft-Graph-API.png new file mode 100644 index 0000000000..0724a1e09b Binary files /dev/null and b/docs/integration-logos/azure/Microsoft-Graph-API.png differ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 48b4189e1a..dcd8641d80 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1444,6 +1444,7 @@ setted sftp SFTPClient sharded +Sharepoint shellcheck shellcmd shm diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 9315766f81..841f376467 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -680,7 +680,7 @@ "deps": [ "adal>=1.2.7", "adlfs>=2023.10.0", - "apache-airflow>=2.6.0", + "apache-airflow>=2.7.0", "azure-batch>=8.0.0", "azure-cosmos>=4.6.0", "azure-datalake-store>=0.0.45", @@ -699,7 +699,8 @@ "azure-storage-file-datalake>=12.9.1", "azure-storage-file-share", "azure-synapse-artifacts>=0.17.0", - "azure-synapse-spark" + "azure-synapse-spark", + "msgraph-core>=1.0.0" ], "devel-deps": [ "pywinrm" diff --git a/tests/providers/microsoft/azure/base.py b/tests/providers/microsoft/azure/base.py new file mode 100644 index 0000000000..4cda62858e --- /dev/null +++ b/tests/providers/microsoft/azure/base.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import contextmanager +from copy import deepcopy +from datetime import datetime +from typing import TYPE_CHECKING, Any, Iterable +from unittest.mock import patch + +from kiota_http.httpx_request_adapter import HttpxRequestAdapter + +from airflow.exceptions import TaskDeferred +from airflow.models import Operator, TaskInstance +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from airflow.utils.session import NEW_SESSION +from airflow.utils.xcom import XCOM_RETURN_KEY +from tests.providers.microsoft.conftest import get_airflow_connection, mock_context + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class MockedTaskInstance(TaskInstance): + values = {} + + def xcom_pull( + self, + task_ids: Iterable[str] | str | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + *, + map_indexes: Iterable[int] | int | None = None, + default: Any | None = None, + ) -> Any: + self.task_id = task_ids + self.dag_id = dag_id + return self.values.get(f"{task_ids}_{dag_id}_{key}") + + def xcom_push( + self, + key: str, + value: Any, + execution_date: datetime | None = None, + session: Session = NEW_SESSION, + ) -> None: + self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value + + +class Base: + def teardown_method(self, method): + KiotaRequestAdapterHook.cached_request_adapters.clear() + MockedTaskInstance.values.clear() + + @contextmanager + def patch_hook_and_request_adapter(self, response): + with patch( + "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection + ), patch.object(HttpxRequestAdapter, "get_http_response_message") as mock_get_http_response: + if isinstance(response, Exception): + mock_get_http_response.side_effect = response + else: + mock_get_http_response.return_value = response + yield + + @staticmethod + async def _run_tigger(trigger: BaseTrigger) -> list[TriggerEvent]: + events = [] + async for event in trigger.run(): + events.append(event) + return events + + def run_trigger(self, trigger: BaseTrigger) -> list[TriggerEvent]: + return asyncio.run(self._run_tigger(trigger)) + + def execute_operator(self, operator: Operator) -> tuple[Any, Any]: + context = mock_context(task=operator) + return asyncio.run(self.deferrable_operator(context, operator)) + + async def deferrable_operator(self, context, operator): + result = None + triggered_events = [] + try: + result = operator.execute(context=context) + except TaskDeferred as deferred: + task = deferred + + while task: + events = await self._run_tigger(task.trigger) + + if not events: + break + + triggered_events.extend(deepcopy(events)) + + try: + method = getattr(operator, task.method_name) + result = method(context=context, event=next(iter(events)).payload) + task = None + except TaskDeferred as exception: + task = exception + return result, triggered_events diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py new file mode 100644 index 0000000000..1c1046e1fa --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +from kiota_http.httpx_request_adapter import HttpxRequestAdapter +from msgraph_core import APIVersion, NationalClouds + +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from tests.providers.microsoft.conftest import get_airflow_connection, mock_connection + + +class TestKiotaRequestAdapterHook: + def test_get_conn(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + assert isinstance(actual, HttpxRequestAdapter) + assert actual.base_url == "https://graph.microsoft.com/v1.0" + + def test_api_version(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + + assert hook.api_version == APIVersion.v1 + + def test_get_api_version_when_empty_config_dict(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_api_version({}) + + assert actual == APIVersion.v1 + + def test_get_api_version_when_api_version_in_config_dict(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_api_version({"api_version": "beta"}) + + assert actual == APIVersion.beta + + def test_get_host_when_connection_has_scheme_and_host(self): + connection = mock_connection(schema="https", host="graph.microsoft.de") + actual = KiotaRequestAdapterHook.get_host(connection) + + assert actual == NationalClouds.Germany.value + + def test_get_host_when_connection_has_no_scheme_or_host(self): + connection = mock_connection() + actual = KiotaRequestAdapterHook.get_host(connection) + + assert actual == NationalClouds.Global.value diff --git a/tests/providers/microsoft/azure/operators/test_msgraph.py b/tests/providers/microsoft/azure/operators/test_msgraph.py new file mode 100644 index 0000000000..b7520d7315 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_msgraph.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import locale +from base64 import b64encode + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator +from airflow.triggers.base import TriggerEvent +from tests.providers.microsoft.azure.base import Base +from tests.providers.microsoft.conftest import load_file, load_json, mock_json_response, mock_response + + +class TestMSGraphAsyncOperator(Base): + @pytest.mark.db_test + def test_execute(self): + users = load_json("resources", "users.json") + next_users = load_json("resources", "next_users.json") + response = mock_json_response(200, users, next_users) + + with self.patch_hook_and_request_adapter(response): + operator = MSGraphAsyncOperator( + task_id="users_delta", + conn_id="msgraph_api", + url="users", + result_processor=lambda context, result: result.get("value"), + ) + + results, events = self.execute_operator(operator) + + assert len(results) == 30 + assert results == users.get("value") + next_users.get("value") + assert len(events) == 2 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(users) + assert isinstance(events[1], TriggerEvent) + assert events[1].payload["status"] == "success" + assert events[1].payload["type"] == "builtins.dict" + assert events[1].payload["response"] == json.dumps(next_users) + + @pytest.mark.db_test + def test_execute_when_do_xcom_push_is_false(self): + users = load_json("resources", "users.json") + users.pop("@odata.nextLink") + response = mock_json_response(200, users) + + with self.patch_hook_and_request_adapter(response): + operator = MSGraphAsyncOperator( + task_id="users_delta", + conn_id="msgraph_api", + url="users/delta", + do_xcom_push=False, + ) + + results, events = self.execute_operator(operator) + + assert isinstance(results, dict) + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(users) + + @pytest.mark.db_test + def test_execute_when_an_exception_occurs(self): + with self.patch_hook_and_request_adapter(AirflowException()): + operator = MSGraphAsyncOperator( + task_id="users_delta", + conn_id="msgraph_api", + url="users/delta", + do_xcom_push=False, + ) + + with pytest.raises(AirflowException): + self.execute_operator(operator) + + @pytest.mark.db_test + def test_execute_when_response_is_bytes(self): + content = load_file("resources", "dummy.pdf", mode="rb", encoding=None) + base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding()) + drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22" + response = mock_response(200, content) + + with self.patch_hook_and_request_adapter(response): + operator = MSGraphAsyncOperator( + task_id="drive_item_content", + conn_id="msgraph_api", + response_type="bytes", + url=f"/drives/{drive_id}/root/content", + ) + + results, events = self.execute_operator(operator) + + assert results == base64_encoded_content + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.bytes" + assert events[0].payload["response"] == base64_encoded_content + + def test_template_fields(self): + operator = MSGraphAsyncOperator( + task_id="drive_item_content", + conn_id="msgraph_api", + url="users/delta", + ) + + for template_field in MSGraphAsyncOperator.template_fields: + getattr(operator, template_field) diff --git a/airflow/providers/microsoft/azure/serialization/__init__.py b/tests/providers/microsoft/azure/resources/__init__.py similarity index 100% rename from airflow/providers/microsoft/azure/serialization/__init__.py rename to tests/providers/microsoft/azure/resources/__init__.py diff --git a/tests/providers/microsoft/azure/resources/dummy.pdf b/tests/providers/microsoft/azure/resources/dummy.pdf new file mode 100644 index 0000000000..774c2ea70c Binary files /dev/null and b/tests/providers/microsoft/azure/resources/dummy.pdf differ diff --git a/tests/providers/microsoft/azure/resources/next_users.json b/tests/providers/microsoft/azure/resources/next_users.json new file mode 100644 index 0000000000..3a88cf08b2 --- /dev/null +++ b/tests/providers/microsoft/azure/resources/next_users.json @@ -0,0 +1 @@ +{"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users(displayName,description,mailNickname)", "value": [{"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/c52a9941-e5cb-49cb-9972-6f45cd7cd447/Microsoft.DirectoryServices.User", "displayName": "Leonardo DiCaprio", "mailNickname": "LeoD"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78- [...] diff --git a/tests/providers/microsoft/azure/resources/status.json b/tests/providers/microsoft/azure/resources/status.json new file mode 100644 index 0000000000..6bff9e29af --- /dev/null +++ b/tests/providers/microsoft/azure/resources/status.json @@ -0,0 +1 @@ +{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"} diff --git a/tests/providers/microsoft/azure/resources/users.json b/tests/providers/microsoft/azure/resources/users.json new file mode 100644 index 0000000000..617e6f8420 --- /dev/null +++ b/tests/providers/microsoft/azure/resources/users.json @@ -0,0 +1 @@ +{"@odata.nextLink": "https://graph.microsoft.com/v1.0/users/delta()?$skiptoken=qLMhmnoTon81CQ1VVWyx9MNESqEIMKNpEWZWAfnn5F7tBNFuSgWh_pXZOweu67nEThGR0yQewi_a3Ixe75S6PoB8pdllphCEev0fMe5Uc1lWMtn3byOS8_OPTzPGZIZ17x-dVyxaE_4I55YyLJ0cgBxg8wsBrkYgaNE9vy5Su2HeCKxJODDQk4zRgP8QGo0pZatReTpqisVbrW5Gl1H_Xgy4lhenv1SmoRcBQtWBa5iAh-MURoaTo7i0kQjFhH6SCrkjBkfkRFVy9dafOOt2Owbxfn5hKGfEnfmG0RBmgdUsZPgX-ap0mjjf7PjExoxMek4CDnb8Yv737oGkh9C_G0XTJGeGxPBbkD-w4SaQookde4yxOzceAw1MuamBy63uJdbXt1ul61tDvfPwrJVHq99FxGU1n [...] diff --git a/tests/providers/microsoft/azure/sensors/test_msgraph.py b/tests/providers/microsoft/azure/sensors/test_msgraph.py new file mode 100644 index 0000000000..50fd2474ab --- /dev/null +++ b/tests/providers/microsoft/azure/sensors/test_msgraph.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor +from tests.providers.microsoft.azure.base import Base +from tests.providers.microsoft.conftest import load_json, mock_context, mock_json_response + + +class TestMSGraphSensor(Base): + def test_execute(self): + status = load_json("resources", "status.json") + response = mock_json_response(200, status) + + with self.patch_hook_and_request_adapter(response): + sensor = MSGraphSensor( + task_id="check_workspaces_status", + conn_id="powerbi", + url="myorg/admin/workspaces/scanStatus/{scanId}", + path_parameters={"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}, + result_processor=lambda context, result: result["id"], + timeout=350.0, + ) + actual = sensor.execute(context=mock_context(task=sensor)) + + assert isinstance(actual, str) + assert actual == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" + + def test_template_fields(self): + sensor = MSGraphSensor( + task_id="check_workspaces_status", + conn_id="powerbi", + url="myorg/admin/workspaces/scanStatus/{scanId}", + ) + + for template_field in MSGraphSensor.template_fields: + getattr(sensor, template_field) diff --git a/tests/providers/microsoft/azure/triggers/test_msgraph.py b/tests/providers/microsoft/azure/triggers/test_msgraph.py new file mode 100644 index 0000000000..900d0875cd --- /dev/null +++ b/tests/providers/microsoft/azure/triggers/test_msgraph.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import json +import locale +from base64 import b64decode, b64encode +from datetime import datetime +from unittest.mock import patch +from uuid import uuid4 + +import pendulum + +from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.triggers.msgraph import ( + CallableResponseHandler, + MSGraphTrigger, + ResponseSerializer, +) +from airflow.triggers.base import TriggerEvent +from tests.providers.microsoft.azure.base import Base +from tests.providers.microsoft.conftest import ( + get_airflow_connection, + load_file, + load_json, + mock_json_response, + mock_response, +) + + +class TestMSGraphTrigger(Base): + def test_run_when_valid_response(self): + users = load_json("resources", "users.json") + response = mock_json_response(200, users) + + with self.patch_hook_and_request_adapter(response): + trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") + actual = self.run_trigger(trigger) + + assert len(actual) == 1 + assert isinstance(actual[0], TriggerEvent) + assert actual[0].payload["status"] == "success" + assert actual[0].payload["type"] == "builtins.dict" + assert actual[0].payload["response"] == json.dumps(users) + + def test_run_when_response_is_none(self): + response = mock_json_response(200) + + with self.patch_hook_and_request_adapter(response): + trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") + actual = self.run_trigger(trigger) + + assert len(actual) == 1 + assert isinstance(actual[0], TriggerEvent) + assert actual[0].payload["status"] == "success" + assert actual[0].payload["type"] is None + assert actual[0].payload["response"] is None + + def test_run_when_response_cannot_be_converted_to_json(self): + with self.patch_hook_and_request_adapter(AirflowException()): + trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") + actual = next(iter(self.run_trigger(trigger))) + + assert isinstance(actual, TriggerEvent) + assert actual.payload["status"] == "failure" + assert actual.payload["message"] == "" + + def test_run_when_response_is_bytes(self): + content = load_file("resources", "dummy.pdf", mode="rb", encoding=None) + base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding()) + response = mock_response(200, content) + + with self.patch_hook_and_request_adapter(response): + url = ( + "https://graph.microsoft.com/v1.0/me/drive/items/1b30fecf-4330-4899-b249-104c2afaf9ed/content" + ) + trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api") + actual = next(iter(self.run_trigger(trigger))) + + assert isinstance(actual, TriggerEvent) + assert actual.payload["status"] == "success" + assert actual.payload["type"] == "builtins.bytes" + assert isinstance(actual.payload["response"], str) + assert actual.payload["response"] == base64_encoded_content + + def test_serialize(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + url = "https://graph.microsoft.com/v1.0/me/drive/items" + trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api") + + actual = trigger.serialize() + + assert isinstance(actual, tuple) + assert actual[0] == "airflow.providers.microsoft.azure.triggers.msgraph.MSGraphTrigger" + assert actual[1] == { + "url": "https://graph.microsoft.com/v1.0/me/drive/items", + "path_parameters": None, + "url_template": None, + "method": "GET", + "query_parameters": None, + "headers": None, + "data": None, + "response_type": "bytes", + "conn_id": "msgraph_api", + "timeout": None, + "proxies": None, + "api_version": "v1.0", + "serializer": "airflow.providers.microsoft.azure.triggers.msgraph.ResponseSerializer", + } + + def test_template_fields(self): + trigger = MSGraphTrigger("users/delta", response_type="bytes", conn_id="msgraph_api") + + for template_field in MSGraphTrigger.template_fields: + getattr(trigger, template_field) + + +class TestResponseHandler: + def test_handle_response_async(self): + users = load_json("resources", "users.json") + response = mock_json_response(200, users) + + actual = asyncio.run( + CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( + response, None + ) + ) + + assert isinstance(actual, dict) + assert actual == users + + +class TestResponseSerializer: + def test_serialize_when_bytes_then_base64_encoded(self): + response = load_file("resources", "dummy.pdf", mode="rb", encoding=None) + content = b64encode(response).decode(locale.getpreferredencoding()) + + actual = ResponseSerializer().serialize(response) + + assert isinstance(actual, str) + assert actual == content + + def test_serialize_when_dict_with_uuid_datatime_and_pendulum_then_json(self): + id = uuid4() + response = { + "id": id, + "creationDate": datetime(2024, 2, 5), + "modificationTime": pendulum.datetime(2024, 2, 5), + } + + actual = ResponseSerializer().serialize(response) + + assert isinstance(actual, str) + assert ( + actual + == f'{{"id": "{id}", "creationDate": "2024-02-05T00:00:00", "modificationTime": "2024-02-05T00:00:00+00:00"}}' + ) + + def test_deserialize_when_json(self): + response = load_file("resources", "users.json") + + actual = ResponseSerializer().deserialize(response) + + assert isinstance(actual, dict) + assert actual == load_json("resources", "users.json") + + def test_deserialize_when_base64_encoded_string(self): + content = load_file("resources", "dummy.pdf", mode="rb", encoding=None) + response = b64encode(content).decode(locale.getpreferredencoding()) + + actual = ResponseSerializer().deserialize(response) + + assert actual == response + assert b64decode(actual) == content diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index bcf5aa65fe..78d8748a89 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -17,14 +17,22 @@ from __future__ import annotations +import json import random import string -from typing import TypeVar +from os.path import dirname, join +from typing import TYPE_CHECKING, Any, Iterable, TypeVar +from unittest.mock import MagicMock import pytest +from httpx import Response +from msgraph_core import APIVersion from airflow.models import Connection +if TYPE_CHECKING: + from sqlalchemy.orm import Session + T = TypeVar("T", dict, str, Connection) @@ -68,3 +76,100 @@ def create_mock_connections(create_mock_connection): def mocked_connection(request, create_mock_connection): """Helper indirect fixture for create test connection.""" return create_mock_connection(request.param) + + +def mock_connection(schema: str | None = None, host: str | None = None) -> Connection: + connection = MagicMock(spec=Connection) + connection.schema = schema + connection.host = host + return connection + + +def mock_json_response(status_code, *contents) -> Response: + response = MagicMock(spec=Response) + response.status_code = status_code + if contents: + contents = list(contents) + response.json.side_effect = lambda: contents.pop(0) + else: + response.json.return_value = None + return response + + +def mock_response(status_code, content: Any = None) -> Response: + response = MagicMock(spec=Response) + response.status_code = status_code + response.content = content + return response + + +def mock_context(task): + from datetime import datetime + + from airflow.models import TaskInstance + from airflow.utils.session import NEW_SESSION + from airflow.utils.state import TaskInstanceState + from airflow.utils.xcom import XCOM_RETURN_KEY + + class MockedTaskInstance(TaskInstance): + def __init__(self): + super().__init__(task=task, run_id="run_id", state=TaskInstanceState.RUNNING) + self.values = {} + + def xcom_pull( + self, + task_ids: Iterable[str] | str | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + *, + map_indexes: Iterable[int] | int | None = None, + default: Any | None = None, + ) -> Any: + self.task_id = task_ids + self.dag_id = dag_id + return self.values.get(f"{task_ids}_{dag_id}_{key}") + + def xcom_push( + self, + key: str, + value: Any, + execution_date: datetime | None = None, + session: Session = NEW_SESSION, + ) -> None: + self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value + + return {"ti": MockedTaskInstance()} + + +def load_json(*locations: Iterable[str]): + with open(join(dirname(__file__), "azure", join(*locations)), encoding="utf-8") as file: + return json.load(file) + + +def load_file(*locations: Iterable[str], mode="r", encoding="utf-8"): + with open(join(dirname(__file__), "azure", join(*locations)), mode=mode, encoding=encoding) as file: + return file.read() + + +def get_airflow_connection( + conn_id: str, + login: str = "client_id", + password: str = "client_secret", + tenant_id: str = "tenant-id", + proxies: (dict, None) = None, + api_version: APIVersion = APIVersion.v1, +): + from airflow.models import Connection + + return Connection( + schema="https", + conn_id=conn_id, + conn_type="http", + host="graph.microsoft.com", + port="80", + login=login, + password=password, + extra={"tenant_id": tenant_id, "api_version": api_version.value, "proxies": proxies or {}}, + ) diff --git a/tests/system/providers/microsoft/azure/example_msgraph.py b/tests/system/providers/microsoft/azure/example_msgraph.py new file mode 100644 index 0000000000..5ff7ba6f88 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_msgraph.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import models +from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator + +DAG_ID = "example_sharepoint_site" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_graph_site] + site_task = MSGraphAsyncOperator( + task_id="news_site", + conn_id="msgraph_api", + url="sites/850v1v.sharepoint.com:/sites/news", + result_processor=lambda context, response: response["id"].split(",")[1], # only keep site_id + ) + # [END howto_operator_graph_site] + + # [START howto_operator_graph_site_pages] + site_pages_task = MSGraphAsyncOperator( + task_id="news_pages", + conn_id="msgraph_api", + api_version="beta", + url=("sites/%s/pages" % "{{ ti.xcom_pull(task_ids='news_site') }}"), # noqa: UP031 + ) + # [END howto_operator_graph_site_pages] + + site_task >> site_pages_task + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/microsoft/azure/example_powerbi.py b/tests/system/providers/microsoft/azure/example_powerbi.py new file mode 100644 index 0000000000..cbee9a62af --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_powerbi.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import models +from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator +from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor + +DAG_ID = "example_powerbi" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_powerbi_workspaces] + workspaces_task = MSGraphAsyncOperator( + task_id="workspaces", + conn_id="powerbi", + url="myorg/admin/workspaces/modified", + result_processor=lambda context, response: list(map(lambda workspace: workspace["id"], response)), + ) + # [END howto_operator_powerbi_workspaces] + + # [START howto_operator_powerbi_workspaces_info] + workspaces_info_task = MSGraphAsyncOperator( + task_id="get_workspace_info", + conn_id="powerbi", + url="myorg/admin/workspaces/getInfo", + method="POST", + query_parameters={ + "lineage": True, + "datasourceDetails": True, + "datasetSchema": True, + "datasetExpressions": True, + "getArtifactUsers": True, + }, + data={"workspaces": workspaces_task.output}, + result_processor=lambda context, response: {"scanId": response["id"]}, + ) + # [END howto_operator_powerbi_workspaces_info] + + # [START howto_sensor_powerbi_scan_status] + check_workspace_status_task = MSGraphSensor.partial( + task_id="check_workspaces_status", + conn_id="powerbi_api", + url="myorg/admin/workspaces/scanStatus/{scanId}", + timeout=350.0, + ).expand(path_parameters=workspaces_info_task.output) + # [END howto_sensor_powerbi_scan_status] + + workspaces_task >> workspaces_info_task >> check_workspace_status_task + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)