This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ca061ddf5f Deferrable sensors can implement sensor timeout (#33718)
8ca061ddf5f is described below

commit 8ca061ddf5fb85c79b1212ca29112190ebb0aab5
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Dec 3 14:39:34 2024 -0800

    Deferrable sensors can implement sensor timeout (#33718)
    
    The goal here is to ensure behavioral parity w.r.t. sensor timeouts between 
deferrable and non-deferrable sensor operators.
    
    With non-deferrable sensors, if there's a sensor timeout, the task fails 
without retry.  But currently, with deferrable sensors, that does not happen.
    
    Since there's already a "timeout" capability on triggers, we can use this 
for sensor timeout.  Essentially all that was missing was the ability to 
distinguish between trigger timeouts and other trigger errors.  With this 
capability, base sensor can distinguish between the two, and reraise deferral 
timeouts as sensor timeouts.
    
    So, here we add a new exception type, TaskDeferralTimeout, which base 
sensor reraises as AirflowSensorTimeout. Then, to take advantage of this 
feature, a sensor need only ensure that its timeout is passed when deferring. 
For convenience, we update the task deferred exception signature to take int 
and float in addition to timedelta, since that's how `timeout` attr is defined 
on base sensor.  But we do not change the exception attribute type.
    
    In order to keep this PR focused, this PR only updates one sensor to use 
the timeout functionality, namely, time delta sensor.  Other sensors will have 
to be done as followups.
---
 airflow/exceptions.py                              | 14 ++++++++---
 airflow/jobs/scheduler_job_runner.py               |  5 ++--
 airflow/models/baseoperator.py                     | 11 +++++---
 airflow/models/taskinstance.py                     |  1 +
 airflow/models/trigger.py                          | 29 ++++++++++++++++++++--
 airflow/sensors/base.py                            |  5 +++-
 .../providers/standard/sensors/time_delta.py       | 21 +++++++++++++++-
 tests/models/test_baseoperator.py                  | 12 ++++++++-
 tests/sensors/test_base.py                         | 10 ++++++++
 9 files changed, 95 insertions(+), 13 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index fee0b5a671d..4035488cf87 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -22,13 +22,13 @@
 from __future__ import annotations
 
 import warnings
+from datetime import timedelta
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Any, NamedTuple
 
 from airflow.utils.trigger_rule import TriggerRule
 
 if TYPE_CHECKING:
-    import datetime
     from collections.abc import Sized
 
     from airflow.models import DagRun
@@ -385,14 +385,18 @@ class TaskDeferred(BaseException):
         trigger,
         method_name: str,
         kwargs: dict[str, Any] | None = None,
-        timeout: datetime.timedelta | None = None,
+        timeout: timedelta | int | float | None = None,
     ):
         super().__init__()
         self.trigger = trigger
         self.method_name = method_name
         self.kwargs = kwargs
-        self.timeout = timeout
+        self.timeout: timedelta | None
         # Check timeout type at runtime
+        if isinstance(timeout, (int, float)):
+            self.timeout = timedelta(seconds=timeout)
+        else:
+            self.timeout = timeout
         if self.timeout is not None and not hasattr(self.timeout, 
"total_seconds"):
             raise ValueError("Timeout value must be a timedelta")
 
@@ -417,6 +421,10 @@ class TaskDeferralError(AirflowException):
     """Raised when a task failed during deferral for some reason."""
 
 
+class TaskDeferralTimeout(AirflowException):
+    """Raise when there is a timeout on the deferral."""
+
+
 # The try/except handling is needed after we moved all k8s classes to 
cncf.kubernetes provider
 # These two exceptions are used internally by Kubernetes Executor but also by 
PodGenerator, so we need
 # to leave them here in case older version of cncf.kubernetes provider is used 
to run KubernetesPodOperator
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 56a65009e2b..0dd6b32f741 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -65,6 +65,7 @@ from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun
 from airflow.models.dagwarning import DagWarning, DagWarningType
 from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
 from airflow.stats import Stats
 from airflow.ti_deps.dependencies_states import EXECUTION_STATES
 from airflow.timetables.simple import AssetTriggeredTimetable
@@ -2057,8 +2058,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     )
                     .values(
                         state=TaskInstanceState.SCHEDULED,
-                        next_method="__fail__",
-                        next_kwargs={"error": "Trigger/execution timeout"},
+                        next_method=TRIGGER_FAIL_REPR,
+                        next_kwargs={"error": 
TriggerFailureReason.TRIGGER_TIMEOUT},
                         trigger_id=None,
                     )
                 ).rowcount
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 13eb787b4f8..512eb189cc9 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -50,6 +50,7 @@ from airflow.configuration import conf
 from airflow.exceptions import (
     AirflowException,
     TaskDeferralError,
+    TaskDeferralTimeout,
     TaskDeferred,
 )
 from airflow.lineage import apply_lineage, prepare_lineage
