This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4c5cebf045 Add Fail Fast feature for DAGs (#29406)
4c5cebf045 is described below
commit 4c5cebf04555aa7b892a50a85f2f488acc4c057b
Author: RachitSharma2001 <[email protected]>
AuthorDate: Sun Apr 23 09:44:23 2023 -0700
Add Fail Fast feature for DAGs (#29406)
---
airflow/exceptions.py | 18 ++++++++++++-
airflow/models/baseoperator.py | 4 ++-
airflow/models/dag.py | 12 +++++++++
airflow/models/taskinstance.py | 19 ++++++++++++++
tests/models/test_baseoperator.py | 38 ++++++++++++++++++++++++++-
tests/models/test_dag.py | 36 +++++++++++++++++++++++++
tests/models/test_taskinstance.py | 55 +++++++++++++++++++++++++++++++++++++++
7 files changed, 179 insertions(+), 3 deletions(-)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index ae45c6f343..613750028a 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -26,7 +26,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple, Sized
if TYPE_CHECKING:
- from airflow.models import DagRun
+ from airflow.models import DAG, DagRun
class AirflowException(Exception):
@@ -207,6 +207,22 @@ class DagFileExists(AirflowBadRequest):
warnings.warn("DagFileExists is deprecated and will be removed.",
DeprecationWarning, stacklevel=2)
+class DagInvalidTriggerRule(AirflowException):
+ """Raise when a dag has 'fail_stop' enabled yet has a non-default trigger
rule"""
+
+ @classmethod
+ def check(cls, dag: DAG | None, trigger_rule: str):
+ from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+ if dag is not None and dag.fail_stop and trigger_rule !=
DEFAULT_TRIGGER_RULE:
+ raise cls()
+
+ def __str__(self) -> str:
+ from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+ return f"A 'fail-stop' dag can only have {DEFAULT_TRIGGER_RULE}
trigger rule"
+
+
class DuplicateTaskIdFound(AirflowException):
"""Raise when a Task with duplicate task_id is defined in the same DAG."""
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 37106c580f..8eb6317874 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -53,7 +53,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
TaskDeferred
+from airflow.exceptions import AirflowException, DagInvalidTriggerRule,
RemovedInAirflow3Warning, TaskDeferred
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -801,6 +801,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+ DagInvalidTriggerRule.check(dag, trigger_rule)
+
self.task_id = task_group.child_id(task_id) if task_group else task_id
if not self.__from_mapped and task_group:
task_group.add(self)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 89becc04c7..b9785a81b2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -70,6 +70,7 @@ from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
+ DagInvalidTriggerRule,
DuplicateTaskIdFound,
RemovedInAirflow3Warning,
TaskNotFound,
@@ -357,6 +358,9 @@ class DAG(LoggingMixin):
Can be used as an HTTP link (for example the link to your Slack
channel), or a mailto link.
e.g: {"dag_owner": "https://airflow.apache.org/"}
:param auto_register: Automatically register this DAG when it is used in a
``with`` block
+ :param fail_stop: Fails currently running tasks when task in DAG fails.
+ **Warning**: A fail stop dag can only have tasks with the default
trigger rule ("all_success").
+ An exception will be thrown if any task in a fail stop dag has a non
default trigger rule.
"""
_comps = {
@@ -419,6 +423,7 @@ class DAG(LoggingMixin):
tags: list[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
+ fail_stop: bool = False,
):
from airflow.utils.task_group import TaskGroup
@@ -602,6 +607,8 @@ class DAG(LoggingMixin):
self.is_paused_upon_creation = is_paused_upon_creation
self.auto_register = auto_register
+ self.fail_stop = fail_stop
+
self.jinja_environment_kwargs = jinja_environment_kwargs
self.render_template_as_native_obj = render_template_as_native_obj
@@ -2353,6 +2360,8 @@ class DAG(LoggingMixin):
:param task: the task you want to add
"""
+ DagInvalidTriggerRule.check(self, task.trigger_rule)
+
from airflow.utils.task_group import TaskGroupContext
if not self.start_date and not task.start_date:
@@ -3055,6 +3064,7 @@ class DAG(LoggingMixin):
"has_on_success_callback",
"has_on_failure_callback",
"auto_register",
+ "fail_stop",
}
cls.__serialized_fields =
frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
return cls.__serialized_fields
@@ -3530,6 +3540,7 @@ def dag(
tags: list[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
+ fail_stop: bool = False,
) -> Callable[[Callable], Callable[..., DAG]]:
"""
Python dag decorator. Wraps a function into an Airflow DAG.
@@ -3583,6 +3594,7 @@ def dag(
schedule=schedule,
owner_links=owner_links,
auto_register=auto_register,
+ fail_stop=fail_stop,
) as dag_obj:
# Set DAG documentation from function documentation if it
exists and doc_md is not set.
if f.__doc__ and not dag_obj.doc_md:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c744d5b3b7..a103ad0cfd 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -181,6 +181,21 @@ def set_current_context(context: Context) ->
Generator[Context, None, None]:
)
+def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session,
task_id_to_ignore: int):
+ for ti in tis:
+ if ti.task_id == task_id_to_ignore or ti.state in (
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.FAILED,
+ ):
+ continue
+ if ti.state == TaskInstanceState.RUNNING:
+ log.info("Forcing task %s to fail", ti.task_id)
+ ti.error(session)
+ else:
+ log.info("Setting task %s to SKIPPED", ti.task_id)
+ ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
+
+
def clear_task_instances(
tis: list[TaskInstance],
session: Session,
@@ -1896,6 +1911,10 @@ class TaskInstance(Base, LoggingMixin):
email_for_state = operator.attrgetter("email_on_failure")
callbacks = task.on_failure_callback if task else None
callback_type = "on_failure"
+
+ if task and task.dag and task.dag.fail_stop:
+ tis = self.get_dagrun(session).get_task_instances()
+ stop_all_tasks_in_dag(tis, session, self.task_id)
else:
if self.state == State.QUEUED:
# We increase the try_number so as to fail the task if it
fails to start after sometime
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index 45f456738e..418ffdba72 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -28,7 +28,7 @@ import jinja2
import pytest
from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
+from airflow.exceptions import AirflowException, DagInvalidTriggerRule,
RemovedInAirflow3Warning
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain,
cross_downstream
@@ -163,6 +163,42 @@ class TestBaseOperator:
illegal_argument_1234="hello?",
)
+ def test_trigger_rule_validation(self):
+ from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+ fail_stop_dag = DAG(
+ dag_id="test_dag_trigger_rule_validation",
start_date=DEFAULT_DATE, fail_stop=True
+ )
+ non_fail_stop_dag = DAG(
+ dag_id="test_dag_trigger_rule_validation",
start_date=DEFAULT_DATE, fail_stop=False
+ )
+
+ # An operator with default trigger rule and a fail-stop dag should be
allowed
+ try:
+ BaseOperator(
+ task_id="test_valid_trigger_rule", dag=fail_stop_dag,
trigger_rule=DEFAULT_TRIGGER_RULE
+ )
+ except DagInvalidTriggerRule as exception:
+ assert (
+ False
+ ), f"BaseOperator raises exception with fail-stop dag & default
trigger rule: {exception}"
+
+ # An operator with non default trigger rule and a non fail-stop dag
should be allowed
+ try:
+ BaseOperator(
+ task_id="test_valid_trigger_rule", dag=non_fail_stop_dag,
trigger_rule=TriggerRule.DUMMY
+ )
+ except DagInvalidTriggerRule as exception:
+ assert (
+ False
+ ), f"BaseOperator raises exception with non fail-stop dag &
non-default trigger rule: {exception}"
+
+ # An operator with non default trigger rule and a fail stop dag should
not be allowed
+ with pytest.raises(DagInvalidTriggerRule):
+ BaseOperator(
+ task_id="test_invalid_trigger_rule", dag=fail_stop_dag,
trigger_rule=TriggerRule.DUMMY
+ )
+
@pytest.mark.parametrize(
("content", "context", "expected_output"),
[
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index d9f269bca6..b920e0a8e4 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1718,6 +1718,42 @@ class TestDag:
)
assert dr.creating_job_id == job_id
+ def test_dag_add_task_checks_trigger_rule(self):
+ # A non fail stop dag should allow any trigger rule
+ from airflow.exceptions import DagInvalidTriggerRule
+ from airflow.utils.trigger_rule import TriggerRule
+
+ task_with_non_default_trigger_rule = EmptyOperator(
+ task_id="task_with_non_default_trigger_rule",
trigger_rule=TriggerRule.DUMMY
+ )
+ non_fail_stop_dag = DAG(
+ dag_id="test_dag_add_task_checks_trigger_rule",
start_date=DEFAULT_DATE, fail_stop=False
+ )
+ try:
+ non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
+ except DagInvalidTriggerRule as exception:
+ assert False, f"dag add_task() raises DagInvalidTriggerRule for
non fail stop dag: {exception}"
+
+ # a fail stop dag should allow default trigger rule
+ from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+
+ fail_stop_dag = DAG(
+ dag_id="test_dag_add_task_checks_trigger_rule",
start_date=DEFAULT_DATE, fail_stop=True
+ )
+ task_with_default_trigger_rule = EmptyOperator(
+ task_id="task_with_default_trigger_rule",
trigger_rule=DEFAULT_TRIGGER_RULE
+ )
+ try:
+ fail_stop_dag.add_task(task_with_default_trigger_rule)
+ except DagInvalidTriggerRule as exception:
+ assert (
+ False
+ ), f"dag.add_task() raises exception for fail-stop dag & default
trigger rule: {exception}"
+
+ # a fail stop dag should not allow a non-default trigger rule
+ with pytest.raises(DagInvalidTriggerRule):
+ fail_stop_dag.add_task(task_with_non_default_trigger_rule)
+
def test_dag_add_task_sets_default_task_group(self):
dag = DAG(dag_id="test_dag_add_task_sets_default_task_group",
start_date=DEFAULT_DATE)
task_without_task_group =
EmptyOperator(task_id="task_without_group_id")
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index f5ad9d30ee..dbd32e0a0b 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2556,6 +2556,61 @@ class TestTaskInstance:
del ti.task
ti.handle_failure("test ti.task undefined")
+ @provide_session
+ def test_handle_failure_fail_stop(self, create_dummy_dag, session=None):
+ start_date = timezone.datetime(2016, 6, 1)
+ clear_db_runs()
+
+ dag, task1 = create_dummy_dag(
+ dag_id="test_handle_failure_fail_stop",
+ schedule=None,
+ start_date=start_date,
+ task_id="task1",
+ trigger_rule="all_success",
+ with_dagrun_type=DagRunType.MANUAL,
+ session=session,
+ fail_stop=True,
+ )
+ dr = dag.create_dagrun(
+ run_id="test_ff",
+ run_type=DagRunType.MANUAL,
+ execution_date=timezone.utcnow(),
+ state=None,
+ session=session,
+ )
+
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
+ ti1.task = task1
+ ti1.state = State.SUCCESS
+
+ states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED,
State.DEFERRED]
+ tasks = []
+ for i in range(len(states)):
+ op = EmptyOperator(
+ task_id=f"reg_Task{i}",
+ dag=dag,
+ )
+ ti = TI(task=op, run_id=dr.run_id)
+ ti.state = states[i]
+ session.add(ti)
+ tasks.append(ti)
+
+ fail_task = EmptyOperator(
+ task_id="fail_Task",
+ dag=dag,
+ )
+ ti_ff = TI(task=fail_task, run_id=dr.run_id)
+ ti_ff.state = State.FAILED
+ session.add(ti_ff)
+ session.flush()
+ ti_ff.handle_failure("test retry handling")
+
+ assert ti1.state == State.SUCCESS
+ assert ti_ff.state == State.FAILED
+ exp_states = [State.FAILED, State.FAILED, State.SKIPPED,
State.SKIPPED, State.SKIPPED]
+ for i in range(len(states)):
+ assert tasks[i].state == exp_states[i]
+
def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
def fail():
raise AirflowFailException("hopeless")