This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 40fd35ca444 fix: `KiotaRequestAdapterHook` make sure proxy config 
parameter is parsed correctly, even if it's a string or json (#46145)
40fd35ca444 is described below

commit 40fd35ca444740d6788e489ac7493b39937f8d23
Author: David Blain <[email protected]>
AuthorDate: Mon Feb 3 05:33:59 2025 +0100

    fix: `KiotaRequestAdapterHook` make sure proxy config parameter is parsed 
correctly, even if it's a string or json (#46145)
    
    * refactor: Make sure proxy is configured correctly, even if it's a string 
or json
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 .../providers/microsoft/azure/hooks/msgraph.py     | 48 +++++++++---
 .../tests/microsoft/azure/hooks/test_msgraph.py    | 87 +++++++++++++++++++++-
 providers/tests/microsoft/conftest.py              |  6 ++
 3 files changed, 125 insertions(+), 16 deletions(-)

diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py 
b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 1754d04b103..46d651670e4 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+from ast import literal_eval
 from contextlib import suppress
 from http import HTTPStatus
 from io import BytesIO
@@ -43,7 +44,12 @@ from kiota_serialization_text.text_parse_node_factory import 
TextParseNodeFactor
 from msgraph_core import APIVersion, GraphClientFactory
 from msgraph_core._enums import NationalClouds
 
-from airflow.exceptions import AirflowBadRequest, AirflowException, 
AirflowNotFoundException
+from airflow.exceptions import (
+    AirflowBadRequest,
+    AirflowConfigException,
+    AirflowException,
+    AirflowNotFoundException,
+)
 from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
@@ -212,19 +218,20 @@ class KiotaRequestAdapterHook(BaseHook):
 
     @classmethod
     def to_httpx_proxies(cls, proxies: dict) -> dict:
-        proxies = proxies.copy()
-        if proxies.get("http"):
-            proxies["http://";] = AsyncHTTPTransport(proxy=proxies.pop("http"))
-        if proxies.get("https"):
-            proxies["https://";] = 
AsyncHTTPTransport(proxy=proxies.pop("https"))
-        if proxies.get("no"):
-            for url in proxies.pop("no", "").split(","):
-                proxies[cls.format_no_proxy_url(url.strip())] = None
+        if proxies:
+            proxies = proxies.copy()
+            if proxies.get("http"):
+                proxies["http://";] = 
AsyncHTTPTransport(proxy=proxies.pop("http"))
+            if proxies.get("https"):
+                proxies["https://";] = 
AsyncHTTPTransport(proxy=proxies.pop("https"))
+            if proxies.get("no"):
+                for url in proxies.pop("no", "").split(","):
+                    proxies[cls.format_no_proxy_url(url.strip())] = None
         return proxies
 
-    def to_msal_proxies(self, authority: str | None, proxies: dict):
+    def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict | 
None:
         self.log.debug("authority: %s", authority)
-        if authority:
+        if authority and proxies:
             no_proxies = proxies.get("no")
             self.log.debug("no_proxies: %s", no_proxies)
             if no_proxies:
@@ -251,7 +258,7 @@ class KiotaRequestAdapterHook(BaseHook):
             host = self.get_host(connection)
             base_url = config.get("base_url", urljoin(host, api_version))
             authority = config.get("authority")
-            proxies = self.proxies or config.get("proxies", {})
+            proxies = self.get_proxies(config)
             httpx_proxies = self.to_httpx_proxies(proxies=proxies)
             scopes = config.get("scopes", self.scopes)
             if isinstance(scopes, str):
@@ -314,6 +321,23 @@ class KiotaRequestAdapterHook(BaseHook):
         self._api_version = api_version
         return request_adapter
 
