This is an automated email from the ASF dual-hosted git repository.
dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 49f48e3bfbf Refactor HttpAsyncHook to support session-based async HTTP
operations and simplify LivyAsyncHook (#60458)
49f48e3bfbf is described below
commit 49f48e3bfbf5ad0062852b96599d06c02229e540
Author: David Blain <[email protected]>
AuthorDate: Sat Mar 14 23:11:23 2026 +0100
Refactor HttpAsyncHook to support session-based async HTTP operations and
simplify LivyAsyncHook (#60458)
* refactor: Refactored HttpAsyncHook to easily support session based run
operations
* fix: Fixed import of LoggingMixin
* refactor: LivyAsyncHook now reuses logic from HttpAsyncHook which is more
DRY
* refactor: Reformatted HttpAsyncHook
* refactor: Fixed possible None types for merged_headers
* refactor: Changed type of _retryable_error_async method
* refactor: Removed unused import
* refactor: Moved SessionConfig inside AsyncHttpSession
* refactor: Reformatted run method of HttpAsyncHook
* refactor: Removed unused import from LivyHook module
* Revert "refactor: Moved SessionConfig inside AsyncHttpSession"
This reverts commit f9c503d4b6aecfbda636d25daf0d6fa59ddd3dd4.
* refactor: Added docstring for retry_limit and retry_delay parameters
* refactor: Reformatted docstring in _retryable_error_async method
* refactor: Added docstring for SessionConfig and AsyncHttpSession
* refactor: Added warning logging when run attempt fails
* refactor: Refactored run_method of LivyAsyncHook
* refactor: Refactored unit tests for LivyAsyncHook
* refactor: Reformatted AsyncHttpSession
* refactor: Reformatted run_method of LivyAsyncHook
* refactor: Escape aiohttp.ClientSession in docstring of session
contextmanager in HttpAsyncHook
* refactor: Also take into extra_options from connection when building
AsyncHttpSession
* refactor: Fixed mocking of test_run_method_success
* refactor: Removed unused imports
* refactor: Reorganized imports
* refactor: Run method of LivyAsyncHook must internally use session from
HttpAsyncHook so it doesn't rely on the error handling of the HttpAsyncHook run
method
* refactor: Escape reserved words in HttpAsyncHook
* refactor: Mock get_async_connection in TestLivyAsyncHook
* refactor: Mock get_async_connection in TestLivyAsyncHook should be
patched on http hook module
* refactor: Mock get_async_connection in TestLivyAsyncHook should be
patched on http hook module
* refactor: Make sure get_async_connection is mocked with real Connection
* refactor: Reformatted Livy unit test
* refactor: Add get_async_connection mock in
test_run_put_method_with_type_error
* refactor: Make sure http provider dependency is set to next release when
livy provider is release
* refactor: Added TODO on asgiref dependency as I can probably be removed
as it will be resolved transiently through common-compat provider
* refactor: Removed asgiref dependency in livy provider
* refactor: Removed asgiref reference from docs
* refactor: Fixed assertion of Connection type in test_build_get_hook of
TestLivyAsyncHook
* refactor: Don't need to assert connections anymore in test_build_get_hook
of TestLivyAsyncHook
---------
Co-authored-by: David Blain <[email protected]>
---
providers/apache/livy/docs/index.rst | 1 -
providers/apache/livy/pyproject.toml | 3 +-
.../airflow/providers/apache/livy/hooks/livy.py | 128 ++------
.../livy/tests/unit/apache/livy/hooks/test_livy.py | 213 ++++++-------
.../http/src/airflow/providers/http/hooks/http.py | 339 +++++++++++++++------
5 files changed, 370 insertions(+), 314 deletions(-)
diff --git a/providers/apache/livy/docs/index.rst
b/providers/apache/livy/docs/index.rst
index 058669470b5..cc3d26534ed 100644
--- a/providers/apache/livy/docs/index.rst
+++ b/providers/apache/livy/docs/index.rst
@@ -103,7 +103,6 @@ PIP package Version required
``apache-airflow-providers-http`` ``>=5.1.0``
``apache-airflow-providers-common-compat`` ``>=1.12.0``
``aiohttp`` ``>=3.9.2``
-``asgiref`` ``>=2.3.0``
========================================== ==================
Cross provider package dependencies
diff --git a/providers/apache/livy/pyproject.toml
b/providers/apache/livy/pyproject.toml
index 3000a45eb05..9a013b15034 100644
--- a/providers/apache/livy/pyproject.toml
+++ b/providers/apache/livy/pyproject.toml
@@ -59,10 +59,9 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
"apache-airflow>=2.11.0",
- "apache-airflow-providers-http>=5.1.0",
+ "apache-airflow-providers-http>=5.1.0", # use next version
"apache-airflow-providers-common-compat>=1.12.0",
"aiohttp>=3.9.2",
- "asgiref>=2.3.0",
]
[dependency-groups]
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
index 9de3582de11..e9f8e94c748 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
@@ -16,24 +16,18 @@
# under the License.
from __future__ import annotations
-import asyncio
import json
import re
from collections.abc import Sequence
from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import Any
-import aiohttp
import requests
from aiohttp import ClientResponseError
-from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
-if TYPE_CHECKING:
- from airflow.models import Connection
-
class BatchState(Enum):
"""Batch session states."""
@@ -502,101 +496,10 @@ class LivyAsyncHook(HttpAsyncHook):
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
- async def _do_api_call_async(
- self,
- endpoint: str | None = None,
- data: dict[str, Any] | str | None = None,
- headers: dict[str, Any] | None = None,
- extra_options: dict[str, Any] | None = None,
- ) -> Any:
- """
- Perform an asynchronous HTTP request call.
-
- :param endpoint: the endpoint to be called i.e. resource/v1/query?
- :param data: payload to be uploaded or request parameters
- :param headers: additional headers to be passed through as a dictionary
- :param extra_options: Additional kwargs to pass when creating a
request.
- For example, ``run(json=obj)`` is passed as
``aiohttp.ClientSession().get(json=obj)``
- """
- extra_options = extra_options or {}
-
- # headers may be passed through directly or in the "extra" field in
the connection
- # definition
- _headers = {}
- auth = None
-
- if self.http_conn_id:
- conn = await get_async_connection(self.http_conn_id)
-
- self.base_url = self._generate_base_url(conn) # type:
ignore[arg-type]
- if conn.login:
- auth = self.auth_type(conn.login, conn.password)
- if conn.extra:
- try:
- _headers.update(conn.extra_dejson)
- except TypeError:
- self.log.warning("Connection to %s has invalid extra
field.", conn.host)
- if headers:
- _headers.update(headers)
-
- if self.base_url and not self.base_url.endswith("/") and endpoint and
not endpoint.startswith("/"):
- url = self.base_url + "/" + endpoint
- else:
- url = (self.base_url or "") + (endpoint or "")
-
- async with aiohttp.ClientSession() as session:
- if self.method == "GET":
- request_func = session.get
- elif self.method == "POST":
- request_func = session.post
- elif self.method == "PATCH":
- request_func = session.patch
- else:
- return {"Response": f"Unexpected HTTP Method: {self.method}",
"status": "error"}
-
- for attempt_num in range(1, 1 + self.retry_limit):
- response = await request_func(
- url,
- json=data if self.method in ("POST", "PATCH") else None,
- params=data if self.method == "GET" else None,
- headers=_headers or None,
- auth=auth,
- **extra_options,
- )
- try:
- response.raise_for_status()
- return await response.json()
- except ClientResponseError as e:
- self.log.warning(
- "[Try %d of %d] Request to %s failed.",
- attempt_num,
- self.retry_limit,
- url,
- )
- if not self._retryable_error_async(e) or attempt_num ==
self.retry_limit:
- self.log.exception("HTTP error, status code: %s",
e.status)
- # In this case, the user probably made a mistake.
- # Don't retry.
- return {"Response": {e.message}, "Status Code":
{e.status}, "status": "error"}
-
- await asyncio.sleep(self.retry_delay)
-
- def _generate_base_url(self, conn: Connection) -> str:
- if conn.host and "://" in conn.host:
- base_url: str = conn.host
- else:
- # schema defaults to HTTP
- schema = conn.schema if conn.schema else "http"
- host = conn.host if conn.host else ""
- base_url = f"{schema}://{host}"
- if conn.port:
- base_url = f"{base_url}:{conn.port}"
- return base_url
-
async def run_method(
self,
endpoint: str,
- method: str = "GET",
+ method: str | None = None,
data: Any | None = None,
headers: dict[str, Any] | None = None,
) -> Any:
@@ -609,16 +512,29 @@ class LivyAsyncHook(HttpAsyncHook):
:param headers: headers
:return: http response
"""
- if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"):
+ method = method or self.method
+ if method not in {"GET", "PATCH", "POST", "PUT", "DELETE", "HEAD"}:
return {"status": "error", "response": f"Invalid http method
{method}"}
- back_method = self.method
- self.method = method
+ endpoint = (
+ f"{self.endpoint_prefix}/{endpoint}"
+ if self.endpoint_prefix and endpoint
+ else endpoint or self.endpoint_prefix
+ )
+
try:
- result = await self._do_api_call_async(endpoint, data, headers,
self.extra_options)
- finally:
- self.method = back_method
- return {"status": "success", "response": result}
+ async with self.session() as session:
+ response = await session.run(
+ endpoint=endpoint,
+ data=data,
+ headers={**self._def_headers, **self.extra_headers,
**(headers or {})},
+ extra_options=self.extra_options,
+ )
+
+ result = await response.json()
+ return {"status": "success", "response": result}
+ except ClientResponseError as e:
+ return {"Response": {e.message}, "Status Code": {e.status},
"status": "error"}
async def get_batch_state(self, session_id: int | str) -> Any:
"""
diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
index 3a819c38064..90bdcad8866 100644
--- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
@@ -592,159 +592,163 @@ class TestLivyAsyncHook:
assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
- async def test_run_method_success(self, mock_do_api_call_async):
+ @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_method_success(self, mock_get_connection, mock_session):
"""Asserts the run_method for success response."""
- mock_do_api_call_async.return_value = {"status": "error", "response":
{"id": 1}}
+ mock_session.return_value.__aenter__.return_value.post = AsyncMock()
+
mock_session.return_value.__aenter__.return_value.post.return_value.json =
AsyncMock(
+ return_value={"id": 1}
+ )
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "GET")
- assert response["status"] == "success"
+ assert response == {"status": "success", "response": {"id": 1}}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
- async def test_run_method_error(self, mock_do_api_call_async):
+ async def test_run_method_error(self):
"""Asserts the run_method for error response."""
- mock_do_api_call_async.return_value = {"status": "error", "response":
{"id": 1}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "abc")
assert response == {"status": "error", "response": "Invalid http
method abc"}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_post_method_with_success(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for success response for POST
method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
+ @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_post_method_with_success(self, mock_get_connection,
mock_session):
+ """Asserts the run_method for success response for POST method."""
mock_session.return_value.__aenter__.return_value.post = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value.json =
AsyncMock(
- return_value={"status": "success"}
+ return_value={"hello": "world"}
)
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "https://localhost"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
+ response = await hook.run_method("api/jobs/runs/get")
+ assert response["status"] == "success"
+ assert response["response"] == {"hello": "world"}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_get_method_with_success(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for GET method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
+ @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_get_method_with_success(self, mock_get_connection,
mock_session):
+ """Asserts the run_method for GET method."""
mock_session.return_value.__aenter__.return_value.get = AsyncMock()
mock_session.return_value.__aenter__.return_value.get.return_value.json =
AsyncMock(
- return_value={"status": "success"}
+ return_value={"hello": "world"}
)
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "GET"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
+ response = await hook.run_method("api/jobs/runs/get")
+ assert response["status"] == "success"
+ assert response["response"] == {"hello": "world"}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_patch_method_with_success(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for PATCH method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
+ @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_patch_method_with_success(self, mock_get_connection,
mock_session):
+ """Asserts the run_method for PATCH method."""
mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
mock_session.return_value.__aenter__.return_value.patch.return_value.json =
AsyncMock(
- return_value={"status": "success"}
+ return_value={"hello": "world"}
)
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "PATCH"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
+ response = await hook.run_method("api/jobs/runs/get")
+ assert response["status"] == "success"
+ assert response["response"] == {"hello": "world"}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_unexpected_method_error(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for unexpected method error"""
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_unexpected_method_with_success(self,
mock_get_connection):
+ """Asserts the run_method for unexpected method error"""
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "abc"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT,
headers={})
- assert response == {"Response": "Unexpected HTTP Method: abc",
"status": "error"}
+ response = await hook.run_method(endpoint="api/jobs/runs/get",
headers={})
+ assert response == {"response": "Invalid http method abc", "status":
"error"}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_with_type_error(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for TypeError."""
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_put_method_with_type_error(self, mock_get_connection):
+ """Asserts the run_method for TypeError."""
async def mock_fun(arg1, arg2, arg3, arg4):
return {"random value"}
- mock_session.return_value.__aexit__.return_value = mock_fun
-
mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value
= {}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "PATCH"
- hook.retry_limit = 1
- hook.retry_delay = 1
- hook.http_conn_id = mock_get_connection
with pytest.raises(TypeError):
- await hook._do_api_call_async(endpoint="", data="test",
headers=mock_fun, extra_options=mock_fun)
+ await hook.run_method(endpoint="api/jobs/runs/get", data="test",
headers=mock_fun)
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
-
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
- async def test_do_api_call_async_with_client_response_error(self,
mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for Client Response Error."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"random value"}
+ @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
+ @mock.patch(
+ "airflow.providers.common.compat.connection.get_async_connection",
+ return_value=Connection(
+ conn_id=LIVY_CONN_ID,
+ conn_type="http",
+ host="http://host",
+ port=80,
+ ),
+ )
+ async def test_run_method_with_client_response_error(self,
mock_get_connection, mock_session):
+ """Asserts the run_method for Client Response Error."""
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
-
mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect
= (
- ClientResponseError(
+ mock_session.return_value.__aenter__.return_value.patch = AsyncMock(
+ side_effect=ClientResponseError(
request_info=RequestInfo(url="example.com", method="PATCH",
headers=multidict.CIMultiDict()),
status=500,
history=[],
)
)
- GET_RUN_ENDPOINT = ""
hook = LivyAsyncHook(livy_conn_id="livy_default")
hook.method = "PATCH"
- hook.base_url = ""
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
+ response = await hook.run_method("")
assert response["status"] == "error"
@pytest.fixture
@@ -764,7 +768,8 @@ class TestLivyAsyncHook:
create_connection_without_db(Connection(conn_id="missing_host",
conn_type="http", port=1234))
create_connection_without_db(Connection(conn_id="invalid_uri",
uri="http://invalid_uri:4321"))
- def test_build_get_hook(self, setup_livy_conn):
+ @pytest.mark.asyncio
+ async def test_build_get_hook(self, setup_livy_conn):
connection_url_mapping = {
# id, expected
"default_port": "http://host",
@@ -776,8 +781,8 @@ class TestLivyAsyncHook:
for conn_id, expected in connection_url_mapping.items():
hook = LivyAsyncHook(livy_conn_id=conn_id)
- response_conn = hook.get_connection(conn_id=conn_id)
- assert hook._generate_base_url(response_conn) == expected
+ async with hook.session() as session:
+ assert session.base_url == expected
def test_build_body(self):
# minimal request
diff --git a/providers/http/src/airflow/providers/http/hooks/http.py
b/providers/http/src/airflow/providers/http/hooks/http.py
index ed137a651c4..240bd24d06c 100644
--- a/providers/http/src/airflow/providers/http/hooks/http.py
+++ b/providers/http/src/airflow/providers/http/hooks/http.py
@@ -18,22 +18,25 @@
from __future__ import annotations
import copy
-from collections.abc import Callable
+from collections.abc import AsyncGenerator, Awaitable, Callable
+from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
import aiohttp
import tenacity
from aiohttp import ClientResponseError
+from pydantic import BaseModel
from requests import PreparedRequest, Request, Response, Session
from requests.auth import HTTPBasicAuth
from requests.exceptions import ConnectionError, HTTPError
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
+from tenacity import retry_if_exception
-from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.http.exceptions import HttpErrorException,
HttpMethodException
+from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
from aiohttp.client_reqrep import ClientResponse
@@ -95,6 +98,28 @@ def _process_extra_options_from_connection(
return conn_extra_options, passed_extra_options
+def _retryable_error_async(exception: BaseException) -> bool:
+ """
+ Determine whether an exception may successful on a subsequent attempt.
+
+ It considers the following to be retryable:
+ - requests_exceptions.ConnectionError
+ - requests_exceptions.Timeout
+ - anything with a status code >= 500
+
+ Most retryable errors are covered by status code >= 500.
+ """
+ if not isinstance(exception, ClientResponseError):
+ return False
+ if exception.status == 429:
+ # don't retry for too Many Requests
+ return False
+ if exception.status == 413:
+ # don't retry for payload Too Large
+ return False
+ return exception.status >= 500
+
+
class HttpHook(BaseHook):
"""
Interact with HTTP servers.
@@ -399,6 +424,132 @@ class HttpHook(BaseHook):
return False, str(e)
+class SessionConfig(BaseModel):
+ """Configuration container for an asynchronous HTTP session."""
+
+ base_url: str
+ headers: dict[str, Any] | None = None
+ auth: aiohttp.BasicAuth | None = None
+ extra_options: dict[str, Any] | None = None
+
+
+class AsyncHttpSession(LoggingMixin):
+ """
+ Wrapper around an ``aiohttp.ClientSession`` providing a session bound
``HttpAsyncHook``.
+
+ This class binds an asynchronous HTTP client session to an
``HttpAsyncHook`` and applies connection
+ configuration, authentication, headers, and retry logic consistently
across requests. A single
+ ``AsyncHttpSession`` instance is intended to be used for multiple HTTP
calls within the same logical session.
+
+ :param hook: The ``HttpAsyncHook`` instance that owns this session and
provides connection-level behavior
+ such as retries and logging.
+ :param request: A callable used to perform the underlying HTTP request.
This is typically a bound
+ ``aiohttp.ClientSession`` request method.
+ :param config: Resolved session configuration containing base URL,
headers, and authentication settings.
+ """
+
+ def __init__(
+ self,
+ hook: HttpAsyncHook,
+ request: Callable[..., Awaitable[ClientResponse]],
+ config: SessionConfig,
+ ) -> None:
+ super().__init__()
+ self._hook = hook
+ self._request = request
+ self.config = config
+
+ @property
+ def http_conn_id(self) -> str:
+ return self._hook.http_conn_id
+
+ @property
+ def base_url(self) -> str:
+ return self.config.base_url
+
+ @property
+ def method(self) -> str:
+ return self._hook.method
+
+ @property
+ def retry_limit(self) -> int:
+ return self._hook.retry_limit
+
+ @property
+ def retry_delay(self) -> float:
+ return self._hook.retry_delay
+
+ @property
+ def headers(self) -> dict[str, Any] | None:
+ return self.config.headers
+
+ @property
+ def extra_options(self) -> dict[str, Any] | None:
+ return self.config.extra_options
+
+ @property
+ def auth(self) -> aiohttp.BasicAuth | None:
+ return self.config.auth
+
+ async def run(
+ self,
+ endpoint: str | None = None,
+ data: dict[str, Any] | str | None = None,
+ json: dict[str, Any] | str | None = None,
+ headers: dict[str, Any] | None = None,
+ extra_options: dict[str, Any] | None = None,
+ ) -> ClientResponse:
+ """
+ Perform an asynchronous HTTP request call.
+
+ :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``.
+ :param data: Payload to be uploaded or request parameters.
+ :param json: Payload to be uploaded as JSON.
+ :param headers: Additional headers to be passed through as a dict.
+ :param extra_options: Additional kwargs to pass when creating a
request.
+ For example, ``run(json=obj)`` is passed as
+ ``aiohttp.ClientSession().get(json=obj)``.
+ """
+ from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed
+
+ url = _url_from_endpoint(self.base_url, endpoint)
+ merged_headers = {**(self.headers or {}), **(headers or {})}
+ extra_options = {**(self.extra_options or {}), **(extra_options or {})}
+
+ async def request_func() -> ClientResponse:
+ response = await self._request(
+ url,
+ params=data if self.method == "GET" else None,
+ data=data if self.method in {"POST", "PUT", "PATCH"} else None,
+ json=json,
+ headers=merged_headers,
+ auth=self.auth,
+ **extra_options,
+ )
+ response.raise_for_status()
+ return response
+
+ async for attempt in AsyncRetrying(
+ stop=stop_after_attempt(self.retry_limit),
+ wait=wait_fixed(self.retry_delay),
+ retry=retry_if_exception(_retryable_error_async),
+ reraise=True,
+ ):
+ with attempt:
+ try:
+ return await request_func()
+ except ClientResponseError as e:
+ self.log.warning(
+ "[Try %d of %d] Request to %s failed.",
+ attempt.retry_state.attempt_number,
+ self.retry_limit,
+ url,
+ )
+ raise e
+
+ raise NotImplementedError # should not reach this, but makes mypy
happy
+
+
class HttpAsyncHook(BaseHook):
"""
Interact with HTTP servers asynchronously.
@@ -408,6 +559,8 @@ class HttpAsyncHook(BaseHook):
API url i.e https://www.google.com/ and optional authentication
credentials. Default
headers can also be specified in the Extra field in json format.
:param auth_type: The auth type for the service
+ :param retry_limit: Maximum number of times to retry this job if it fails
(default is 3)
+ :param retry_delay: Delay between retry attempts (default is 1.0)
"""
conn_name_attr = "http_conn_id"
@@ -429,13 +582,82 @@ class HttpAsyncHook(BaseHook):
self._retry_obj: Callable[..., Any]
self.auth_type: Any = auth_type
if retry_limit < 1:
- raise ValueError("Retry limit must be greater than equal to 1")
+ raise ValueError("Retry limit must be greater or equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay
+ self._config: SessionConfig | None = None
+
+ def _get_request_func(self, session: aiohttp.ClientSession) ->
Callable[..., Any]:
+ method = self.method
+ if method == "GET":
+ return session.get
+ if method == "POST":
+ return session.post
+ if method == "PATCH":
+ return session.patch
+ if method == "HEAD":
+ return session.head
+ if method == "PUT":
+ return session.put
+ if method == "DELETE":
+ return session.delete
+ if method == "OPTIONS":
+ return session.options
+ raise HttpMethodException(f"Unexpected HTTP Method: {method}")
+
+ async def config(self) -> SessionConfig:
+ if not self._config:
+ from airflow.providers.common.compat.connection import
get_async_connection
+
+ base_url: str = self.base_url
+ auth: aiohttp.BasicAuth | None = None
+ headers: dict[str, Any] = {}
+ extra_options: dict[str, Any] = {}
+
+ if self.http_conn_id:
+ conn = await get_async_connection(conn_id=self.http_conn_id)
+
+ if conn.host and "://" in conn.host:
+ base_url = conn.host
+ else:
+ schema = conn.schema or "http"
+ base_url = f"{schema}://{conn.host or ''}"
+
+ if conn.port:
+ base_url += f":{conn.port}"
+
+ if conn.login:
+ auth = self.auth_type(conn.login, conn.password)
+
+ if conn.extra:
+ conn_extra_options, extra_options =
_process_extra_options_from_connection(
+ conn=conn, extra_options={}
+ )
+ headers.update(conn_extra_options)
+
+ self._config = SessionConfig(
+ base_url=base_url,
+ headers=headers,
+ auth=auth,
+ extra_options=extra_options,
+ )
+ return self._config
+
+ @asynccontextmanager
+ async def session(self) -> AsyncGenerator[AsyncHttpSession, None]:
+ """
+ Create an ``AsyncHttpSession`` bound to a single
``aiohttp.ClientSession``.
+
+ Airflow connection resolution happens exactly once here.
+ """
+ async with aiohttp.ClientSession() as session:
+ request = self._get_request_func(session=session)
+ config = await self.config()
+ yield AsyncHttpSession(hook=self, request=request, config=config)
async def run(
self,
- session: aiohttp.ClientSession,
+ session: aiohttp.ClientSession | None = None,
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
json: dict[str, Any] | str | None = None,
@@ -445,6 +667,7 @@ class HttpAsyncHook(BaseHook):
"""
Perform an asynchronous HTTP request call.
+ :param session: ``aiohttp.ClientSession``
:param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``.
:param data: Payload to be uploaded or request parameters.
:param json: Payload to be uploaded as JSON.
@@ -453,103 +676,17 @@ class HttpAsyncHook(BaseHook):
For example, ``run(json=obj)`` is passed as
``aiohttp.ClientSession().get(json=obj)``.
"""
- extra_options = extra_options or {}
-
- # headers may be passed through directly or in the "extra" field in
the connection
- # definition
- _headers = {}
- auth = None
-
- if self.http_conn_id:
- conn = await get_async_connection(self.http_conn_id)
-
- if conn.host and "://" in conn.host:
- self.base_url = conn.host
- else:
- # schema defaults to HTTP
- schema = conn.schema if conn.schema else "http"
- host = conn.host if conn.host else ""
- self.base_url = schema + "://" + host
-
- if conn.port:
- self.base_url += f":{conn.port}"
- if conn.login:
- auth = self.auth_type(conn.login, conn.password)
- if conn.extra:
- conn_extra_options, extra_options =
_process_extra_options_from_connection(
- conn=conn, extra_options=extra_options
+ try:
+ if session is not None:
+ request = self._get_request_func(session=session)
+ config = await self.config()
+ return await AsyncHttpSession(hook=self, request=request,
config=config).run(
+ endpoint=endpoint, data=data, json=json, headers=headers,
extra_options=extra_options
)
- try:
- _headers.update(conn_extra_options)
- except TypeError:
- self.log.warning("Connection to %s has invalid extra
field.", conn.host)
- if headers:
- _headers.update(headers)
-
- url = _url_from_endpoint(self.base_url, endpoint)
-
- if self.method == "GET":
- request_func = session.get
- elif self.method == "POST":
- request_func = session.post
- elif self.method == "PATCH":
- request_func = session.patch
- elif self.method == "HEAD":
- request_func = session.head
- elif self.method == "PUT":
- request_func = session.put
- elif self.method == "DELETE":
- request_func = session.delete
- elif self.method == "OPTIONS":
- request_func = session.options
- else:
- raise HttpMethodException(f"Unexpected HTTP Method: {self.method}")
-
- for attempt in range(1, 1 + self.retry_limit):
- response = await request_func(
- url,
- params=data if self.method == "GET" else None,
- data=data if self.method in ("POST", "PUT", "PATCH") else None,
- json=json,
- headers=_headers,
- auth=auth,
- **extra_options,
- )
- try:
- response.raise_for_status()
- except ClientResponseError as e:
- self.log.warning(
- "[Try %d of %d] Request to %s failed.",
- attempt,
- self.retry_limit,
- url,
+ async with self.session() as http:
+ return await http.run(
+ endpoint=endpoint, data=data, json=json, headers=headers,
extra_options=extra_options
)
- if not self._retryable_error_async(e) or attempt ==
self.retry_limit:
- self.log.exception("HTTP error with status: %s", e.status)
- # In this case, the user probably made a mistake.
- # Don't retry.
- raise HttpErrorException(f"{e.status}:{e.message}")
- else:
- return response
-
- raise NotImplementedError # should not reach this, but makes mypy
happy
-
- def _retryable_error_async(self, exception: ClientResponseError) -> bool:
- """
- Determine whether an exception may successful on a subsequent attempt.
-
- It considers the following to be retryable:
- - requests_exceptions.ConnectionError
- - requests_exceptions.Timeout
- - anything with a status code >= 500
-
- Most retryable errors are covered by status code >= 500.
- """
- if exception.status == 429:
- # don't retry for too Many Requests
- return False
- if exception.status == 413:
- # don't retry for payload Too Large
- return False
- return exception.status >= 500
+ except ClientResponseError as e:
+ raise HttpErrorException(f"{e.status}:{e.message}")