This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c5cebf045 Add Fail Fast feature for DAGs (#29406)
4c5cebf045 is described below

commit 4c5cebf04555aa7b892a50a85f2f488acc4c057b
Author: RachitSharma2001 <[email protected]>
AuthorDate: Sun Apr 23 09:44:23 2023 -0700

    Add Fail Fast feature for DAGs (#29406)
---
 airflow/exceptions.py             | 18 ++++++++++++-
 airflow/models/baseoperator.py    |  4 ++-
 airflow/models/dag.py             | 12 +++++++++
 airflow/models/taskinstance.py    | 19 ++++++++++++++
 tests/models/test_baseoperator.py | 38 ++++++++++++++++++++++++++-
 tests/models/test_dag.py          | 36 +++++++++++++++++++++++++
 tests/models/test_taskinstance.py | 55 +++++++++++++++++++++++++++++++++++++++
 7 files changed, 179 insertions(+), 3 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index ae45c6f343..613750028a 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -26,7 +26,7 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, Any, NamedTuple, Sized
 
 if TYPE_CHECKING:
-    from airflow.models import DagRun
+    from airflow.models import DAG, DagRun
 
 
 class AirflowException(Exception):
@@ -207,6 +207,22 @@ class DagFileExists(AirflowBadRequest):
         warnings.warn("DagFileExists is deprecated and will be removed.", 
DeprecationWarning, stacklevel=2)
 
 
+class DagInvalidTriggerRule(AirflowException):
+    """Raise when a dag has 'fail_stop' enabled yet has a non-default trigger 
rule"""
+
+    @classmethod
+    def check(cls, dag: DAG | None, trigger_rule: str):
+        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+        if dag is not None and dag.fail_stop and trigger_rule != 
DEFAULT_TRIGGER_RULE:
+            raise cls()
+
+    def __str__(self) -> str:
+        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+        return f"A 'fail-stop' dag can only have {DEFAULT_TRIGGER_RULE} 
trigger rule"
+
+
 class DuplicateTaskIdFound(AirflowException):
     """Raise when a Task with duplicate task_id is defined in the same DAG."""
 
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 37106c580f..8eb6317874 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -53,7 +53,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import NoResultFound
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, 
TaskDeferred
+from airflow.exceptions import AirflowException, DagInvalidTriggerRule, 
RemovedInAirflow3Warning, TaskDeferred
 from airflow.lineage import apply_lineage, prepare_lineage
 from airflow.models.abstractoperator import (
     DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -801,6 +801,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         dag = dag or DagContext.get_current_dag()
         task_group = task_group or TaskGroupContext.get_current_task_group(dag)
 
+        DagInvalidTriggerRule.check(dag, trigger_rule)
+
         self.task_id = task_group.child_id(task_id) if task_group else task_id
         if not self.__from_mapped and task_group:
             task_group.add(self)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 89becc04c7..b9785a81b2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -70,6 +70,7 @@ from airflow.exceptions import (
     AirflowDagInconsistent,
     AirflowException,
     AirflowSkipException,
+    DagInvalidTriggerRule,
     DuplicateTaskIdFound,
     RemovedInAirflow3Warning,
     TaskNotFound,
@@ -357,6 +358,9 @@ class DAG(LoggingMixin):
         Can be used as an HTTP link (for example the link to your Slack 
channel), or a mailto link.
         e.g: {"dag_owner": "https://airflow.apache.org/"}
     :param auto_register: Automatically register this DAG when it is used in a 
``with`` block
+    :param fail_stop: Fails currently running tasks when task in DAG fails.
+        **Warning**: A fail stop dag can only have tasks with the default 
trigger rule ("all_success").
+        An exception will be thrown if any task in a fail stop dag has a non 
default trigger rule.
     """
 
     _comps = {
@@ -419,6 +423,7 @@ class DAG(LoggingMixin):
         tags: list[str] | None = None,
         owner_links: dict[str, str] | None = None,
         auto_register: bool = True,
+        fail_stop: bool = False,
     ):
         from airflow.utils.task_group import TaskGroup
 
@@ -602,6 +607,8 @@ class DAG(LoggingMixin):
         self.is_paused_upon_creation = is_paused_upon_creation
         self.auto_register = auto_register
 
+        self.fail_stop = fail_stop
+
         self.jinja_environment_kwargs = jinja_environment_kwargs
         self.render_template_as_native_obj = render_template_as_native_obj
 
@@ -2353,6 +2360,8 @@ class DAG(LoggingMixin):
 
         :param task: the task you want to add
         """
+        DagInvalidTriggerRule.check(self, task.trigger_rule)
+
         from airflow.utils.task_group import TaskGroupContext
 
         if not self.start_date and not task.start_date:
@@ -3055,6 +3064,7 @@ class DAG(LoggingMixin):
                 "has_on_success_callback",
                 "has_on_failure_callback",
                 "auto_register",
+                "fail_stop",
             }
             cls.__serialized_fields = 
frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
         return cls.__serialized_fields
@@ -3530,6 +3540,7 @@ def dag(
     tags: list[str] | None = None,
     owner_links: dict[str, str] | None = None,
     auto_register: bool = True,
+    fail_stop: bool = False,
 ) -> Callable[[Callable], Callable[..., DAG]]:
     """
     Python dag decorator. Wraps a function into an Airflow DAG.
@@ -3583,6 +3594,7 @@ def dag(
                 schedule=schedule,
                 owner_links=owner_links,
                 auto_register=auto_register,
+                fail_stop=fail_stop,
             ) as dag_obj:
                 # Set DAG documentation from function documentation if it 
exists and doc_md is not set.
                 if f.__doc__ and not dag_obj.doc_md:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c744d5b3b7..a103ad0cfd 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -181,6 +181,21 @@ def set_current_context(context: Context) -> 
Generator[Context, None, None]:
             )
 
 
+def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, 
task_id_to_ignore: int):
+    for ti in tis:
+        if ti.task_id == task_id_to_ignore or ti.state in (
+            TaskInstanceState.SUCCESS,
+            TaskInstanceState.FAILED,
+        ):
+            continue
+        if ti.state == TaskInstanceState.RUNNING:
+            log.info("Forcing task %s to fail", ti.task_id)
+            ti.error(session)
+        else:
+            log.info("Setting task %s to SKIPPED", ti.task_id)
+            ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
+
+
 def clear_task_instances(
     tis: list[TaskInstance],
     session: Session,
@@ -1896,6 +1911,10 @@ class TaskInstance(Base, LoggingMixin):
             email_for_state = operator.attrgetter("email_on_failure")
             callbacks = task.on_failure_callback if task else None
             callback_type = "on_failure"
+
+            if task and task.dag and task.dag.fail_stop:
+                tis = self.get_dagrun(session).get_task_instances()
+                stop_all_tasks_in_dag(tis, session, self.task_id)
         else:
             if self.state == State.QUEUED:
                 # We increase the try_number so as to fail the task if it 
fails to start after sometime
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 45f456738e..418ffdba72 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -28,7 +28,7 @@ import jinja2
 import pytest
 
 from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
+from airflow.exceptions import AirflowException, DagInvalidTriggerRule, 
RemovedInAirflow3Warning
 from airflow.lineage.entities import File
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, 
cross_downstream
@@ -163,6 +163,42 @@ class TestBaseOperator:
                 illegal_argument_1234="hello?",
             )
 
+    def test_trigger_rule_validation(self):
+        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+        fail_stop_dag = DAG(
+            dag_id="test_dag_trigger_rule_validation", 
start_date=DEFAULT_DATE, fail_stop=True
+        )
+        non_fail_stop_dag = DAG(
+            dag_id="test_dag_trigger_rule_validation", 
start_date=DEFAULT_DATE, fail_stop=False
+        )
+
+        # An operator with default trigger rule and a fail-stop dag should be 
allowed
+        try:
+            BaseOperator(
+                task_id="test_valid_trigger_rule", dag=fail_stop_dag, 
trigger_rule=DEFAULT_TRIGGER_RULE
+            )
+        except DagInvalidTriggerRule as exception:
+            assert (
+                False
+            ), f"BaseOperator raises exception with fail-stop dag & default 
trigger rule: {exception}"
+
+        # An operator with non default trigger rule and a non fail-stop dag 
should be allowed
+        try:
+            BaseOperator(
+                task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, 
trigger_rule=TriggerRule.DUMMY
+            )
+        except DagInvalidTriggerRule as exception:
+            assert (
+                False
+            ), f"BaseOperator raises exception with non fail-stop dag & 
non-default trigger rule: {exception}"
+
+        # An operator with non default trigger rule and a fail stop dag should 
not be allowed
+        with pytest.raises(DagInvalidTriggerRule):
+            BaseOperator(
+                task_id="test_invalid_trigger_rule", dag=fail_stop_dag, 
trigger_rule=TriggerRule.DUMMY
+            )
+
     @pytest.mark.parametrize(
         ("content", "context", "expected_output"),
         [
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index d9f269bca6..b920e0a8e4 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1718,6 +1718,42 @@ class TestDag:
         )
         assert dr.creating_job_id == job_id
 
+    def test_dag_add_task_checks_trigger_rule(self):
+        # A non fail stop dag should allow any trigger rule
+        from airflow.exceptions import DagInvalidTriggerRule
+        from airflow.utils.trigger_rule import TriggerRule
+
+        task_with_non_default_trigger_rule = EmptyOperator(
+            task_id="task_with_non_default_trigger_rule", 
trigger_rule=TriggerRule.DUMMY
+        )
+        non_fail_stop_dag = DAG(
+            dag_id="test_dag_add_task_checks_trigger_rule", 
start_date=DEFAULT_DATE, fail_stop=False
+        )
+        try:
+            non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
+        except DagInvalidTriggerRule as exception:
+            assert False, f"dag add_task() raises DagInvalidTriggerRule for 
non fail stop dag: {exception}"
+
+        # a fail stop dag should allow default trigger rule
+        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+        fail_stop_dag = DAG(
+            dag_id="test_dag_add_task_checks_trigger_rule", 
start_date=DEFAULT_DATE, fail_stop=True
+        )
+        task_with_default_trigger_rule = EmptyOperator(
+            task_id="task_with_default_trigger_rule", 
trigger_rule=DEFAULT_TRIGGER_RULE
+        )
+        try:
+            fail_stop_dag.add_task(task_with_default_trigger_rule)
+        except DagInvalidTriggerRule as exception:
+            assert (
+                False
+            ), f"dag.add_task() raises exception for fail-stop dag & default 
trigger rule: {exception}"
+
+        # a fail stop dag should not allow a non-default trigger rule
+        with pytest.raises(DagInvalidTriggerRule):
+            fail_stop_dag.add_task(task_with_non_default_trigger_rule)
+
     def test_dag_add_task_sets_default_task_group(self):
         dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", 
start_date=DEFAULT_DATE)
         task_without_task_group = 
EmptyOperator(task_id="task_without_group_id")
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index f5ad9d30ee..dbd32e0a0b 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2556,6 +2556,61 @@ class TestTaskInstance:
         del ti.task
         ti.handle_failure("test ti.task undefined")
 
+    @provide_session
+    def test_handle_failure_fail_stop(self, create_dummy_dag, session=None):
+        start_date = timezone.datetime(2016, 6, 1)
+        clear_db_runs()
+
+        dag, task1 = create_dummy_dag(
+            dag_id="test_handle_failure_fail_stop",
+            schedule=None,
+            start_date=start_date,
+            task_id="task1",
+            trigger_rule="all_success",
+            with_dagrun_type=DagRunType.MANUAL,
+            session=session,
+            fail_stop=True,
+        )
+        dr = dag.create_dagrun(
+            run_id="test_ff",
+            run_type=DagRunType.MANUAL,
+            execution_date=timezone.utcnow(),
+            state=None,
+            session=session,
+        )
+
+        ti1 = dr.get_task_instance(task1.task_id, session=session)
+        ti1.task = task1
+        ti1.state = State.SUCCESS
+
+        states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, 
State.DEFERRED]
+        tasks = []
+        for i in range(len(states)):
+            op = EmptyOperator(
+                task_id=f"reg_Task{i}",
+                dag=dag,
+            )
+            ti = TI(task=op, run_id=dr.run_id)
+            ti.state = states[i]
+            session.add(ti)
+            tasks.append(ti)
+
+        fail_task = EmptyOperator(
+            task_id="fail_Task",
+            dag=dag,
+        )
+        ti_ff = TI(task=fail_task, run_id=dr.run_id)
+        ti_ff.state = State.FAILED
+        session.add(ti_ff)
+        session.flush()
+        ti_ff.handle_failure("test retry handling")
+
+        assert ti1.state == State.SUCCESS
+        assert ti_ff.state == State.FAILED
+        exp_states = [State.FAILED, State.FAILED, State.SKIPPED, 
State.SKIPPED, State.SKIPPED]
+        for i in range(len(states)):
+            assert tasks[i].state == exp_states[i]
+
     def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
         def fail():
             raise AirflowFailException("hopeless")

Reply via email to