+    def get_proxies(self, config: dict) -> dict:
+        proxies = self.proxies or config.get("proxies", {})
+        if isinstance(proxies, str):
+            # TODO: Once provider depends on Airflow 2.10 or higher code below 
won't be needed anymore as
+            #       we could then use the get_extra_dejson method on the 
connection which deserializes
+            #       nested json. Make sure to use 
connection.get_extra_dejson(nested=True) instead of
+            #       connection.extra_dejson.
+            with suppress(JSONDecodeError):
+                proxies = json.loads(proxies)
+            with suppress(Exception):
+                proxies = literal_eval(proxies)
+        if not isinstance(proxies, dict):
+            raise AirflowConfigException(
+                f"Proxies must be of type dict, got {type(proxies).__name__} 
instead!"
+            )
+        return proxies
+
     def get_credentials(
         self,
         login: str | None,
diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py 
b/providers/tests/microsoft/azure/hooks/test_msgraph.py
index 3dbf8b9bf64..c3621a4ec6f 100644
--- a/providers/tests/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py
@@ -17,19 +17,27 @@
 from __future__ import annotations
 
 import asyncio
+import inspect
 from json import JSONDecodeError
 from typing import TYPE_CHECKING
 from unittest.mock import Mock, patch
 
 import pytest
 from httpx import Response
+from httpx._utils import URLPattern
 from kiota_http.httpx_request_adapter import HttpxRequestAdapter
 from kiota_serialization_json.json_parse_node import JsonParseNode
 from kiota_serialization_text.text_parse_node import TextParseNode
 from msgraph_core import APIVersion, NationalClouds
 from opentelemetry.trace import Span
 
-from airflow.exceptions import AirflowBadRequest, AirflowException, 
AirflowNotFoundException
+from airflow.exceptions import (
+    AirflowBadRequest,
+    AirflowConfigException,
+    AirflowException,
+    AirflowNotFoundException,
+    AirflowProviderDeprecationWarning,
+)
 from airflow.providers.microsoft.azure.hooks.msgraph import (
     DefaultResponseHandler,
     KiotaRequestAdapterHook,
@@ -43,15 +51,13 @@ from providers.tests.microsoft.conftest import (
     mock_json_response,
     mock_response,
 )
+from tests_common.test_utils.providers import get_provider_min_airflow_version
 
 if TYPE_CHECKING:
     from kiota_abstractions.request_adapter import RequestAdapter
 
 
 class TestKiotaRequestAdapterHook:
-    def setup_method(self):
-        KiotaRequestAdapterHook.cached_request_adapters.clear()
-
     @staticmethod
     def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: 
str):
         assert isinstance(request_adapter, HttpxRequestAdapter)
@@ -86,6 +92,61 @@ class TestKiotaRequestAdapterHook:
             assert isinstance(actual, HttpxRequestAdapter)
             assert actual.base_url == "https://api.fabric.microsoft.com/v1";
 
+    def test_get_conn_with_proxies_as_string(self):
+        connection = lambda conn_id: get_airflow_connection(
+            conn_id=conn_id,
+            host="api.fabric.microsoft.com",
+            api_version="v1",
+            proxies="{'http': 'http://proxy:80', 'https': 'https://proxy:80'}",
+        )
+
+        with patch(
+            "airflow.hooks.base.BaseHook.get_connection",
+            side_effect=connection,
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+            actual = hook.get_conn()
+
+            assert isinstance(actual, HttpxRequestAdapter)
+            assert actual._http_client._mounts.get(URLPattern("http://";))
+            assert actual._http_client._mounts.get(URLPattern("https://";))
+
+    def test_get_conn_with_proxies_as_invalid_string(self):
+        connection = lambda conn_id: get_airflow_connection(
+            conn_id=conn_id,
+            host="api.fabric.microsoft.com",
+            api_version="v1",
+            proxies='["http://proxy:80";, "https://proxy:80";]',
+        )
+
+        with patch(
+            "airflow.hooks.base.BaseHook.get_connection",
+            side_effect=connection,
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+            with pytest.raises(AirflowConfigException):
+                hook.get_conn()
+
+    def test_get_conn_with_proxies_as_json(self):
+        connection = lambda conn_id: get_airflow_connection(
+            conn_id=conn_id,
+            host="api.fabric.microsoft.com",
+            api_version="v1",
+            proxies='{"http": "http://proxy:80";, "https": "https://proxy:80"}',
+        )
+
+        with patch(
+            "airflow.hooks.base.BaseHook.get_connection",
+            side_effect=connection,
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+            actual = hook.get_conn()
+
+            assert isinstance(actual, HttpxRequestAdapter)
+            assert actual._http_client._mounts.get(URLPattern("http://";))
+            assert actual._http_client._mounts.get(URLPattern("https://";))
+
     def test_scopes_when_default(self):
         with patch(
             "airflow.hooks.base.BaseHook.get_connection",
@@ -301,3 +362,21 @@ class TestResponseHandler:
 
         with pytest.raises(AirflowException):
             
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
+
+    @pytest.mark.db_test
+    def 
test_when_provider_min_airflow_version_is_2_10_or_higher_remove_obsolete_code(self):
+        """
+        Once this test starts failing due to the fact that the minimum Airflow 
version is now 2.10.0 or higher
+        for this provider, you should remove the obsolete code in the 
get_proxies method of the
+        KiotaRequestAdapterHook and remove this test.  This test was added to 
make sure to not forget to
+        remove the fallback code for backward compatibility with Airflow 2.9.x 
which isn't need anymore once
+        this provider depends on Airflow 2.10.0 or higher.
+        """
+        min_airflow_version = 
get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure")
+
+        # Check if the current Airflow version is 2.10.0 or higher
+        if min_airflow_version[0] >= 3 or (min_airflow_version[0] >= 2 and 
min_airflow_version[1] >= 10):
+            method_source = 
inspect.getsource(KiotaRequestAdapterHook.get_proxies)
+            raise AirflowProviderDeprecationWarning(
+                f"Check TODO's to remove obsolete code in get_proxies 
method:\n\r\n\r\t\t\t{method_source}"
+            )
diff --git a/providers/tests/microsoft/conftest.py 
b/providers/tests/microsoft/conftest.py
index e0ec3172cfa..606c2b2ca1c 100644
--- a/providers/tests/microsoft/conftest.py
+++ b/providers/tests/microsoft/conftest.py
@@ -32,6 +32,7 @@ from httpx import Headers, Response
 from msgraph_core import APIVersion
 
 from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.msgraph import 
KiotaRequestAdapterHook
 from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook
 
 T = TypeVar("T", dict, str, Connection)
@@ -176,6 +177,11 @@ def get_airflow_connection(
     )
 
 
[email protected](autouse=True)
+def clear_cache():
+    KiotaRequestAdapterHook.cached_request_adapters.clear()
+
+
 @pytest.fixture
 def powerbi_hook():
     return PowerBIHook(**{"conn_id": "powerbi_conn_id", "timeout": 3, 
"api_version": "v1.0"})

Reply via email to