This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new ff1e3a6f617 Added support for certificate authentication with MSGraphAsyncOperator (#45935) ff1e3a6f617 is described below commit ff1e3a6f617bb14b6007daff6ed8296d4a78c20d Author: David Blain <i...@dabla.be> AuthorDate: Sun Jan 26 07:42:02 2025 +0100 Added support for certificate authentication with MSGraphAsyncOperator (#45935) * refactor: Added support for certificate authentication in KiotaRequestAdapterHook * refactor: Fixed label for allowed hosts in MS Graph connection form --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- .../connections/images/msgraph.png | Bin 0 -> 71754 bytes .../connections/msgraph.rst | 135 +++++++++++++++++++++ .../providers/microsoft/azure/hooks/msgraph.py | 83 ++++++++++--- .../providers/microsoft/azure/operators/msgraph.py | 4 + .../providers/microsoft/azure/sensors/msgraph.py | 4 + .../providers/microsoft/azure/triggers/msgraph.py | 4 + .../tests/microsoft/azure/hooks/test_msgraph.py | 31 +++++ .../tests/microsoft/azure/triggers/test_msgraph.py | 9 +- 8 files changed, 247 insertions(+), 23 deletions(-) diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/images/msgraph.png b/docs/apache-airflow-providers-microsoft-azure/connections/images/msgraph.png new file mode 100644 index 00000000000..36c36fa7831 Binary files /dev/null and b/docs/apache-airflow-providers-microsoft-azure/connections/images/msgraph.png differ diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/connections/msgraph.rst new file mode 100644 index 00000000000..8c817416162 --- /dev/null +++ b/docs/apache-airflow-providers-microsoft-azure/connections/msgraph.rst @@ -0,0 +1,135 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:msgraph: + +Microsoft Graph API Connection +============================== + +The Microsoft Graph API connection type enables Microsoft Graph API Integrations. + +The :class:`~airflow.providers.microsoft.azure.hooks.msgraph.KiotaRequestAdapterHook` and :class:`~airflow.providers.microsoft.azure.operators.msgraph.MSGraphAsyncOperator` requires a connection of type ``msgraph`` to authenticate with Microsoft Graph API. + +Authenticating to Microsoft Graph API +------------------------------------- + +1. Use `token credentials + <https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-token-credentials>`_ + i.e. add specific credentials (client_id, client_secret, tenant_id) to the Airflow connection. + +Default Connection IDs +---------------------- + +All hooks and operators related to Microsoft Graph API use ``msgraph_default`` by default. + +Configuring the Connection +-------------------------- + +Client ID + Specify the ``client_id`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + + +Client Secret + Specify the ``client_secret`` used for the initial connection. + This is needed for *token credentials* authentication mechanism unless a certificate is used. + + +Tenant ID + Specify the ``tenant_id`` used for the initial connection. + This is needed for *token credentials* authentication mechanism. + + +API Version + Specify the ``api_version`` used for the initial connection. + Default value is ``v1.0``. + + +Authority + The ``authority`` parameter defines the endpoint (or tenant) that MSAL uses to authenticate requests. + It determines which identity provider will handle authentication. + Default value is ``login.microsoftonline.com``. + + +Scopes + The ``scopes`` parameter specifies the permissions or access rights that your application is requesting for a connection. + These permissions define what resources or data your application can access on behalf of the user or application. + Default value is ``https://graph.microsoft.com/.default``. + + +Certificate path + The ``certificate_path`` parameter specifies the filepath where the certificate is located. + Both ``certificate_path`` and ``certificate_data`` parameter cannot be used together, they should be mutually exclusive. + Default value is None. + + +Certificate data + The ``certificate_date`` parameter specifies the certificate as a string. + Both ``certificate_path`` and ``certificate_data`` parameter cannot be used together, they should be mutually exclusive. + Default value is None. + + +Disable instance discovery + The ``disable_instance_discovery`` parameter determines whether MSAL should validate and discover Azure AD endpoints dynamically during runtime. + Default value is False (e.g. disabled). + + +Allowed hosts + The ``allowed_hosts`` parameter is used to define a list of acceptable hosts that the authentication provider will trust when making requests. + This parameter is particularly useful for enhancing security and controlling which endpoints the authentication provider interacts with. + + +Proxies + The ``proxies`` parameter is used to define a dict for the ``http`` and ``https`` schema, the ``no`` key can be use to define hosts not to be used by the proxy. + Default value is None. + + +Verify environment + The ``verify`` parameter specifies whether SSL certificates should be verified when making HTTPS requests. + By default, ``verify`` parameter is set to True. This means that the `httpx <https://www.python-httpx.org>`_ library will verify the SSL certificate presented by the server to ensure: + + - The certificate is valid and trusted. + - The certificate matches the hostname of the server. + - The certificate has not expired or been revoked. + + Setting ``verify`` to False disables SSL certificate verification. This is typically used in development or testing environments when working with self-signed certificates or servers without valid certificates. + + +Trust environment + The ``trust_env`` parameter determines whether or not the library should use environment variables for configuration when making HTTP/HTTPS requests. + By default, ``trust_env`` parameter is set to True. This means the `httpx <https://www.python-httpx.org>`_ library will automatically trust and use environment variables for proxy configuration, SSL settings, and authentication. + + +Base URL + The ``base_url`` parameter allows you to override the default base url used to make it requests, namely ``https://graph.microsoft.com/``. + This can be useful if you want to use the MSGraphAsyncOperator to call other Microsoft REST API's like Sharepoint or PowerBI. + Default value is None. + + +.. raw:: html + + <div align="center" style="padding-bottom:10px"> + <img src="images/msgraph.png" + alt="Microsoft Graph API connection form"> + </div> + + +.. spelling:word-list:: + + Entra diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py index f01fa1c5858..1754d04b103 100644 --- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any from urllib.parse import quote, urljoin, urlparse import httpx -from azure.identity import ClientSecretCredential +from azure.identity import CertificateCredential, ClientSecretCredential from httpx import AsyncHTTPTransport, Timeout from kiota_abstractions.api_error import APIError from kiota_abstractions.method import Method @@ -47,6 +47,7 @@ from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFo from airflow.hooks.base import BaseHook if TYPE_CHECKING: + from azure.identity._internal.client_credential_base import ClientCredentialBase from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.request_information import QueryParams from kiota_abstractions.response_handler import NativeResponseType @@ -107,6 +108,7 @@ class KiotaRequestAdapterHook(BaseHook): """ DEFAULT_HEADERS = {"Accept": "application/json;q=1"} + DEFAULT_SCOPE = "https://graph.microsoft.com/.default" cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {} conn_type: str = "msgraph" conn_name_attr: str = "conn_id" @@ -119,7 +121,7 @@ class KiotaRequestAdapterHook(BaseHook): timeout: float | None = None, proxies: dict | None = None, host: str = NationalClouds.Global.value, - scopes: list[str] | None = None, + scopes: str | list[str] | None = None, api_version: APIVersion | str | None = None, ): super().__init__() @@ -127,7 +129,10 @@ class KiotaRequestAdapterHook(BaseHook): self.timeout = timeout self.proxies = proxies self.host = host - self.scopes = scopes or ["https://graph.microsoft.com/.default"] + if isinstance(scopes, str): + self.scopes = [scopes] + else: + self.scopes = scopes or [self.DEFAULT_SCOPE] self._api_version = self.resolve_api_version_from_value(api_version) @classmethod @@ -140,20 +145,21 @@ class KiotaRequestAdapterHook(BaseHook): return { "tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), "api_version": StringField( - lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default="v1.0" + lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default=APIVersion.v1.value ), "authority": StringField(lazy_gettext("Authority"), widget=BS3TextFieldWidget()), + "certificate_path": StringField(lazy_gettext("Certificate path"), widget=BS3TextFieldWidget()), + "certificate_data": StringField(lazy_gettext("Certificate data"), widget=BS3TextFieldWidget()), "scopes": StringField( lazy_gettext("Scopes"), widget=BS3TextFieldWidget(), - default="https://graph.microsoft.com/.default", + default=cls.DEFAULT_SCOPE, ), "disable_instance_discovery": BooleanField( lazy_gettext("Disable instance discovery"), default=False ), - "allowed_hosts": StringField(lazy_gettext("Allowed"), widget=BS3TextFieldWidget()), + "allowed_hosts": StringField(lazy_gettext("Allowed hosts"), widget=BS3TextFieldWidget()), "proxies": StringField(lazy_gettext("Proxies"), widget=BS3TextAreaFieldWidget()), - "stream": BooleanField(lazy_gettext("Stream"), default=False), "verify": BooleanField(lazy_gettext("Verify"), default=True), "trust_env": BooleanField(lazy_gettext("Trust environment"), default=True), "base_url": StringField(lazy_gettext("Base URL"), widget=BS3TextFieldWidget()), @@ -241,18 +247,17 @@ class KiotaRequestAdapterHook(BaseHook): client_id = connection.login client_secret = connection.password config = connection.extra_dejson if connection.extra else {} - tenant_id = config.get("tenant_id") or config.get("tenantId") api_version = self.get_api_version(config) host = self.get_host(connection) base_url = config.get("base_url", urljoin(host, api_version)) authority = config.get("authority") proxies = self.proxies or config.get("proxies", {}) - msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) httpx_proxies = self.to_httpx_proxies(proxies=proxies) scopes = config.get("scopes", self.scopes) + if isinstance(scopes, str): + scopes = scopes.split(",") verify = config.get("verify", True) trust_env = config.get("trust_env", False) - disable_instance_discovery = config.get("disable_instance_discovery", False) allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",") self.log.info( @@ -262,7 +267,6 @@ class KiotaRequestAdapterHook(BaseHook): ) self.log.info("Host: %s", host) self.log.info("Base URL: %s", base_url) - self.log.info("Tenant id: %s", tenant_id) self.log.info("Client id: %s", client_id) self.log.info("Client secret: %s", client_secret) self.log.info("API version: %s", api_version) @@ -271,19 +275,16 @@ class KiotaRequestAdapterHook(BaseHook): self.log.info("Timeout: %s", self.timeout) self.log.info("Trust env: %s", trust_env) self.log.info("Authority: %s", authority) - self.log.info("Disable instance discovery: %s", disable_instance_discovery) self.log.info("Allowed hosts: %s", allowed_hosts) self.log.info("Proxies: %s", proxies) - self.log.info("MSAL Proxies: %s", msal_proxies) self.log.info("HTTPX Proxies: %s", httpx_proxies) - credentials = ClientSecretCredential( - tenant_id=tenant_id, # type: ignore - client_id=connection.login, - client_secret=connection.password, + credentials = self.get_credentials( + login=connection.login, + password=connection.password, + config=config, authority=authority, - proxies=msal_proxies, - disable_instance_discovery=disable_instance_discovery, - connection_verify=verify, + verify=verify, + proxies=proxies, ) http_client = GraphClientFactory.create_with_default_middleware( api_version=api_version, # type: ignore @@ -313,6 +314,48 @@ class KiotaRequestAdapterHook(BaseHook): self._api_version = api_version return request_adapter + def get_credentials( + self, + login: str | None, + password: str | None, + config, + authority: str | None, + verify: bool, + proxies: dict, + ) -> ClientCredentialBase: + tenant_id = config.get("tenant_id") or config.get("tenantId") + certificate_path = config.get("certificate_path") + certificate_data = config.get("certificate_data") + disable_instance_discovery = config.get("disable_instance_discovery", False) + msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) + self.log.info("Tenant id: %s", tenant_id) + self.log.info("Certificate path: %s", certificate_path) + self.log.info("Certificate data: %s", certificate_data is not None) + self.log.info("Authority: %s", authority) + self.log.info("Disable instance discovery: %s", disable_instance_discovery) + self.log.info("MSAL Proxies: %s", msal_proxies) + if certificate_path or certificate_data: + return CertificateCredential( + tenant_id=tenant_id, # type: ignore + client_id=login, # type: ignore + password=password, + certificate_path=certificate_path, + certificate_data=certificate_data.encode() if certificate_data else None, + authority=authority, + proxies=msal_proxies, + disable_instance_discovery=disable_instance_discovery, + connection_verify=verify, + ) + return ClientSecretCredential( + tenant_id=tenant_id, # type: ignore + client_id=login, # type: ignore + client_secret=password, # type: ignore + authority=authority, + proxies=msal_proxies, + disable_instance_discovery=disable_instance_discovery, + connection_verify=verify, + ) + def test_connection(self): """Test HTTP Connection.""" try: diff --git a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py index 9a3fc197d44..7aefd2971d4 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -71,6 +71,7 @@ class MSGraphAsyncOperator(BaseOperator): :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). When no timeout is specified or set to None then there is no HTTP timeout on each request. :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as `v1.0` or `beta`. @@ -110,6 +111,7 @@ class MSGraphAsyncOperator(BaseOperator): key: str = XCOM_RETURN_KEY, timeout: float | None = None, proxies: dict | None = None, + scopes: str | list[str] | None = None, api_version: APIVersion | str | None = None, pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None, result_processor: Callable[[Context, Any], Any] = lambda context, result: result, @@ -130,6 +132,7 @@ class MSGraphAsyncOperator(BaseOperator): self.key = key self.timeout = timeout self.proxies = proxies + self.scopes = scopes self.api_version = api_version self.pagination_function = pagination_function or self.paginate self.result_processor = result_processor @@ -150,6 +153,7 @@ class MSGraphAsyncOperator(BaseOperator): conn_id=self.conn_id, timeout=self.timeout, proxies=self.proxies, + scopes=self.scopes, api_version=self.api_version, serializer=type(self.serializer), ), diff --git a/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 6b5622e2d7a..ecad1a34f16 100644 --- a/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -47,6 +47,7 @@ class MSGraphSensor(BaseSensorOperator): :param method: The HTTP method being used to do the REST call (default is GET). :param conn_id: The HTTP Connection ID to run the operator against (templated). :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as `v1.0` or `beta`. @@ -83,6 +84,7 @@ class MSGraphSensor(BaseSensorOperator): data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, proxies: dict | None = None, + scopes: str | list[str] | None = None, api_version: APIVersion | str | None = None, event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded", result_processor: Callable[[Context, Any], Any] = lambda context, result: result, @@ -101,6 +103,7 @@ class MSGraphSensor(BaseSensorOperator): self.data = data self.conn_id = conn_id self.proxies = proxies + self.scopes = scopes self.api_version = api_version self.event_processor = event_processor self.result_processor = result_processor @@ -120,6 +123,7 @@ class MSGraphSensor(BaseSensorOperator): conn_id=self.conn_id, timeout=self.timeout, proxies=self.proxies, + scopes=self.scopes, api_version=self.api_version, serializer=type(self.serializer), ), diff --git a/providers/src/airflow/providers/microsoft/azure/triggers/msgraph.py b/providers/src/airflow/providers/microsoft/azure/triggers/msgraph.py index 076f2f493ea..4006ee6c3c0 100644 --- a/providers/src/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -90,6 +90,7 @@ class MSGraphTrigger(BaseTrigger): :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). When no timeout is specified or set to None then there is no HTTP timeout on each request. :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as `v1.0` or `beta`. @@ -121,6 +122,7 @@ class MSGraphTrigger(BaseTrigger): conn_id: str = KiotaRequestAdapterHook.default_conn_name, timeout: float | None = None, proxies: dict | None = None, + scopes: str | list[str] | None = None, api_version: APIVersion | str | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, ): @@ -129,6 +131,7 @@ class MSGraphTrigger(BaseTrigger): conn_id=conn_id, timeout=timeout, proxies=proxies, + scopes=scopes, api_version=api_version, ) self.url = url @@ -157,6 +160,7 @@ class MSGraphTrigger(BaseTrigger): "conn_id": self.conn_id, "timeout": self.timeout, "proxies": self.proxies, + "scopes": self.hook.scopes, "api_version": self.api_version, "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", "url": self.url, diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py b/providers/tests/microsoft/azure/hooks/test_msgraph.py index aff5d0226a1..3dbf8b9bf64 100644 --- a/providers/tests/microsoft/azure/hooks/test_msgraph.py +++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py @@ -86,6 +86,37 @@ class TestKiotaRequestAdapterHook: assert isinstance(actual, HttpxRequestAdapter) assert actual.base_url == "https://api.fabric.microsoft.com/v1" + def test_scopes_when_default(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + + assert hook.scopes == [KiotaRequestAdapterHook.DEFAULT_SCOPE] + + def test_scopes_when_passed_as_string(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook( + conn_id="msgraph_api", scopes="https://microsoft.sharepoint.com/.default" + ) + + assert hook.scopes == ["https://microsoft.sharepoint.com/.default"] + + def test_scopes_when_passed_as_list(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook( + conn_id="msgraph_api", scopes=["https://microsoft.sharepoint.com/.default"] + ) + + assert hook.scopes == ["https://microsoft.sharepoint.com/.default"] + def test_api_version(self): with patch( "airflow.hooks.base.BaseHook.get_connection", diff --git a/providers/tests/microsoft/azure/triggers/test_msgraph.py b/providers/tests/microsoft/azure/triggers/test_msgraph.py index 0784d8d8317..ce5a554fe1d 100644 --- a/providers/tests/microsoft/azure/triggers/test_msgraph.py +++ b/providers/tests/microsoft/azure/triggers/test_msgraph.py @@ -24,8 +24,10 @@ from unittest.mock import patch from uuid import uuid4 import pendulum +from msgraph_core import APIVersion from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import ( MSGraphTrigger, ResponseSerializer, @@ -108,7 +110,7 @@ class TestMSGraphTrigger(Base): actual = trigger.serialize() assert isinstance(actual, tuple) - assert actual[0] == "airflow.providers.microsoft.azure.triggers.msgraph.MSGraphTrigger" + assert actual[0] == f"{MSGraphTrigger.__module__}.{MSGraphTrigger.__name__}" assert actual[1] == { "url": "https://graph.microsoft.com/v1.0/me/drive/items", "path_parameters": None, @@ -121,8 +123,9 @@ class TestMSGraphTrigger(Base): "conn_id": "msgraph_api", "timeout": None, "proxies": None, - "api_version": "v1.0", - "serializer": "airflow.providers.microsoft.azure.triggers.msgraph.ResponseSerializer", + "scopes": [KiotaRequestAdapterHook.DEFAULT_SCOPE], + "api_version": APIVersion.v1.value, + "serializer": f"{ResponseSerializer.__module__}.{ResponseSerializer.__name__}", } def test_template_fields(self):