This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c2d02b450a Allow default requests parameters like proxy to be defined 
in extra options field of a Airflow HTTP Connection (#36733)
c2d02b450a is described below

commit c2d02b450a1836ba777dc5557ac1773161cbc5ea
Author: David Blain <[email protected]>
AuthorDate: Sat Jan 13 18:49:51 2024 +0100

    Allow default requests parameters like proxy to be defined in extra options 
field of a Airflow HTTP Connection (#36733)
    
    * refactor: Pop non-header related parameters from the Connection 
extra_options which are used by the SimpleHttpOperator to avoid a InvalidHeader 
exception while instantiating the requests Session
    
    * refactor: Forgot to assign non-header related parameters to pop to the 
instantiated request Session as default value
    
    * refactor: Also use the extra options from connections when using an 
AsyncHttpHook
    
    * docs: Updated the HTTP Connection documentation concerning the optional 
Extra field
    
    * refactor: Fixed static checks on test http module
    
    * refactor: Also allow the definition of timeout as a request parameter in 
extra_options and added async test for AsyncHttpOperator
    
    * refactor: Fixed some formatting to make static checks happy
    
    * refactor: Removed indentation from Extras section
    
    * refactor: Refactored different tests for the 
process_extra_options_from_connection into one test as suggested by aritra24
    
    * refactor: Fixed formatting of get_airflow_connection_with_extra
    
    * refactor: Moved import of Connection under type check
    
    * refactor: Reformatted http hook
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 airflow/providers/http/hooks/http.py               |  43 ++++++-
 .../connections/http.rst                           |  10 +-
 tests/providers/http/hooks/test_http.py            | 139 +++++++++++++++++++--
 3 files changed, 181 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/http/hooks/http.py 
b/airflow/providers/http/hooks/http.py
index fc19d0b102..7b98ec25df 100644
--- a/airflow/providers/http/hooks/http.py
+++ b/airflow/providers/http/hooks/http.py
@@ -26,6 +26,7 @@ import tenacity
 from aiohttp import ClientResponseError
 from asgiref.sync import sync_to_async
 from requests.auth import HTTPBasicAuth
+from requests.models import DEFAULT_REDIRECT_LIMIT
 from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
 
 from airflow.exceptions import AirflowException
@@ -34,6 +35,8 @@ from airflow.hooks.base import BaseHook
 if TYPE_CHECKING:
     from aiohttp.client_reqrep import ClientResponse
 
+    from airflow.models import Connection
+
 
 class HttpHook(BaseHook):
     """Interact with HTTP servers.
@@ -113,8 +116,19 @@ class HttpHook(BaseHook):
             elif self._auth_type:
                 session.auth = self.auth_type()
             if conn.extra:
+                extra = conn.extra_dejson
+                extra.pop(
+                    "timeout", None
+                )  # ignore this as timeout is only accepted in request method 
of Session
+                extra.pop("allow_redirects", None)  # ignore this as only 
max_redirects is accepted in Session
+                session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
+                session.stream = extra.pop("stream", False)
+                session.verify = extra.pop("verify", extra.pop("verify_ssl", 
True))
+                session.cert = extra.pop("cert", None)
+                session.max_redirects = extra.pop("max_redirects", 
DEFAULT_REDIRECT_LIMIT)
+
                 try:
-                    session.headers.update(conn.extra_dejson)
+                    session.headers.update(extra)
                 except TypeError:
                     self.log.warning("Connection to %s has invalid extra 
field.", conn.host)
         if headers:
@@ -336,8 +350,10 @@ class HttpAsyncHook(BaseHook):
             if conn.login:
                 auth = self.auth_type(conn.login, conn.password)
             if conn.extra:
+                extra = self._process_extra_options_from_connection(conn=conn, 
extra_options=extra_options)
+
                 try:
-                    _headers.update(conn.extra_dejson)
+                    _headers.update(extra)
                 except TypeError:
                     self.log.warning("Connection to %s has invalid extra 
field.", conn.host)
         if headers:
@@ -395,6 +411,29 @@ class HttpAsyncHook(BaseHook):
             else:
                 raise NotImplementedError  # should not reach this, but makes 
mypy happy
 
+    @classmethod
+    def _process_extra_options_from_connection(cls, conn: Connection, 
extra_options: dict) -> dict:
+        extra = conn.extra_dejson
+        extra.pop("stream", None)
+        extra.pop("cert", None)
+        proxies = extra.pop("proxies", extra.pop("proxy", None))
+        timeout = extra.pop("timeout", None)
+        verify_ssl = extra.pop("verify", extra.pop("verify_ssl", None))
+        allow_redirects = extra.pop("allow_redirects", None)
+        max_redirects = extra.pop("max_redirects", None)
+
+        if proxies is not None and "proxy" not in extra_options:
+            extra_options["proxy"] = proxies
+        if timeout is not None and "timeout" not in extra_options:
+            extra_options["timeout"] = timeout
+        if verify_ssl is not None and "verify_ssl" not in extra_options:
+            extra_options["verify_ssl"] = verify_ssl
+        if allow_redirects is not None and "allow_redirects" not in 
extra_options:
+            extra_options["allow_redirects"] = allow_redirects
+        if max_redirects is not None and "max_redirects" not in extra_options:
+            extra_options["max_redirects"] = max_redirects
+        return extra
+
     def _retryable_error_async(self, exception: ClientResponseError) -> bool:
         """Determine whether an exception may successful on a subsequent 
attempt.
 
diff --git a/docs/apache-airflow-providers-http/connections/http.rst 
b/docs/apache-airflow-providers-http/connections/http.rst
index 41856cefee..6f1decdec9 100644
--- a/docs/apache-airflow-providers-http/connections/http.rst
+++ b/docs/apache-airflow-providers-http/connections/http.rst
@@ -54,7 +54,15 @@ Schema (optional)
     Specify the service type etc: http/https.
 
 Extras (optional)
-    Specify headers in json format.
+    Specify headers and default requests parameters in json format.
+    Following default requests parameters are taken into account:
+    * ``stream``
+    * ``cert``
+    * ``proxies or proxy``
+    * ``verify or verify_ssl``
+    * ``allow_redirects``
+    * ``max_redirects``
+
 
 When specifying the connection in environment variable you should specify
 it using URI syntax.
diff --git a/tests/providers/http/hooks/test_http.py 
b/tests/providers/http/hooks/test_http.py
index 617009d575..7b093c66bb 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -31,6 +31,7 @@ import tenacity
 from aioresponses import aioresponses
 from requests.adapters import Response
 from requests.auth import AuthBase, HTTPBasicAuth
+from requests.models import DEFAULT_REDIRECT_LIMIT
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
@@ -46,18 +47,23 @@ def aioresponse():
         yield async_response
 
 
-def get_airflow_connection(unused_conn_id=None):
-    return Connection(conn_id="http_default", conn_type="http", 
host="test:8080/", extra='{"bearer": "test"}')
+def get_airflow_connection(conn_id: str = "http_default"):
+    return Connection(conn_id=conn_id, conn_type="http", host="test:8080/", 
extra='{"bearer": "test"}')
 
 
-def get_airflow_connection_with_port(unused_conn_id=None):
-    return Connection(conn_id="http_default", conn_type="http", 
host="test.com", port=1234)
+def get_airflow_connection_with_extra(extra: dict):
+    def inner(conn_id: str = "http_default"):
+        return Connection(conn_id=conn_id, conn_type="http", 
host="test:8080/", extra=json.dumps(extra))
 
+    return inner
 
-def get_airflow_connection_with_login_and_password(unused_conn_id=None):
-    return Connection(
-        conn_id="http_default", conn_type="http", host="test.com", 
login="username", password="pass"
-    )
+
+def get_airflow_connection_with_port(conn_id: str = "http_default"):
+    return Connection(conn_id=conn_id, conn_type="http", host="test.com", 
port=1234)
+
+
+def get_airflow_connection_with_login_and_password(conn_id: str = 
"http_default"):
+    return Connection(conn_id=conn_id, conn_type="http", host="test.com", 
login="username", password="pass")
 
 
 class TestHttpHook:
@@ -119,6 +125,64 @@ class TestHttpHook:
             assert dict(conn.headers, **json.loads(expected_conn.extra)) == 
conn.headers
             assert conn.headers.get("bearer") == "test"
 
+    def test_hook_ignore_max_redirects_from_extra_field_as_header(self):
+        airflow_connection = 
get_airflow_connection_with_extra(extra={"bearer": "test", "max_redirects": 3})
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
+            expected_conn = airflow_connection()
+            conn = self.get_hook.get_conn()
+            assert dict(conn.headers, **json.loads(expected_conn.extra)) != 
conn.headers
+            assert conn.headers.get("bearer") == "test"
+            assert conn.headers.get("allow_redirects") is None
+            assert conn.proxies == {}
+            assert conn.stream is False
+            assert conn.verify is True
+            assert conn.cert is None
+            assert conn.max_redirects == 3
+
+    def test_hook_ignore_proxies_from_extra_field_as_header(self):
+        airflow_connection = get_airflow_connection_with_extra(
+            extra={"bearer": "test", "proxies": {"http": "http://proxy:80";, 
"https": "https://proxy:80"}}
+        )
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
+            expected_conn = airflow_connection()
+            conn = self.get_hook.get_conn()
+            assert dict(conn.headers, **json.loads(expected_conn.extra)) != 
conn.headers
+            assert conn.headers.get("bearer") == "test"
+            assert conn.headers.get("proxies") is None
+            assert conn.proxies == {"http": "http://proxy:80";, "https": 
"https://proxy:80"}
+            assert conn.stream is False
+            assert conn.verify is True
+            assert conn.cert is None
+            assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+
+    def test_hook_ignore_verify_from_extra_field_as_header(self):
+        airflow_connection = 
get_airflow_connection_with_extra(extra={"bearer": "test", "verify": False})
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
+            expected_conn = airflow_connection()
+            conn = self.get_hook.get_conn()
+            assert dict(conn.headers, **json.loads(expected_conn.extra)) != 
conn.headers
+            assert conn.headers.get("bearer") == "test"
+            assert conn.headers.get("verify") is None
+            assert conn.proxies == {}
+            assert conn.stream is False
+            assert conn.verify is False
+            assert conn.cert is None
+            assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+
+    def test_hook_ignore_cert_from_extra_field_as_header(self):
+        airflow_connection = 
get_airflow_connection_with_extra(extra={"bearer": "test", "cert": "cert.crt"})
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
+            expected_conn = airflow_connection()
+            conn = self.get_hook.get_conn()
+            assert dict(conn.headers, **json.loads(expected_conn.extra)) != 
conn.headers
+            assert conn.headers.get("bearer") == "test"
+            assert conn.headers.get("cert") is None
+            assert conn.proxies == {}
+            assert conn.stream is False
+            assert conn.verify is True
+            assert conn.cert == "cert.crt"
+            assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+
     @mock.patch("requests.Request")
     def test_hook_with_method_in_lowercase(self, mock_requests):
         from requests.exceptions import InvalidURL, MissingSchema
@@ -525,3 +589,62 @@ class TestHttpAsyncHook:
                 assert all(
                     key in headers and headers[key] == value for key, value in 
connection_extra.items()
                 )
+
+    @pytest.mark.asyncio
+    async def 
test_async_request_uses_connection_extra_with_requests_parameters(self):
+        """Test api call asynchronously with a connection that has extra 
field."""
+        connection_extra = {"bearer": "test"}
+        proxy = {"http": "http://proxy:80";, "https": "https://proxy:80"}
+        airflow_connection = get_airflow_connection_with_extra(
+            extra={
+                **connection_extra,
+                **{
+                    "proxies": proxy,
+                    "timeout": 60,
+                    "verify": False,
+                    "allow_redirects": False,
+                    "max_redirects": 3,
+                },
+            }
+        )
+
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=airflow_connection):
+            hook = HttpAsyncHook()
+            with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
+                await hook.run("v1/test")
+                headers = mocked_function.call_args.kwargs.get("headers")
+                assert all(
+                    key in headers and headers[key] == value for key, value in 
connection_extra.items()
+                )
+                assert mocked_function.call_args.kwargs.get("proxy") == proxy
+                assert mocked_function.call_args.kwargs.get("timeout") == 60
+                assert mocked_function.call_args.kwargs.get("verify_ssl") is 
False
+                assert mocked_function.call_args.kwargs.get("allow_redirects") 
is False
+                assert mocked_function.call_args.kwargs.get("max_redirects") 
== 3
+
+    def test_process_extra_options_from_connection(self):
+        extra_options = {}
+        proxy = {"http": "http://proxy:80";, "https": "https://proxy:80"}
+        conn = get_airflow_connection_with_extra(
+            extra={
+                "bearer": "test",
+                "stream": True,
+                "cert": "cert.crt",
+                "proxies": proxy,
+                "timeout": 60,
+                "verify": False,
+                "allow_redirects": False,
+                "max_redirects": 3,
+            }
+        )()
+
+        actual = 
HttpAsyncHook._process_extra_options_from_connection(conn=conn, 
extra_options=extra_options)
+
+        assert extra_options == {
+            "proxy": proxy,
+            "timeout": 60,
+            "verify_ssl": False,
+            "allow_redirects": False,
+            "max_redirects": 3,
+        }
+        assert actual == {"bearer": "test"}

Reply via email to