This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9596bbdafa Add deferrable implementation in HTTPSensor (#36904)
9596bbdafa is described below
commit 9596bbdafa0cbe7e3c7d7181c98fa241041db3af
Author: vatsrahul1001 <[email protected]>
AuthorDate: Tue Jan 23 11:47:58 2024 +0530
Add deferrable implementation in HTTPSensor (#36904)
Co-authored-by: Wei Lee <[email protected]>
---
airflow/providers/http/sensors/http.py | 31 ++++++++++
airflow/providers/http/triggers/http.py | 72 ++++++++++++++++++++++++
docs/apache-airflow-providers-http/operators.rst | 7 +++
tests/providers/http/sensors/test_http.py | 54 +++++++++++++++++-
tests/system/providers/http/example_http.py | 19 ++++++-
5 files changed, 181 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/http/sensors/http.py
b/airflow/providers/http/sensors/http.py
index 41c2fc2001..3691764333 100644
--- a/airflow/providers/http/sensors/http.py
+++ b/airflow/providers/http/sensors/http.py
@@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence
+from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.http.hooks.http import HttpHook
+from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -78,6 +81,8 @@ class HttpSensor(BaseSensorOperator):
:param tcp_keep_alive_count: The TCP Keep Alive count parameter
(corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter
(corresponds to
``socket.TCP_KEEPINTVL``)
+ :param deferrable: If waiting for completion, whether to defer the task
until done,
+ default is ``False``
"""
template_fields: Sequence[str] = ("endpoint", "request_params", "headers")
@@ -97,6 +102,7 @@ class HttpSensor(BaseSensorOperator):
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
tcp_keep_alive_interval: int = 30,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -114,6 +120,7 @@ class HttpSensor(BaseSensorOperator):
self.tcp_keep_alive_idle = tcp_keep_alive_idle
self.tcp_keep_alive_count = tcp_keep_alive_count
self.tcp_keep_alive_interval = tcp_keep_alive_interval
+ self.deferrable = deferrable
def poke(self, context: Context) -> bool:
from airflow.utils.operator_helpers import determine_kwargs
@@ -135,9 +142,12 @@ class HttpSensor(BaseSensorOperator):
headers=self.headers,
extra_options=self.extra_options,
)
+
if self.response_check:
kwargs = determine_kwargs(self.response_check, [response],
context)
+
return self.response_check(response, **kwargs)
+
except AirflowException as exc:
if str(exc).startswith(self.response_error_codes_allowlist):
return False
@@ -148,3 +158,24 @@ class HttpSensor(BaseSensorOperator):
raise exc
return True
+
+ def execute(self, context: Context) -> None:
+ if not self.deferrable or self.response_check:
+ super().execute(context=context)
+ elif not self.poke(context):
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=HttpSensorTrigger(
+ endpoint=self.endpoint,
+ http_conn_id=self.http_conn_id,
+ data=self.request_params,
+ headers=self.headers,
+ method=self.method,
+ extra_options=self.extra_options,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ self.log.info("%s completed successfully.", self.task_id)
diff --git a/airflow/providers/http/triggers/http.py
b/airflow/providers/http/triggers/http.py
index b4598984f3..52b76a0104 100644
--- a/airflow/providers/http/triggers/http.py
+++ b/airflow/providers/http/triggers/http.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import asyncio
import base64
import pickle
from typing import TYPE_CHECKING, Any, AsyncIterator
@@ -24,6 +25,7 @@ import requests
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
+from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -124,3 +126,73 @@ class HttpTrigger(BaseTrigger):
cookies.set(k, v)
response.cookies = cookies
return response
+
+
+class HttpSensorTrigger(BaseTrigger):
+ """
+ A trigger that fires when the request to a URL returns a non-404 status
code.
+
+ :param endpoint: The relative part of the full url
+ :param http_conn_id: The HTTP Connection ID to run the sensor against
+ :param method: The HTTP request method to use
+ :param data: payload to be uploaded or aiohttp parameters
+ :param headers: The HTTP headers to be added to the GET request
+ :param extra_options: Additional kwargs to pass when creating a request.
+ For example, ``run(json=obj)`` is passed as
``aiohttp.ClientSession().get(json=obj)``
+ :param poke_interval: Time to sleep using asyncio
+ """
+
+ def __init__(
+ self,
+ endpoint: str | None = None,
+ http_conn_id: str = "http_default",
+ method: str = "GET",
+ data: dict[str, Any] | str | None = None,
+ headers: dict[str, str] | None = None,
+ extra_options: dict[str, Any] | None = None,
+ poke_interval: float = 5.0,
+ ):
+ super().__init__()
+ self.endpoint = endpoint
+ self.method = method
+ self.data = data
+ self.headers = headers
+ self.extra_options = extra_options or {}
+ self.http_conn_id = http_conn_id
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes HttpTrigger arguments and classpath."""
+ return (
+ "airflow.providers.http.triggers.http.HttpSensorTrigger",
+ {
+ "endpoint": self.endpoint,
+ "data": self.data,
+ "headers": self.headers,
+ "extra_options": self.extra_options,
+ "http_conn_id": self.http_conn_id,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Makes a series of asynchronous http calls via an http hook."""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ await hook.run(
+ endpoint=self.endpoint,
+ data=self.data,
+ headers=self.headers,
+ extra_options=self.extra_options,
+ )
+ yield TriggerEvent(True)
+ except AirflowException as exc:
+ if str(exc).startswith("404"):
+ await asyncio.sleep(self.poke_interval)
+
+ def _get_async_hook(self) -> HttpAsyncHook:
+ return HttpAsyncHook(
+ method=self.method,
+ http_conn_id=self.http_conn_id,
+ )
diff --git a/docs/apache-airflow-providers-http/operators.rst
b/docs/apache-airflow-providers-http/operators.rst
index 944a293120..3f52ca0a62 100644
--- a/docs/apache-airflow-providers-http/operators.rst
+++ b/docs/apache-airflow-providers-http/operators.rst
@@ -37,6 +37,13 @@ Here we are poking until httpbin gives us a response text
containing ``httpbin``
:start-after: [START howto_operator_http_http_sensor_check]
:end-before: [END howto_operator_http_http_sensor_check]
+This sensor can also be used in deferrable mode
+
+.. exampleinclude:: /../../tests/system/providers/http/example_http.py
+ :language: python
+ :start-after: [START howto_operator_http_http_sensor_check_deferrable]
+ :end-before: [END howto_operator_http_http_sensor_check_deferrable]
+
.. _howto/operator:HttpOperator:
HttpOperator
diff --git a/tests/providers/http/sensors/test_http.py
b/tests/providers/http/sensors/test_http.py
index f842ea91fc..4e95c84405 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -23,10 +23,11 @@ from unittest.mock import patch
import pytest
import requests
-from airflow.exceptions import AirflowException, AirflowSensorTimeout,
AirflowSkipException
+from airflow.exceptions import AirflowException, AirflowSensorTimeout,
AirflowSkipException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
+from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.utils.timezone import datetime
pytestmark = pytest.mark.db_test
@@ -330,3 +331,54 @@ class TestHttpOpSensor:
dag=self.dag,
)
sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+
+
+class TestHttpSensorAsync:
+ @mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
+ @mock.patch(
+ "airflow.providers.http.sensors.http.HttpSensor.poke",
+ return_value=True,
+ )
+ def test_execute_finished_before_deferred(
+ self,
+ mock_poke,
+ mock_defer,
+ ):
+ """
+ Asserts that a task is not deferred when task is already finished
+ """
+
+ task = HttpSensor(task_id="run_now", endpoint="test-endpoint",
deferrable=True)
+
+ task.execute({})
+ assert not mock_defer.called
+
+ @mock.patch(
+ "airflow.providers.http.sensors.http.HttpSensor.poke",
+ return_value=False,
+ )
+ def test_execute_is_deferred(self, mock_poke):
+ """
+ Asserts that a task is deferred and a HttpTrigger will be fired
+ when the HttpSensor is executed in deferrable mode.
+ """
+
+ task = HttpSensor(task_id="run_now", endpoint="test-endpoint",
deferrable=True)
+
+ with pytest.raises(TaskDeferred) as exc:
+ task.execute({})
+
+ assert isinstance(exc.value.trigger, HttpSensorTrigger), "Trigger is
not a HttpTrigger"
+
+ @mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
+ @mock.patch("airflow.sensors.base.BaseSensorOperator.execute")
+ def test_execute_not_defer_when_response_check_is_not_none(self,
mock_execute, mock_defer):
+ task = HttpSensor(
+ task_id="run_now",
+ endpoint="test-endpoint",
+ response_check=lambda response: "httpbin" in response.text,
+ deferrable=True,
+ )
+ task.execute({})
+ mock_execute.assert_called_once()
+ mock_defer.assert_not_called()
diff --git a/tests/system/providers/http/example_http.py
b/tests/system/providers/http/example_http.py
index 3cd52edf84..5c99c2bcba 100644
--- a/tests/system/providers/http/example_http.py
+++ b/tests/system/providers/http/example_http.py
@@ -110,6 +110,17 @@ task_http_sensor_check = HttpSensor(
dag=dag,
)
# [END howto_operator_http_http_sensor_check]
+# [START howto_operator_http_http_sensor_check_deferrable]
+task_http_sensor_check_async = HttpSensor(
+ task_id="http_sensor_check_async",
+ http_conn_id="http_default",
+ endpoint="",
+ deferrable=True,
+ request_params={},
+ poke_interval=5,
+ dag=dag,
+)
+# [END howto_operator_http_http_sensor_check_deferrable]
# [START howto_operator_http_pagination_function]
@@ -134,7 +145,13 @@ task_get_paginated = HttpOperator(
dag=dag,
)
# [END howto_operator_http_pagination_function]
-task_http_sensor_check >> task_post_op >> task_get_op >>
task_get_op_response_filter
+(
+ task_http_sensor_check
+ >> task_http_sensor_check_async
+ >> task_post_op
+ >> task_get_op
+ >> task_get_op_response_filter
+)
task_get_op_response_filter >> task_put_op >> task_del_op >>
task_post_op_formenc
task_post_op_formenc >> task_get_paginated