This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9596bbdafa Add deferrable implementation in HTTPSensor (#36904)
9596bbdafa is described below

commit 9596bbdafa0cbe7e3c7d7181c98fa241041db3af
Author: vatsrahul1001 <[email protected]>
AuthorDate: Tue Jan 23 11:47:58 2024 +0530

    Add deferrable implementation in HTTPSensor (#36904)
    
    
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/providers/http/sensors/http.py           | 31 ++++++++++
 airflow/providers/http/triggers/http.py          | 72 ++++++++++++++++++++++++
 docs/apache-airflow-providers-http/operators.rst |  7 +++
 tests/providers/http/sensors/test_http.py        | 54 +++++++++++++++++-
 tests/system/providers/http/example_http.py      | 19 ++++++-
 5 files changed, 181 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/http/sensors/http.py 
b/airflow/providers/http/sensors/http.py
index 41c2fc2001..3691764333 100644
--- a/airflow/providers/http/sensors/http.py
+++ b/airflow/providers/http/sensors/http.py
@@ -17,10 +17,13 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
 from typing import TYPE_CHECKING, Any, Callable, Sequence
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.http.hooks.http import HttpHook
+from airflow.providers.http.triggers.http import HttpSensorTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -78,6 +81,8 @@ class HttpSensor(BaseSensorOperator):
     :param tcp_keep_alive_count: The TCP Keep Alive count parameter 
(corresponds to ``socket.TCP_KEEPCNT``)
     :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter 
(corresponds to
         ``socket.TCP_KEEPINTVL``)
+    :param deferrable: If waiting for completion, whether to defer the task 
until done,
+        default is ``False``
     """
 
     template_fields: Sequence[str] = ("endpoint", "request_params", "headers")
@@ -97,6 +102,7 @@ class HttpSensor(BaseSensorOperator):
         tcp_keep_alive_idle: int = 120,
         tcp_keep_alive_count: int = 20,
         tcp_keep_alive_interval: int = 30,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -114,6 +120,7 @@ class HttpSensor(BaseSensorOperator):
         self.tcp_keep_alive_idle = tcp_keep_alive_idle
         self.tcp_keep_alive_count = tcp_keep_alive_count
         self.tcp_keep_alive_interval = tcp_keep_alive_interval
+        self.deferrable = deferrable
 
     def poke(self, context: Context) -> bool:
         from airflow.utils.operator_helpers import determine_kwargs
@@ -135,9 +142,12 @@ class HttpSensor(BaseSensorOperator):
                 headers=self.headers,
                 extra_options=self.extra_options,
             )
+
             if self.response_check:
                 kwargs = determine_kwargs(self.response_check, [response], 
context)
+
                 return self.response_check(response, **kwargs)
+
         except AirflowException as exc:
             if str(exc).startswith(self.response_error_codes_allowlist):
                 return False
@@ -148,3 +158,24 @@ class HttpSensor(BaseSensorOperator):
             raise exc
 
         return True
+
+    def execute(self, context: Context) -> None:
+        if not self.deferrable or self.response_check:
+            super().execute(context=context)
+        elif not self.poke(context):
+            self.defer(
+                timeout=timedelta(seconds=self.timeout),
+                trigger=HttpSensorTrigger(
+                    endpoint=self.endpoint,
+                    http_conn_id=self.http_conn_id,
+                    data=self.request_params,
+                    headers=self.headers,
+                    method=self.method,
+                    extra_options=self.extra_options,
+                    poke_interval=self.poke_interval,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        self.log.info("%s completed successfully.", self.task_id)
diff --git a/airflow/providers/http/triggers/http.py 
b/airflow/providers/http/triggers/http.py
index b4598984f3..52b76a0104 100644
--- a/airflow/providers/http/triggers/http.py
+++ b/airflow/providers/http/triggers/http.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 import base64
 import pickle
 from typing import TYPE_CHECKING, Any, AsyncIterator
@@ -24,6 +25,7 @@ import requests
 from requests.cookies import RequestsCookieJar
 from requests.structures import CaseInsensitiveDict
 
+from airflow.exceptions import AirflowException
 from airflow.providers.http.hooks.http import HttpAsyncHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
@@ -124,3 +126,73 @@ class HttpTrigger(BaseTrigger):
             cookies.set(k, v)
         response.cookies = cookies
         return response
+
+
+class HttpSensorTrigger(BaseTrigger):
+    """
+    A trigger that fires when the request to a URL returns a non-404 status 
code.
+
+    :param endpoint: The relative part of the full url
+    :param http_conn_id: The HTTP Connection ID to run the sensor against
+    :param method: The HTTP request method to use
+    :param data: payload to be uploaded or aiohttp parameters
+    :param headers: The HTTP headers to be added to the GET request
+    :param extra_options: Additional kwargs to pass when creating a request.
+        For example, ``run(json=obj)`` is passed as 
``aiohttp.ClientSession().get(json=obj)``
+    :param poke_interval: Time to sleep using asyncio
+    """
+
+    def __init__(
+        self,
+        endpoint: str | None = None,
+        http_conn_id: str = "http_default",
+        method: str = "GET",
+        data: dict[str, Any] | str | None = None,
+        headers: dict[str, str] | None = None,
+        extra_options: dict[str, Any] | None = None,
+        poke_interval: float = 5.0,
+    ):
+        super().__init__()
+        self.endpoint = endpoint
+        self.method = method
+        self.data = data
+        self.headers = headers
+        self.extra_options = extra_options or {}
+        self.http_conn_id = http_conn_id
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes HttpTrigger arguments and classpath."""
+        return (
+            "airflow.providers.http.triggers.http.HttpSensorTrigger",
+            {
+                "endpoint": self.endpoint,
+                "data": self.data,
+                "headers": self.headers,
+                "extra_options": self.extra_options,
+                "http_conn_id": self.http_conn_id,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Makes a series of asynchronous http calls via an http hook."""
+        hook = self._get_async_hook()
+        while True:
+            try:
+                await hook.run(
+                    endpoint=self.endpoint,
+                    data=self.data,
+                    headers=self.headers,
+                    extra_options=self.extra_options,
+                )
+                yield TriggerEvent(True)
+            except AirflowException as exc:
+                if str(exc).startswith("404"):
+                    await asyncio.sleep(self.poke_interval)
+
+    def _get_async_hook(self) -> HttpAsyncHook:
+        return HttpAsyncHook(
+            method=self.method,
+            http_conn_id=self.http_conn_id,
+        )
diff --git a/docs/apache-airflow-providers-http/operators.rst 
b/docs/apache-airflow-providers-http/operators.rst
index 944a293120..3f52ca0a62 100644
--- a/docs/apache-airflow-providers-http/operators.rst
+++ b/docs/apache-airflow-providers-http/operators.rst
@@ -37,6 +37,13 @@ Here we are poking until httpbin gives us a response text 
containing ``httpbin``
     :start-after: [START howto_operator_http_http_sensor_check]
     :end-before: [END howto_operator_http_http_sensor_check]
 
+This sensor can also be used in deferrable mode
+
+.. exampleinclude:: /../../tests/system/providers/http/example_http.py
+    :language: python
+    :start-after: [START howto_operator_http_http_sensor_check_deferrable]
+    :end-before: [END howto_operator_http_http_sensor_check_deferrable]
+
 .. _howto/operator:HttpOperator:
 
 HttpOperator
diff --git a/tests/providers/http/sensors/test_http.py 
b/tests/providers/http/sensors/test_http.py
index f842ea91fc..4e95c84405 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -23,10 +23,11 @@ from unittest.mock import patch
 import pytest
 import requests
 
-from airflow.exceptions import AirflowException, AirflowSensorTimeout, 
AirflowSkipException
+from airflow.exceptions import AirflowException, AirflowSensorTimeout, 
AirflowSkipException, TaskDeferred
 from airflow.models.dag import DAG
 from airflow.providers.http.operators.http import HttpOperator
 from airflow.providers.http.sensors.http import HttpSensor
+from airflow.providers.http.triggers.http import HttpSensorTrigger
 from airflow.utils.timezone import datetime
 
 pytestmark = pytest.mark.db_test
@@ -330,3 +331,54 @@ class TestHttpOpSensor:
             dag=self.dag,
         )
         sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+
+class TestHttpSensorAsync:
+    @mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
+    @mock.patch(
+        "airflow.providers.http.sensors.http.HttpSensor.poke",
+        return_value=True,
+    )
+    def test_execute_finished_before_deferred(
+        self,
+        mock_poke,
+        mock_defer,
+    ):
+        """
+        Asserts that a task is not deferred when task is already finished
+        """
+
+        task = HttpSensor(task_id="run_now", endpoint="test-endpoint", 
deferrable=True)
+
+        task.execute({})
+        assert not mock_defer.called
+
+    @mock.patch(
+        "airflow.providers.http.sensors.http.HttpSensor.poke",
+        return_value=False,
+    )
+    def test_execute_is_deferred(self, mock_poke):
+        """
+        Asserts that a task is deferred and a HttpTrigger will be fired
+        when the HttpSensor is executed in deferrable mode.
+        """
+
+        task = HttpSensor(task_id="run_now", endpoint="test-endpoint", 
deferrable=True)
+
+        with pytest.raises(TaskDeferred) as exc:
+            task.execute({})
+
+        assert isinstance(exc.value.trigger, HttpSensorTrigger), "Trigger is 
not a HttpTrigger"
+
+    @mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
+    @mock.patch("airflow.sensors.base.BaseSensorOperator.execute")
+    def test_execute_not_defer_when_response_check_is_not_none(self, 
mock_execute, mock_defer):
+        task = HttpSensor(
+            task_id="run_now",
+            endpoint="test-endpoint",
+            response_check=lambda response: "httpbin" in response.text,
+            deferrable=True,
+        )
+        task.execute({})
+        mock_execute.assert_called_once()
+        mock_defer.assert_not_called()
diff --git a/tests/system/providers/http/example_http.py 
b/tests/system/providers/http/example_http.py
index 3cd52edf84..5c99c2bcba 100644
--- a/tests/system/providers/http/example_http.py
+++ b/tests/system/providers/http/example_http.py
@@ -110,6 +110,17 @@ task_http_sensor_check = HttpSensor(
     dag=dag,
 )
 # [END howto_operator_http_http_sensor_check]
+# [START howto_operator_http_http_sensor_check_deferrable]
+task_http_sensor_check_async = HttpSensor(
+    task_id="http_sensor_check_async",
+    http_conn_id="http_default",
+    endpoint="",
+    deferrable=True,
+    request_params={},
+    poke_interval=5,
+    dag=dag,
+)
+# [END howto_operator_http_http_sensor_check_deferrable]
 # [START howto_operator_http_pagination_function]
 
 
@@ -134,7 +145,13 @@ task_get_paginated = HttpOperator(
     dag=dag,
 )
 # [END howto_operator_http_pagination_function]
-task_http_sensor_check >> task_post_op >> task_get_op >> 
task_get_op_response_filter
+(
+    task_http_sensor_check
+    >> task_http_sensor_check_async
+    >> task_post_op
+    >> task_get_op
+    >> task_get_op_response_filter
+)
 task_get_op_response_filter >> task_put_op >> task_del_op >> 
task_post_op_formenc
 task_post_op_formenc >> task_get_paginated
 

Reply via email to