This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 47edfe9a22 Add HttpHookAsync for deferrable implementation (#29038)
47edfe9a22 is described below
commit 47edfe9a22d1c521e49de3bed87bc332a48c0a80
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Tue Feb 14 19:21:14 2023 +0545
Add HttpHookAsync for deferrable implementation (#29038)
This PR donates the following HttpHookAsync that interacts with HTTP
servers using Python Async. This was developed in astronomer-providers repo to
apache airflow.
---
airflow/providers/http/hooks/http.py | 163 +++++++++++++++++++++++++++++++-
airflow/providers/http/provider.yaml | 2 +
generated/provider_dependencies.json | 2 +
setup.py | 1 +
tests/providers/http/hooks/test_http.py | 95 ++++++++++++++++++-
5 files changed, 261 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/http/hooks/http.py
b/airflow/providers/http/hooks/http.py
index 6c37818ccb..235a7f8a2e 100644
--- a/airflow/providers/http/hooks/http.py
+++ b/airflow/providers/http/hooks/http.py
@@ -17,16 +17,23 @@
# under the License.
from __future__ import annotations
-from typing import Any, Callable
+import asyncio
+from typing import TYPE_CHECKING, Any, Callable
+import aiohttp
import requests
import tenacity
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
from requests.auth import HTTPBasicAuth
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+if TYPE_CHECKING:
+ from aiohttp.client_reqrep import ClientResponse
+
class HttpHook(BaseHook):
"""
@@ -246,3 +253,157 @@ class HttpHook(BaseHook):
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
+
+
+class HttpAsyncHook(BaseHook):
+ """
+ Interact with HTTP servers using Python Async.
+
+ :param method: the API method to be called
+ :param http_conn_id: http connection id that has the base
+ API url i.e https://www.google.com/ and optional authentication
credentials. Default
+ headers can also be specified in the Extra field in json format.
+ :param auth_type: The auth type for the service
+ """
+
+ conn_name_attr = "http_conn_id"
+ default_conn_name = "http_default"
+ conn_type = "http"
+ hook_name = "HTTP"
+
+ def __init__(
+ self,
+ method: str = "POST",
+ http_conn_id: str = default_conn_name,
+ auth_type: Any = aiohttp.BasicAuth,
+ retry_limit: int = 3,
+ retry_delay: float = 1.0,
+ ) -> None:
+ self.http_conn_id = http_conn_id
+ self.method = method.upper()
+ self.base_url: str = ""
+ self._retry_obj: Callable[..., Any]
+ self.auth_type: Any = auth_type
+ if retry_limit < 1:
+ raise ValueError("Retry limit must be greater than equal to 1")
+ self.retry_limit = retry_limit
+ self.retry_delay = retry_delay
+
+ async def run(
+ self,
+ endpoint: str | None = None,
+ data: dict[str, Any] | str | None = None,
+ headers: dict[str, Any] | None = None,
+ extra_options: dict[str, Any] | None = None,
+ ) -> "ClientResponse":
+ r"""
+ Performs an asynchronous HTTP request call
+
+ :param endpoint: the endpoint to be called i.e. resource/v1/query?
+ :param data: payload to be uploaded or request parameters
+ :param headers: additional headers to be passed through as a dictionary
+ :param extra_options: Additional kwargs to pass when creating a
request.
+ For example, ``run(json=obj)`` is passed as
``aiohttp.ClientSession().get(json=obj)``
+ """
+ extra_options = extra_options or {}
+
+ # headers may be passed through directly or in the "extra" field in
the connection
+ # definition
+ _headers = {}
+ auth = None
+
+ if self.http_conn_id:
+ conn = await sync_to_async(self.get_connection)(self.http_conn_id)
+
+ if conn.host and "://" in conn.host:
+ self.base_url = conn.host
+ else:
+ # schema defaults to HTTP
+ schema = conn.schema if conn.schema else "http"
+ host = conn.host if conn.host else ""
+ self.base_url = schema + "://" + host
+
+ if conn.port:
+ self.base_url = self.base_url + ":" + str(conn.port)
+ if conn.login:
+ auth = self.auth_type(conn.login, conn.password)
+ if conn.extra:
+ try:
+ _headers.update(conn.extra_dejson)
+ except TypeError:
+ self.log.warning("Connection to %s has invalid extra
field.", conn.host)
+ if headers:
+ _headers.update(headers)
+
+ if self.base_url and not self.base_url.endswith("/") and endpoint and
not endpoint.startswith("/"):
+ url = self.base_url + "/" + endpoint
+ else:
+ url = (self.base_url or "") + (endpoint or "")
+
+ async with aiohttp.ClientSession() as session:
+ if self.method == "GET":
+ request_func = session.get
+ elif self.method == "POST":
+ request_func = session.post
+ elif self.method == "PATCH":
+ request_func = session.patch
+ elif self.method == "HEAD":
+ request_func = session.head
+ elif self.method == "PUT":
+ request_func = session.put
+ elif self.method == "DELETE":
+ request_func = session.delete
+ elif self.method == "OPTIONS":
+ request_func = session.options
+ else:
+ raise AirflowException(f"Unexpected HTTP Method:
{self.method}")
+
+ attempt_num = 1
+ while True:
+ response = await request_func(
+ url,
+ json=data if self.method in ("POST", "PATCH") else None,
+ params=data if self.method == "GET" else None,
+ headers=headers,
+ auth=auth,
+ **extra_options,
+ )
+ try:
+ response.raise_for_status()
+ return response
+ except ClientResponseError as e:
+ self.log.warning(
+ "[Try %d of %d] Request to %s failed.",
+ attempt_num,
+ self.retry_limit,
+ url,
+ )
+ if not self._retryable_error_async(e) or attempt_num ==
self.retry_limit:
+ self.log.exception("HTTP error with status: %s",
e.status)
+ # In this case, the user probably made a mistake.
+ # Don't retry.
+ raise AirflowException(f"{e.status}:{e.message}")
+
+ attempt_num += 1
+ await asyncio.sleep(self.retry_delay)
+
+ def _retryable_error_async(self, exception: ClientResponseError) -> bool:
+ """
+ Determines whether or not an exception that was thrown might be
successful
+ on a subsequent attempt.
+
+ It considers the following to be retryable:
+ - requests_exceptions.ConnectionError
+ - requests_exceptions.Timeout
+ - anything with a status code >= 500
+
+ Most retryable errors are covered by status code >= 500.
+ """
+ if exception.status == 429:
+ # don't retry for too Many Requests
+ return False
+ if exception.status == 413:
+ # don't retry for payload Too Large
+ return False
+
+ return exception.status >= 500
diff --git a/airflow/providers/http/provider.yaml
b/airflow/providers/http/provider.yaml
index 4c9f6c3e77..3294ac56ec 100644
--- a/airflow/providers/http/provider.yaml
+++ b/airflow/providers/http/provider.yaml
@@ -42,6 +42,8 @@ dependencies:
# release it as a requirement for airflow
- requests>=2.26.0
- requests_toolbelt
+ - aiohttp
+ - asgiref
integrations:
- integration-name: Hypertext Transfer Protocol (HTTP)
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index b3fde3a3a8..4c3c62fe9a 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -415,6 +415,8 @@
},
"http": {
"deps": [
+ "aiohttp",
+ "asgiref",
"requests>=2.26.0",
"requests_toolbelt"
],
diff --git a/setup.py b/setup.py
index 59397c1196..450062cb7f 100644
--- a/setup.py
+++ b/setup.py
@@ -404,6 +404,7 @@ devel_only = [
"twine",
"wheel",
"yamllint",
+ "aioresponses",
]
diff --git a/tests/providers/http/hooks/test_http.py
b/tests/providers/http/hooks/test_http.py
index 03bf09075c..820c45c5a3 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+import logging
import os
from collections import OrderedDict
from http import HTTPStatus
@@ -26,11 +27,12 @@ from unittest import mock
import pytest
import requests
import tenacity
+from aioresponses import aioresponses
from requests.adapters import Response
from airflow.exceptions import AirflowException
from airflow.models import Connection
-from airflow.providers.http.hooks.http import HttpHook
+from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
def get_airflow_connection(unused_conn_id=None):
@@ -392,3 +394,94 @@ class TestKeepAlive:
send_email_test = mock.Mock()
+
+
[email protected]
+def aioresponse():
+ """
+ Creates an mock async API response.
+ This comes from a mock library specific to the aiohttp package:
+ https://github.com/pnuckowski/aioresponses
+
+ """
+ with aioresponses() as async_response:
+ yield async_response
+
+
[email protected]
+async def test_do_api_call_async_non_retryable_error(aioresponse):
+ """Test api call asynchronously with non retryable error."""
+ hook = HttpAsyncHook(method="GET")
+ aioresponse.get("http://httpbin.org/non_existent_endpoint", status=400)
+
+ with pytest.raises(AirflowException) as exc, mock.patch.dict(
+ "os.environ",
+ AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
+ ):
+ await hook.run(endpoint="non_existent_endpoint")
+
+ assert str(exc.value) == "400:Bad Request"
+
+
[email protected]
+async def test_do_api_call_async_retryable_error(caplog, aioresponse):
+ """Test api call asynchronously with retryable error."""
+ caplog.set_level(logging.WARNING,
logger="airflow.providers.http.hooks.http")
+ hook = HttpAsyncHook(method="GET")
+
+ with pytest.raises(AirflowException) as exc, mock.patch.dict(
+ "os.environ",
+ AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
+ ):
+ await hook.run(endpoint="non_existent_endpoint")
+
+ assert str(exc.value) == "500:Internal Server Error"
+ assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint
failed" in caplog.text
+
+
[email protected]
+async def test_do_api_call_async_unknown_method():
+ """Test api call asynchronously for unknown method."""
+ hook = HttpAsyncHook(method="NOPE")
+ json = {
+ "existing_cluster_id": "xxxx-xxxxxx-xxxxxx",
+ }
+
+ with pytest.raises(AirflowException) as exc:
+ await hook.run(endpoint="non_existent_endpoint", data=json)
+
+ assert str(exc.value) == "Unexpected HTTP Method: NOPE"
+
+
[email protected]
+async def test_async_post_request(aioresponse):
+ """Test api call asynchronously for POST request."""
+ hook = HttpAsyncHook()
+
+ aioresponse.post(
+ "http://test:8080/v1/test",
+ status=200,
+ payload='{"status":{"status": 200}}',
+ reason="OK",
+ )
+
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
+ resp = await hook.run("v1/test")
+ assert resp.status == 200
+
+
[email protected]
+async def test_async_post_request_with_error_code(aioresponse):
+ """Test api call asynchronously for POST request with error."""
+ hook = HttpAsyncHook()
+
+ aioresponse.post(
+ "http://test:8080/v1/test",
+ status=418,
+ payload='{"status":{"status": 418}}',
+ reason="I am teapot",
+ )
+
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection):
+ with pytest.raises(AirflowException):
+ await hook.run("v1/test")