This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 47edfe9a22 Add HttpHookAsync for deferrable implementation (#29038)
47edfe9a22 is described below

commit 47edfe9a22d1c521e49de3bed87bc332a48c0a80
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Tue Feb 14 19:21:14 2023 +0545

    Add HttpHookAsync for deferrable implementation (#29038)
    
    This PR donates the following HttpHookAsync that interacts with HTTP 
servers using Python Async. This was developed in astronomer-providers repo to 
apache airflow.
---
 airflow/providers/http/hooks/http.py    | 163 +++++++++++++++++++++++++++++++-
 airflow/providers/http/provider.yaml    |   2 +
 generated/provider_dependencies.json    |   2 +
 setup.py                                |   1 +
 tests/providers/http/hooks/test_http.py |  95 ++++++++++++++++++-
 5 files changed, 261 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/http/hooks/http.py 
b/airflow/providers/http/hooks/http.py
index 6c37818ccb..235a7f8a2e 100644
--- a/airflow/providers/http/hooks/http.py
+++ b/airflow/providers/http/hooks/http.py
@@ -17,16 +17,23 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Any, Callable
+import asyncio
+from typing import TYPE_CHECKING, Any, Callable
 
+import aiohttp
 import requests
 import tenacity
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
 from requests.auth import HTTPBasicAuth
 from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+if TYPE_CHECKING:
+    from aiohttp.client_reqrep import ClientResponse
+
 
 class HttpHook(BaseHook):
     """
@@ -246,3 +253,157 @@ class HttpHook(BaseHook):
             return True, "Connection successfully tested"
         except Exception as e:
             return False, str(e)
+
+
+class HttpAsyncHook(BaseHook):
+    """
+    Interact with HTTP servers using Python Async.
+
+    :param method: the API method to be called
+    :param http_conn_id: http connection id that has the base
+        API url i.e https://www.google.com/ and optional authentication 
credentials. Default
+        headers can also be specified in the Extra field in json format.
+    :param auth_type: The auth type for the service
+    """
+
+    conn_name_attr = "http_conn_id"
+    default_conn_name = "http_default"
+    conn_type = "http"
+    hook_name = "HTTP"
+
+    def __init__(
+        self,
+        method: str = "POST",
+        http_conn_id: str = default_conn_name,
+        auth_type: Any = aiohttp.BasicAuth,
+        retry_limit: int = 3,
+        retry_delay: float = 1.0,
+    ) -> None:
+        self.http_conn_id = http_conn_id
+        self.method = method.upper()
+        self.base_url: str = ""
+        self._retry_obj: Callable[..., Any]
+        self.auth_type: Any = auth_type
+        if retry_limit < 1:
+            raise ValueError("Retry limit must be greater than equal to 1")
+        self.retry_limit = retry_limit
+        self.retry_delay = retry_delay
+
+    async def run(
+        self,
+        endpoint: str | None = None,
+        data: dict[str, Any] | str | None = None,
+        headers: dict[str, Any] | None = None,
+        extra_options: dict[str, Any] | None = None,
+    ) -> "ClientResponse":
+        r"""
+        Performs an asynchronous HTTP request call
+
+        :param endpoint: the endpoint to be called i.e. resource/v1/query?
+        :param data: payload to be uploaded or request parameters
+        :param headers: additional headers to be passed through as a dictionary
+        :param extra_options: Additional kwargs to pass when creating a 
request.
+            For example, ``run(json=obj)`` is passed as 
``aiohttp.ClientSession().get(json=obj)``
+        """
+        extra_options = extra_options or {}
+
+        # headers may be passed through directly or in the "extra" field in 
the connection
+        # definition
+        _headers = {}
+        auth = None
+
+        if self.http_conn_id:
+            conn = await sync_to_async(self.get_connection)(self.http_conn_id)
+
+            if conn.host and "://" in conn.host:
+                self.base_url = conn.host
+            else:
+                # schema defaults to HTTP
+                schema = conn.schema if conn.schema else "http"
+                host = conn.host if conn.host else ""
+                self.base_url = schema + "://" + host
+
+            if conn.port:
+                self.base_url = self.base_url + ":" + str(conn.port)
+            if conn.login:
+                auth = self.auth_type(conn.login, conn.password)
+            if conn.extra:
+                try:
+                    _headers.update(conn.extra_dejson)
+                except TypeError:
+                    self.log.warning("Connection to %s has invalid extra 
field.", conn.host)
+        if headers:
+            _headers.update(headers)
+
+        if self.base_url and not self.base_url.endswith("/") and endpoint and 
not endpoint.startswith("/"):
+            url = self.base_url + "/" + endpoint
+        else:
+            url = (self.base_url or "") + (endpoint or "")
+
+        async with aiohttp.ClientSession() as session:
+            if self.method == "GET":
+                request_func = session.get
+            elif self.method == "POST":
+                request_func = session.post
+            elif self.method == "PATCH":
+                request_func = session.patch
+            elif self.method == "HEAD":
+                request_func = session.head
+            elif self.method == "PUT":
+                request_func = session.put
+            elif self.method == "DELETE":
+                request_func = session.delete
+            elif self.method == "OPTIONS":
+                request_func = session.options
+            else:
+                raise AirflowException(f"Unexpected HTTP Method: 
{self.method}")
+
+            attempt_num = 1
+            while True:
+                response = await request_func(
+                    url,
+                    json=data if self.method in ("POST", "PATCH") else None,
+                    params=data if self.method == "GET" else None,
+                    headers=headers,
+                    auth=auth,
+                    **extra_options,
+                )
+                try:
+                    response.raise_for_status()
+                    return response
+                except ClientResponseError as e:
+                    self.log.warning(
+                        "[Try %d of %d] Request to %s failed.",
+                        attempt_num,
+                        self.retry_limit,
+                        url,
+                    )
+                    if not self._retryable_error_async(e) or attempt_num == 
self.retry_limit:
+                        self.log.exception("HTTP error with status: %s", 
e.status)
+                        # In this case, the user probably made a mistake.
+                        # Don't retry.
+                        raise AirflowException(f"{e.status}:{e.message}")
+
+                attempt_num += 1
+                await asyncio.sleep(self.retry_delay)
+
+    def _retryable_error_async(self, exception: ClientResponseError) -> bool:
+        """
+        Determines whether or not an exception that was thrown might be 
successful
+        on a subsequent attempt.
+
+        It considers the following to be retryable:
+            - requests_exceptions.ConnectionError
+            - requests_exceptions.Timeout
+            - anything with a status code >= 500
+
+        Most retryable errors are covered by status code >= 500.
+        """
+        if exception.status == 429:
+            # don't retry for too Many Requests
+            return False
+        if exception.status == 413:
+            # don't retry for payload Too Large
+            return False
+
+        return exception.status >= 500
diff --git a/airflow/providers/http/provider.yaml 
b/airflow/providers/http/provider.yaml
index 4c9f6c3e77..3294ac56ec 100644
--- a/airflow/providers/http/provider.yaml
+++ b/airflow/providers/http/provider.yaml
@@ -42,6 +42,8 @@ dependencies:
   # release it as a requirement for airflow
   - requests>=2.26.0
   - requests_toolbelt
+  - aiohttp
+  - asgiref
 
 integrations:
   - integration-name: Hypertext Transfer Protocol (HTTP)
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index b3fde3a3a8..4c3c62fe9a 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -415,6 +415,8 @@
   },
   "http": {
     "deps": [
+      "aiohttp",
+      "asgiref",
       "requests>=2.26.0",
       "requests_toolbelt"
     ],
diff --git a/setup.py b/setup.py
index 59397c1196..450062cb7f 100644
--- a/setup.py
+++ b/setup.py
@@ -404,6 +404,7 @@ devel_only = [
     "twine",
     "wheel",
     "yamllint",
+    "aioresponses",
 ]
 
 
diff --git a/tests/providers/http/hooks/test_http.py 
b/tests/providers/http/hooks/test_http.py
index 03bf09075c..820c45c5a3 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+import logging
 import os
 from collections import OrderedDict
 from http import HTTPStatus
@@ -26,11 +27,12 @@ from unittest import mock
 import pytest
 import requests
 import tenacity
+from aioresponses import aioresponses
 from requests.adapters import Response
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
-from airflow.providers.http.hooks.http import HttpHook
+from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
 
 
 def get_airflow_connection(unused_conn_id=None):
@@ -392,3 +394,94 @@ class TestKeepAlive:
 
 
 send_email_test = mock.Mock()
+
+
[email protected]
+def aioresponse():
+    """
+    Creates an mock async API response.
+    This comes from a mock library specific to the aiohttp package:
+    https://github.com/pnuckowski/aioresponses
+
+    """
+    with aioresponses() as async_response:
+        yield async_response
+
+
[email protected]
+async def test_do_api_call_async_non_retryable_error(aioresponse):
+    """Test api call asynchronously with non retryable error."""
+    hook = HttpAsyncHook(method="GET")
+    aioresponse.get("http://httpbin.org/non_existent_endpoint";, status=400)
+
+    with pytest.raises(AirflowException) as exc, mock.patch.dict(
+        "os.environ",
+        AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
+    ):
+        await hook.run(endpoint="non_existent_endpoint")
+
+    assert str(exc.value) == "400:Bad Request"
+
+
[email protected]
+async def test_do_api_call_async_retryable_error(caplog, aioresponse):
+    """Test api call asynchronously with retryable error."""
+    caplog.set_level(logging.WARNING, 
logger="airflow.providers.http.hooks.http")
+    hook = HttpAsyncHook(method="GET")
+
+    with pytest.raises(AirflowException) as exc, mock.patch.dict(
+        "os.environ",
+        AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
+    ):
+        await hook.run(endpoint="non_existent_endpoint")
+
+    assert str(exc.value) == "500:Internal Server Error"
+    assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint 
failed" in caplog.text
+
+
[email protected]
+async def test_do_api_call_async_unknown_method():
+    """Test api call asynchronously for unknown method."""
+    hook = HttpAsyncHook(method="NOPE")
+    json = {
+        "existing_cluster_id": "xxxx-xxxxxx-xxxxxx",
+    }
+
+    with pytest.raises(AirflowException) as exc:
+        await hook.run(endpoint="non_existent_endpoint", data=json)
+
+    assert str(exc.value) == "Unexpected HTTP Method: NOPE"
+
+
[email protected]
+async def test_async_post_request(aioresponse):
+    """Test api call asynchronously for POST request."""
+    hook = HttpAsyncHook()
+
+    aioresponse.post(
+        "http://test:8080/v1/test";,
+        status=200,
+        payload='{"status":{"status": 200}}',
+        reason="OK",
+    )
+
+    with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+        resp = await hook.run("v1/test")
+        assert resp.status == 200
+
+
[email protected]
+async def test_async_post_request_with_error_code(aioresponse):
+    """Test api call asynchronously for POST request with error."""
+    hook = HttpAsyncHook()
+
+    aioresponse.post(
+        "http://test:8080/v1/test";,
+        status=418,
+        payload='{"status":{"status": 418}}',
+        reason="I am teapot",
+    )
+
+    with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+        with pytest.raises(AirflowException):
+            await hook.run("v1/test")

Reply via email to