This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 40fd35ca444 fix: `KiotaRequestAdapterHook` make sure proxy config
parameter is parsed correctly, even if it's a string or json (#46145)
40fd35ca444 is described below
commit 40fd35ca444740d6788e489ac7493b39937f8d23
Author: David Blain <[email protected]>
AuthorDate: Mon Feb 3 05:33:59 2025 +0100
fix: `KiotaRequestAdapterHook` make sure proxy config parameter is parsed
correctly, even if it's a string or json (#46145)
* refactor: Make sure proxy is configured correctly, even if it's a string
or json
---------
Co-authored-by: David Blain <[email protected]>
---
.../providers/microsoft/azure/hooks/msgraph.py | 48 +++++++++---
.../tests/microsoft/azure/hooks/test_msgraph.py | 87 +++++++++++++++++++++-
providers/tests/microsoft/conftest.py | 6 ++
3 files changed, 125 insertions(+), 16 deletions(-)
diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 1754d04b103..46d651670e4 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+from ast import literal_eval
from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
@@ -43,7 +44,12 @@ from kiota_serialization_text.text_parse_node_factory import
TextParseNodeFactor
from msgraph_core import APIVersion, GraphClientFactory
from msgraph_core._enums import NationalClouds
-from airflow.exceptions import AirflowBadRequest, AirflowException,
AirflowNotFoundException
+from airflow.exceptions import (
+ AirflowBadRequest,
+ AirflowConfigException,
+ AirflowException,
+ AirflowNotFoundException,
+)
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
@@ -212,19 +218,20 @@ class KiotaRequestAdapterHook(BaseHook):
@classmethod
def to_httpx_proxies(cls, proxies: dict) -> dict:
- proxies = proxies.copy()
- if proxies.get("http"):
- proxies["http://"] = AsyncHTTPTransport(proxy=proxies.pop("http"))
- if proxies.get("https"):
- proxies["https://"] =
AsyncHTTPTransport(proxy=proxies.pop("https"))
- if proxies.get("no"):
- for url in proxies.pop("no", "").split(","):
- proxies[cls.format_no_proxy_url(url.strip())] = None
+ if proxies:
+ proxies = proxies.copy()
+ if proxies.get("http"):
+ proxies["http://"] =
AsyncHTTPTransport(proxy=proxies.pop("http"))
+ if proxies.get("https"):
+ proxies["https://"] =
AsyncHTTPTransport(proxy=proxies.pop("https"))
+ if proxies.get("no"):
+ for url in proxies.pop("no", "").split(","):
+ proxies[cls.format_no_proxy_url(url.strip())] = None
return proxies
- def to_msal_proxies(self, authority: str | None, proxies: dict):
+ def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict |
None:
self.log.debug("authority: %s", authority)
- if authority:
+ if authority and proxies:
no_proxies = proxies.get("no")
self.log.debug("no_proxies: %s", no_proxies)
if no_proxies:
@@ -251,7 +258,7 @@ class KiotaRequestAdapterHook(BaseHook):
host = self.get_host(connection)
base_url = config.get("base_url", urljoin(host, api_version))
authority = config.get("authority")
- proxies = self.proxies or config.get("proxies", {})
+ proxies = self.get_proxies(config)
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
scopes = config.get("scopes", self.scopes)
if isinstance(scopes, str):
@@ -314,6 +321,23 @@ class KiotaRequestAdapterHook(BaseHook):
self._api_version = api_version
return request_adapter
+ def get_proxies(self, config: dict) -> dict:
+ proxies = self.proxies or config.get("proxies", {})
+ if isinstance(proxies, str):
+ # TODO: Once provider depends on Airflow 2.10 or higher code below
won't be needed anymore as
+ # we could then use the get_extra_dejson method on the
connection which deserializes
+ # nested json. Make sure to use
connection.get_extra_dejson(nested=True) instead of
+ # connection.extra_dejson.
+ with suppress(JSONDecodeError):
+ proxies = json.loads(proxies)
+ with suppress(Exception):
+ proxies = literal_eval(proxies)
+ if not isinstance(proxies, dict):
+ raise AirflowConfigException(
+ f"Proxies must be of type dict, got {type(proxies).__name__}
instead!"
+ )
+ return proxies
+
def get_credentials(
self,
login: str | None,
diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py
b/providers/tests/microsoft/azure/hooks/test_msgraph.py
index 3dbf8b9bf64..c3621a4ec6f 100644
--- a/providers/tests/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py
@@ -17,19 +17,27 @@
from __future__ import annotations
import asyncio
+import inspect
from json import JSONDecodeError
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
import pytest
from httpx import Response
+from httpx._utils import URLPattern
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from kiota_serialization_json.json_parse_node import JsonParseNode
from kiota_serialization_text.text_parse_node import TextParseNode
from msgraph_core import APIVersion, NationalClouds
from opentelemetry.trace import Span
-from airflow.exceptions import AirflowBadRequest, AirflowException,
AirflowNotFoundException
+from airflow.exceptions import (
+ AirflowBadRequest,
+ AirflowConfigException,
+ AirflowException,
+ AirflowNotFoundException,
+ AirflowProviderDeprecationWarning,
+)
from airflow.providers.microsoft.azure.hooks.msgraph import (
DefaultResponseHandler,
KiotaRequestAdapterHook,
@@ -43,15 +51,13 @@ from providers.tests.microsoft.conftest import (
mock_json_response,
mock_response,
)
+from tests_common.test_utils.providers import get_provider_min_airflow_version
if TYPE_CHECKING:
from kiota_abstractions.request_adapter import RequestAdapter
class TestKiotaRequestAdapterHook:
- def setup_method(self):
- KiotaRequestAdapterHook.cached_request_adapters.clear()
-
@staticmethod
def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id:
str):
assert isinstance(request_adapter, HttpxRequestAdapter)
@@ -86,6 +92,61 @@ class TestKiotaRequestAdapterHook:
assert isinstance(actual, HttpxRequestAdapter)
assert actual.base_url == "https://api.fabric.microsoft.com/v1"
+ def test_get_conn_with_proxies_as_string(self):
+ connection = lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ host="api.fabric.microsoft.com",
+ api_version="v1",
+ proxies="{'http': 'http://proxy:80', 'https': 'https://proxy:80'}",
+ )
+
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
+
+ assert isinstance(actual, HttpxRequestAdapter)
+ assert actual._http_client._mounts.get(URLPattern("http://"))
+ assert actual._http_client._mounts.get(URLPattern("https://"))
+
+ def test_get_conn_with_proxies_as_invalid_string(self):
+ connection = lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ host="api.fabric.microsoft.com",
+ api_version="v1",
+ proxies='["http://proxy:80", "https://proxy:80"]',
+ )
+
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ with pytest.raises(AirflowConfigException):
+ hook.get_conn()
+
+ def test_get_conn_with_proxies_as_json(self):
+ connection = lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ host="api.fabric.microsoft.com",
+ api_version="v1",
+ proxies='{"http": "http://proxy:80", "https": "https://proxy:80"}',
+ )
+
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
+
+ assert isinstance(actual, HttpxRequestAdapter)
+ assert actual._http_client._mounts.get(URLPattern("http://"))
+ assert actual._http_client._mounts.get(URLPattern("https://"))
+
def test_scopes_when_default(self):
with patch(
"airflow.hooks.base.BaseHook.get_connection",
@@ -301,3 +362,21 @@ class TestResponseHandler:
with pytest.raises(AirflowException):
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+ @pytest.mark.db_test
+ def
test_when_provider_min_airflow_version_is_2_10_or_higher_remove_obsolete_code(self):
+ """
+ Once this test starts failing due to the fact that the minimum Airflow
version is now 2.10.0 or higher
+ for this provider, you should remove the obsolete code in the
get_proxies method of the
+ KiotaRequestAdapterHook and remove this test. This test was added to
make sure to not forget to
+ remove the fallback code for backward compatibility with Airflow 2.9.x
which isn't need anymore once
+ this provider depends on Airflow 2.10.0 or higher.
+ """
+ min_airflow_version =
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure")
+
+ # Check if the current Airflow version is 2.10.0 or higher
+ if min_airflow_version[0] >= 3 or (min_airflow_version[0] >= 2 and
min_airflow_version[1] >= 10):
+ method_source =
inspect.getsource(KiotaRequestAdapterHook.get_proxies)
+ raise AirflowProviderDeprecationWarning(
+ f"Check TODO's to remove obsolete code in get_proxies
method:\n\r\n\r\t\t\t{method_source}"
+ )
diff --git a/providers/tests/microsoft/conftest.py
b/providers/tests/microsoft/conftest.py
index e0ec3172cfa..606c2b2ca1c 100644
--- a/providers/tests/microsoft/conftest.py
+++ b/providers/tests/microsoft/conftest.py
@@ -32,6 +32,7 @@ from httpx import Headers, Response
from msgraph_core import APIVersion
from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.msgraph import
KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook
T = TypeVar("T", dict, str, Connection)
@@ -176,6 +177,11 @@ def get_airflow_connection(
)
[email protected](autouse=True)
+def clear_cache():
+ KiotaRequestAdapterHook.cached_request_adapters.clear()
+
+
@pytest.fixture
def powerbi_hook():
return PowerBIHook(**{"conn_id": "powerbi_conn_id", "timeout": 3,
"api_version": "v1.0"})