This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 71fec4e661a [Providers/HTTP] Add adapter parameter to HttpHook to 
allow custom requests adapters (#44302)
71fec4e661a is described below

commit 71fec4e661ae4858dfe0f3797124c68b87aa13b2
Author: Jiao, Hsu <[email protected]>
AuthorDate: Tue Dec 3 18:44:17 2024 +0800

    [Providers/HTTP] Add adapter parameter to HttpHook to allow custom requests 
adapters (#44302)
    
    * feat(http-hook): add adapter parameter to HttpHook and enhance get_conn
    
    - Added `adapter` parameter to `HttpHook` to allow custom HTTP adapters.
    - Modified `get_conn` to support mounting custom adapters or using 
TCPKeepAliveAdapter by default.
    - Added comprehensive tests to validate the functionality of the `adapter` 
parameter and its integration with `get_conn`.
    - Ensured all new tests pass and maintain compatibility with existing 
functionality.
    
    * fix(http_hook): Update docstring and remove redundant TCPKeepAliveAdapter
    
    - Added missing `adapter` parameter description to the HttpHook class 
docstring.
    - Removed redundant instantiation of `TCPKeepAliveAdapter` in the `run` 
method since it's already instantiated in `get_conn`.
    
    * fix(http_hook): improve get_conn session setup and TCP adapter logic
    
    - Ensured proper mounting of TCP Keep-Alive adapter when enabled.
    - Improved handling of connection extras for cleaner session configuration.
    
    * feat(http): update get_conn logic and corresponding tests (#44302)
    
    Aligned the `get_conn` method with the adjustments specified in #44302,
    including refined handling of headers. Optimized and updated test cases
    to ensure compatibility and maintain robust test coverage.
    
    * refactor(http_hook): simplify HttpHook by reverting BaseAdapter to 
HTTPAdapter
    
    - Changed the `adapter` parameter to accept only `HTTPAdapter` instead of 
`BaseAdapter`.
    - Strengthened `_set_base_url` validation to ensure base_url is constructed 
with stricter conditions.
    - Adjusted `_mount_adapters` to improve maintainability.
    
    * refactor(http_hook): simplify HttpHook by reverting BaseAdapter to 
HTTPAdapter
    
    - Changed the `adapter` parameter to accept only `HTTPAdapter` instead of 
`BaseAdapter`.
    - Strengthened `_set_base_url` validation to ensure base_url is constructed 
with stricter conditions.
    - Adjusted `_mount_adapters` to improve maintainability.
    
    * Merge: new main
    
    * refactor: improve function naming and add type annotations
    
    - Changed the function prefix from `_set` to `_configure_session_from` to 
enhance readability and better reflect its purpose.
    - Added static type annotations for input parameters and return values.
    - Included comments to document the design rationale following coding 
standards.
    - Improved error message: replaced generic text with detailed and 
actionable messages.
    
    * fix: simplify the change of session
    
    - Added a variable `session` after the change of session member
    
    * fix: Adjust response format.
    
    * fix: simplify the logic
    
    * fix(hook): ensure default HTTPAdapter in HttpHook init
    
    The `adapter` parameter in `HttpHook` was previously required to be 
explicitly
    set to an instance of `HTTPAdapter`. This commit modifies the `__init__`
    method to assign a default `HTTPAdapter` when no adapter is provided.
    
    Changes:
    - Removed type checks for `adapter`, as default initialization guarantees 
correctness.
    - Improved code readability and reduced potential runtime errors.
    
    No functional changes beyond defaulting `adapter` to `HTTPAdapter`.
    
    * feat(http_hook): add support for custom adapter in initialization
    
    Refactored `HttpHook` to support a custom `HTTPAdapter` through the 
`adapter` parameter. If no adapter is provided, it defaults to 
`TCPKeepAliveAdapter` when `tcp_keep_alive=True`.
    
    Test: Added `test_custom_adapter` to verify correct adapter mounting.
    
    * fix: CI image checks / Static checks
    
    - Adjust the length of each line of code.
    
    * fix: Adjust indent style
    
    - modify `assert instance` by PEP8
    
    * fix: ruff error about `from requests.adapters import HTTPAdapter`
    
    ---------
    
    Co-authored-by: jiao <[email protected]>
---
 providers/src/airflow/providers/http/hooks/http.py | 126 +++++++++++++--------
 providers/tests/http/hooks/test_http.py            |  16 ++-
 2 files changed, 95 insertions(+), 47 deletions(-)

diff --git a/providers/src/airflow/providers/http/hooks/http.py 
b/providers/src/airflow/providers/http/hooks/http.py
index 05b432626b8..a179739275e 100644
--- a/providers/src/airflow/providers/http/hooks/http.py
+++ b/providers/src/airflow/providers/http/hooks/http.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import asyncio
 from typing import TYPE_CHECKING, Any, Callable
+from urllib.parse import urlparse
 
 import aiohttp
 import requests
@@ -34,6 +35,7 @@ from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
     from aiohttp.client_reqrep import ClientResponse
+    from requests.adapters import HTTPAdapter
 
     from airflow.models import Connection
 
@@ -54,6 +56,7 @@ class HttpHook(BaseHook):
         API url i.e https://www.google.com/ and optional authentication 
credentials. Default
         headers can also be specified in the Extra field in json format.
     :param auth_type: The auth type for the service
+    :param adapter: An optional instance of `requests.adapters.HTTPAdapter` to 
mount for the session.
     :param tcp_keep_alive: Enable TCP Keep Alive for the connection.
     :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds 
to ``socket.TCP_KEEPIDLE``).
     :param tcp_keep_alive_count: The TCP Keep Alive count parameter 
(corresponds to ``socket.TCP_KEEPCNT``)
@@ -76,6 +79,7 @@ class HttpHook(BaseHook):
         tcp_keep_alive_idle: int = 120,
         tcp_keep_alive_count: int = 20,
         tcp_keep_alive_interval: int = 30,
+        adapter: HTTPAdapter | None = None,
     ) -> None:
         super().__init__()
         self.http_conn_id = http_conn_id
@@ -83,10 +87,17 @@ class HttpHook(BaseHook):
         self.base_url: str = ""
         self._retry_obj: Callable[..., Any]
         self._auth_type: Any = auth_type
-        self.tcp_keep_alive = tcp_keep_alive
-        self.keep_alive_idle = tcp_keep_alive_idle
-        self.keep_alive_count = tcp_keep_alive_count
-        self.keep_alive_interval = tcp_keep_alive_interval
+
+        # If no adapter is provided, use TCPKeepAliveAdapter (default behavior)
+        self.adapter = adapter
+        if tcp_keep_alive and adapter is None:
+            self.keep_alive_adapter = TCPKeepAliveAdapter(
+                idle=tcp_keep_alive_idle,
+                count=tcp_keep_alive_count,
+                interval=tcp_keep_alive_interval,
+            )
+        else:
+            self.keep_alive_adapter = None
 
     @property
     def auth_type(self):
@@ -102,47 +113,76 @@ class HttpHook(BaseHook):
         """
         Create a Requests HTTP session.
 
-        :param headers: additional headers to be passed through as a dictionary
+        :param headers: Additional headers to be passed through as a 
dictionary.
+        :return: A configured requests.Session object.
         """
         session = requests.Session()
-
-        if self.http_conn_id:
-            conn = self.get_connection(self.http_conn_id)
-
-            if conn.host and "://" in conn.host:
-                self.base_url = conn.host
-            else:
-                # schema defaults to HTTP
-                schema = conn.schema if conn.schema else "http"
-                host = conn.host if conn.host else ""
-                self.base_url = f"{schema}://{host}"
-
-            if conn.port:
-                self.base_url += f":{conn.port}"
-            if conn.login:
-                session.auth = self.auth_type(conn.login, conn.password)
-            elif self._auth_type:
-                session.auth = self.auth_type()
-            if conn.extra:
-                extra = conn.extra_dejson
-                extra.pop(
-                    "timeout", None
-                )  # ignore this as timeout is only accepted in request method 
of Session
-                extra.pop("allow_redirects", None)  # ignore this as only 
max_redirects is accepted in Session
-                session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
-                session.stream = extra.pop("stream", False)
-                session.verify = extra.pop("verify", extra.pop("verify_ssl", 
True))
-                session.cert = extra.pop("cert", None)
-                session.max_redirects = extra.pop("max_redirects", 
DEFAULT_REDIRECT_LIMIT)
-                session.trust_env = extra.pop("trust_env", True)
-
-                try:
-                    session.headers.update(extra)
-                except TypeError:
-                    self.log.warning("Connection to %s has invalid extra 
field.", conn.host)
+        connection = self.get_connection(self.http_conn_id)
+        self._set_base_url(connection)
+        session = self._configure_session_from_auth(session, connection)
+        if connection.extra:
+            session = self._configure_session_from_extra(session, connection)
+        session = self._configure_session_from_mount_adapters(session)
         if headers:
             session.headers.update(headers)
+        return session
+
+    def _set_base_url(self, connection: Connection) -> None:
+        host = connection.host or ""
+        schema = connection.schema or "http"
+        # RFC 3986 (https://www.rfc-editor.org/rfc/rfc3986.html#page-16)
+        if "://" in host:
+            self.base_url = host
+        else:
+            self.base_url = f"{schema}://{host}" if host else f"{schema}://"
+            if connection.port:
+                self.base_url = f"{self.base_url}:{connection.port}"
+        parsed = urlparse(self.base_url)
+        if not parsed.scheme:
+            raise ValueError(f"Invalid base URL: Missing scheme in 
{self.base_url}")
+
+    def _configure_session_from_auth(
+        self, session: requests.Session, connection: Connection
+    ) -> requests.Session:
+        session.auth = self._extract_auth(connection)
+        return session
+
+    def _extract_auth(self, connection: Connection) -> Any | None:
+        if connection.login:
+            return self.auth_type(connection.login, connection.password)
+        elif self._auth_type:
+            return self.auth_type()
+        return None
+
+    def _configure_session_from_extra(
+        self, session: requests.Session, connection: Connection
+    ) -> requests.Session:
+        extra = connection.extra_dejson
+        extra.pop("timeout", None)
+        extra.pop("allow_redirects", None)
+        session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
+        session.stream = extra.pop("stream", False)
+        session.verify = extra.pop("verify", extra.pop("verify_ssl", True))
+        session.cert = extra.pop("cert", None)
+        session.max_redirects = extra.pop("max_redirects", 
DEFAULT_REDIRECT_LIMIT)
+        session.trust_env = extra.pop("trust_env", True)
+        try:
+            session.headers.update(extra)
+        except TypeError:
+            self.log.warning("Connection to %s has invalid extra field.", 
connection.host)
+        return session
 
+    def _configure_session_from_mount_adapters(self, session: 
requests.Session) -> requests.Session:
+        scheme = urlparse(self.base_url).scheme
+        if not scheme:
+            raise ValueError(
+                f"Cannot mount adapters: {self.base_url} does not include a 
valid scheme (http or https)."
+            )
+        if self.adapter:
+            session.mount(f"{scheme}://", self.adapter)
+        elif self.keep_alive_adapter:
+            session.mount("http://";, self.keep_alive_adapter)
+            session.mount("https://";, self.keep_alive_adapter)
         return session
 
     def run(
@@ -171,11 +211,6 @@ class HttpHook(BaseHook):
 
         url = self.url_from_endpoint(endpoint)
 
-        if self.tcp_keep_alive:
-            keep_alive_adapter = TCPKeepAliveAdapter(
-                idle=self.keep_alive_idle, count=self.keep_alive_count, 
interval=self.keep_alive_interval
-            )
-            session.mount(url, keep_alive_adapter)
         if self.method == "GET":
             # GET uses params
             req = requests.Request(self.method, url, params=data, 
headers=headers, **request_kwargs)
@@ -467,5 +502,4 @@ class HttpAsyncHook(BaseHook):
         if exception.status == 413:
             # don't retry for payload Too Large
             return False
-
         return exception.status >= 500
diff --git a/providers/tests/http/hooks/test_http.py 
b/providers/tests/http/hooks/test_http.py
index e09fd2d034e..bd381a7155b 100644
--- a/providers/tests/http/hooks/test_http.py
+++ b/providers/tests/http/hooks/test_http.py
@@ -29,7 +29,7 @@ import pytest
 import requests
 import tenacity
 from aioresponses import aioresponses
-from requests.adapters import Response
+from requests.adapters import HTTPAdapter, Response
 from requests.auth import AuthBase, HTTPBasicAuth
 from requests.models import DEFAULT_REDIRECT_LIMIT
 
@@ -536,6 +536,20 @@ class TestHttpHook:
         hook.base_url = base_url
         assert hook.url_from_endpoint(endpoint) == expected_url
 
+    def test_custom_adapter(self):
+        with mock.patch(
+            "airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection_with_port
+        ):
+            custom_adapter = HTTPAdapter()
+            hook = HttpHook(method="GET", adapter=custom_adapter)
+            session = hook.get_conn()
+            assert isinstance(
+                session.adapters["http://";], type(custom_adapter)
+            ), "Custom HTTP adapter not correctly mounted"
+            assert isinstance(
+                session.adapters["https://";], type(custom_adapter)
+            ), "Custom HTTPS adapter not correctly mounted"
+
 
 class TestHttpAsyncHook:
     @pytest.mark.asyncio

Reply via email to