@@ -72,6 +73,7 @@ from airflow.models.base import _sentinel
 from airflow.models.mappedoperator import OperatorPartial, 
validate_mapping_kwargs
 from airflow.models.taskinstance import TaskInstance, clear_task_instances
 from airflow.models.taskmixin import DependencyMixin
+from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
 from airflow.sdk.definitions.baseoperator import (
     BaseOperatorMeta as TaskSDKBaseOperatorMeta,
     get_merged_defaults,
@@ -973,7 +975,7 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
         trigger: BaseTrigger,
         method_name: str,
         kwargs: dict[str, Any] | None = None,
-        timeout: timedelta | None = None,
+        timeout: timedelta | int | float | None = None,
     ) -> NoReturn:
         """
         Mark this Operator "deferred", suspending its execution until the 
provided trigger fires an event.
@@ -990,12 +992,15 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, 
metaclass=BaseOperator
         """Call this method when a deferred task is resumed."""
         # __fail__ is a special signal value for next_method that indicates
         # this task was scheduled specifically to fail.
-        if next_method == "__fail__":
+        if next_method == TRIGGER_FAIL_REPR:
             next_kwargs = next_kwargs or {}
             traceback = next_kwargs.get("traceback")
             if traceback is not None:
                 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
-            raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
+            if (error := next_kwargs.get("error", "Unknown")) == 
TriggerFailureReason.TRIGGER_TIMEOUT:
+                raise TaskDeferralTimeout(error)
+            else:
+                raise TaskDeferralError(error)
         # Grab the callable off the Operator/Task and add in any kwargs
         execute_callable = getattr(self, next_method)
         if next_kwargs:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d6b24f34000..705cc797ed1 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1538,6 +1538,7 @@ def _defer_task(
 ) -> TaskInstance:
     from airflow.models.trigger import Trigger
 
+    timeout: timedelta | None
     if exception is not None:
         trigger_row = Trigger.from_object(exception.trigger)
         next_method = exception.method_name
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index b7b6ba9980d..f56512cdbc1 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import datetime
 from collections.abc import Iterable
+from enum import Enum
 from traceback import format_exception
 from typing import TYPE_CHECKING, Any
 
@@ -40,6 +41,27 @@ if TYPE_CHECKING:
 
     from airflow.triggers.base import BaseTrigger
 
+TRIGGER_FAIL_REPR = "__fail__"
+"""String value to represent trigger failure.
+
+Internal use only.
+
+:meta private:
+"""
+
+
+class TriggerFailureReason(str, Enum):
+    """
+    Reasons for trigger failures.
+
+    Internal use only.
+
+    :meta private:
+    """
+
+    TRIGGER_TIMEOUT = "Trigger timeout"
+    TRIGGER_FAILURE = "Trigger failure"
+
 
 class Trigger(Base):
     """
@@ -229,8 +251,11 @@ class Trigger(Base):
         ):
             # Add the error and set the next_method to the fail state
             traceback = format_exception(type(exc), exc, exc.__traceback__) if 
exc else None
-            task_instance.next_method = "__fail__"
-            task_instance.next_kwargs = {"error": "Trigger failure", 
"traceback": traceback}
+            task_instance.next_method = TRIGGER_FAIL_REPR
+            task_instance.next_kwargs = {
+                "error": TriggerFailureReason.TRIGGER_FAILURE,
+                "traceback": traceback,
+            }
             # Remove ourselves as its trigger
             task_instance.trigger_id = None
             # Finally, mark it as scheduled so it gets re-queued
diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index a593a4519f1..e8d89b2365b 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -38,6 +38,7 @@ from airflow.exceptions import (
     AirflowSkipException,
     AirflowTaskTimeout,
     TaskDeferralError,
+    TaskDeferralTimeout,
 )
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.models.baseoperator import BaseOperator
@@ -174,7 +175,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
         super().__init__(**kwargs)
         self.poke_interval = 
self._coerce_poke_interval(poke_interval).total_seconds()
         self.soft_fail = soft_fail
-        self.timeout = self._coerce_timeout(timeout).total_seconds()
+        self.timeout: int | float = 
self._coerce_timeout(timeout).total_seconds()
         self.mode = mode
         self.exponential_backoff = exponential_backoff
         self.max_wait = self._coerce_max_wait(max_wait)
@@ -338,6 +339,8 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
     def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | 
None, context: Context):
         try:
             return super().resume_execution(next_method, next_kwargs, context)
+        except TaskDeferralTimeout as e:
+            raise AirflowSensorTimeout(*e.args) from e
         except (AirflowException, TaskDeferralError) as e:
             if self.soft_fail:
                 raise AirflowSkipException(str(e)) from e
diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py 
b/providers/src/airflow/providers/standard/sensors/time_delta.py
index a0d3189b027..8e0f26ac249 100644
--- a/providers/src/airflow/providers/standard/sensors/time_delta.py
+++ b/providers/src/airflow/providers/standard/sensors/time_delta.py
@@ -21,6 +21,8 @@ from datetime import timedelta
 from time import sleep
 from typing import TYPE_CHECKING, Any, NoReturn
 
+from packaging.version import Version
+
 from airflow.configuration import conf
 from airflow.exceptions import AirflowSkipException
 from airflow.providers.standard.triggers.temporal import DateTimeTrigger, 
TimeDeltaTrigger
@@ -32,6 +34,12 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
+def _get_airflow_version():
+    from airflow import __version__ as airflow_version
+
+    return Version(Version(airflow_version).base_version)
+
+
 class TimeDeltaSensor(BaseSensorOperator):
     """
     Waits for a timedelta after the run's data interval.
@@ -91,7 +99,18 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
                 raise AirflowSkipException("Skipping due to soft_fail is set 
to True.") from e
             raise
 
-        self.defer(trigger=trigger, method_name="execute_complete")
+        # todo: remove backcompat when min airflow version greater than 2.11
+        timeout: int | float | timedelta
+        if _get_airflow_version() >= Version("2.11.0"):
+            timeout = self.timeout
+        else:
+            timeout = timedelta(seconds=self.timeout)
+
+        self.defer(
+            trigger=trigger,
+            method_name="execute_complete",
+            timeout=timeout,
+        )
 
     def execute_complete(self, context: Context, event: Any = None) -> None:
         """Handle the event when the trigger fires and return immediately."""
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 638f012a3a5..e95866d95a5 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -29,7 +29,7 @@ import jinja2
 import pytest
 
 from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferralTimeout
 from airflow.lineage.entities import File
 from airflow.models.baseoperator import (
     BaseOperator,
@@ -40,6 +40,7 @@ from airflow.models.baseoperator import (
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
+from airflow.models.trigger import TriggerFailureReason
 from airflow.providers.common.sql.operators import sql
 from airflow.utils.edgemodifier import Label
 from airflow.utils.task_group import TaskGroup
@@ -582,6 +583,15 @@ class TestBaseOperator:
         # leaking a lot of state)
         assert caplog.messages == ["test"]
 
+    def test_resume_execution(self):
+        op = BaseOperator(task_id="hi")
+        with pytest.raises(TaskDeferralTimeout):
+            op.resume_execution(
+                next_method="__fail__",
+                next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
+                context={},
+            )
+
 
 def test_deepcopy():
     # Test bug when copying an operator attached to a DAG
diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py
index b1d398265a7..9bb4f5b9934 100644
--- a/tests/sensors/test_base.py
+++ b/tests/sensors/test_base.py
@@ -45,6 +45,7 @@ from airflow.executors.executor_constants import (
 from airflow.executors.local_executor import LocalExecutor
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.models import TaskInstance, TaskReschedule
+from airflow.models.trigger import TriggerFailureReason
 from airflow.models.xcom import XCom
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.celery.executors.celery_executor import CeleryExecutor
@@ -1061,6 +1062,15 @@ class TestBaseSensor:
             task = sensor.prepare_for_execution()
             assert task.mode == mode
 
+    def test_resume_execution(self):
+        op = BaseSensorOperator(task_id="hi")
+        with pytest.raises(AirflowSensorTimeout):
+            op.resume_execution(
+                next_method="__fail__",
+                next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
+                context={},
+            )
+
 
 @poke_mode_only
 class DummyPokeOnlySensor(BaseSensorOperator):

Reply via email to