This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3720689e38 Add `_on_failure_fail_dagrun` attribute for teardown tasks 
(#29832)
3720689e38 is described below

commit 3720689e38276d47ab9fe764c5c034c35dcaaf01
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Thu Mar 16 19:33:53 2023 +0100

    Add `_on_failure_fail_dagrun` attribute for teardown tasks (#29832)
    
    * Add `_on_failure_fail_dagrun` attribute for teardown tasks
    
    This adds basic implementation of on_failure_fail_dagrun params to teardown 
tasks
    
    * fixup! Add `_on_failure_fail_dagrun` attribute for teardown tasks
    
    * fixup! fixup! Add `_on_failure_fail_dagrun` attribute for teardown tasks
    
    * Apply suggestions from code review
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
    
    * fixup! Apply suggestions from code review
    
    ---------
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 airflow/decorators/setup_teardown.py          | 27 ++++++-----
 airflow/decorators/task_group.py              |  1 +
 airflow/models/baseoperator.py                |  5 +-
 airflow/utils/setup_teardown.py               |  5 +-
 airflow/utils/task_group.py                   |  6 +++
 tests/decorators/test_setup_teardown.py       | 70 ++++++++++++++++++++++++++-
 tests/serialization/test_dag_serialization.py | 44 +++++++++++++++--
 7 files changed, 141 insertions(+), 17 deletions(-)

diff --git a/airflow/decorators/setup_teardown.py 
b/airflow/decorators/setup_teardown.py
index b354d44528..8480bde636 100644
--- a/airflow/decorators/setup_teardown.py
+++ b/airflow/decorators/setup_teardown.py
@@ -37,14 +37,19 @@ def setup_task(python_callable: Callable) -> Callable:
     return wrapper
 
 
-def teardown_task(python_callable: Callable) -> Callable:
-    # Using FunctionType here since _TaskDecorator is also a callable
-    if isinstance(python_callable, types.FunctionType):
-        python_callable = python_task(python_callable)
-
-    @functools.wraps(python_callable)
-    def wrapper(*args, **kwargs) -> Callable:
-        with SetupTeardownContext.teardown():
-            return python_callable(*args, **kwargs)
-
-    return wrapper
+def teardown_task(_func=None, *, on_failure_fail_dagrun: bool | None = None) 
-> Callable:
+    def teardown(python_callable: Callable) -> Callable:
+        # Using FunctionType here since _TaskDecorator is also a callable
+        if isinstance(python_callable, types.FunctionType):
+            python_callable = python_task(python_callable)
+
+        @functools.wraps(python_callable)
+        def wrapper(*args, **kwargs) -> Callable:
+            with 
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
+                return python_callable(*args, **kwargs)
+
+        return wrapper
+
+    if _func is None:
+        return teardown
+    return teardown(_func)
diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py
index 1431a10452..bd9248032d 100644
--- a/airflow/decorators/task_group.py
+++ b/airflow/decorators/task_group.py
@@ -184,6 +184,7 @@ def task_group(
     add_suffix_on_collision: bool = False,
     setup: bool = False,
     teardown: bool = False,
+    on_failure_fail_dagrun: bool = False,
 ) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, 
FReturn]]:
     ...
 
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 21ee60d93c..6fd033d9d2 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -688,6 +688,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
     _is_setup = False
     _is_teardown = False
+    _on_failure_fail_dagrun = False
 
     def __init__(
         self,
@@ -929,7 +930,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     def as_teardown(cls, *args, **kwargs):
         from airflow.utils.setup_teardown import SetupTeardownContext
 
-        with SetupTeardownContext.teardown():
+        on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
+        with 
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
             return cls(*args, **kwargs)
 
     def __eq__(self, other):
@@ -1477,6 +1479,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     "params",
                     "_is_setup",
                     "_is_teardown",
+                    "_on_failure_fail_dagrun",
                 }
             )
             DagContext.pop_context_managed_dag()
diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
index e685b678c3..6c8c8500ff 100644
--- a/airflow/utils/setup_teardown.py
+++ b/airflow/utils/setup_teardown.py
@@ -26,6 +26,7 @@ class SetupTeardownContext:
 
     is_setup: bool = False
     is_teardown: bool = False
+    on_failure_fail_dagrun: bool = False
 
     @classmethod
     @contextmanager
@@ -43,7 +44,7 @@ class SetupTeardownContext:
 
     @classmethod
     @contextmanager
-    def teardown(cls):
+    def teardown(cls, *, on_failure_fail_dagrun=False):
         if cls.is_setup or cls.is_teardown:
             raise AirflowException(
                 "A teardown task or taskgroup cannot be nested inside another"
@@ -51,7 +52,9 @@ class SetupTeardownContext:
             )
 
         cls.is_teardown = True
+        cls.on_failure_fail_dagrun = on_failure_fail_dagrun
         try:
             yield
         finally:
             cls.is_teardown = False
+            cls.on_failure_fail_dagrun = False
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 85249d3a98..a7d45e0e44 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -94,16 +94,20 @@ class TaskGroup(DAGNode):
         add_suffix_on_collision: bool = False,
         setup: bool = False,
         teardown: bool = False,
+        on_failure_fail_dagrun: bool = False,
     ):
         from airflow.models.dag import DagContext
 
         if setup and teardown:
             raise AirflowException("Cannot set both setup and teardown to 
True")
+        if on_failure_fail_dagrun and not teardown:
+            raise AirflowException("on_failure_fail_dagrun can only be set to 
True if teardown is True")
 
         self.prefix_group_id = prefix_group_id
         self.default_args = copy.deepcopy(default_args or {})
         self.setup = setup
         self.teardown = teardown
+        self.on_failure_fail_dagrun = on_failure_fail_dagrun
 
         dag = dag or DagContext.get_current_dag()
 
@@ -246,6 +250,8 @@ class TaskGroup(DAGNode):
         elif SetupTeardownContext.is_teardown or is_teardown:
             if isinstance(task, AbstractOperator):
                 setattr(task, "_is_teardown", True)
+                if SetupTeardownContext.on_failure_fail_dagrun or 
self.on_failure_fail_dagrun:
+                    setattr(task, "_on_failure_fail_dagrun", True)
 
         self.children[key] = task
 
diff --git a/tests/decorators/test_setup_teardown.py 
b/tests/decorators/test_setup_teardown.py
index c6873a98bf..26eea29a29 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -65,6 +65,7 @@ class TestSetupTearDownTask:
         assert setup_task._is_setup
 
     def test_marking_operator_as_setup_task(self, dag_maker):
+
         with dag_maker() as dag:
             BashOperator.as_setup(task_id="mytask", bash_command='echo "I am a 
setup task"')
 
@@ -86,7 +87,6 @@ class TestSetupTearDownTask:
         assert teardown_task._is_teardown
 
     def test_marking_operator_as_teardown_task(self, dag_maker):
-        from airflow.operators.bash import BashOperator
 
         with dag_maker() as dag:
             BashOperator.as_teardown(task_id="mytask", bash_command='echo "I 
am a setup task"')
@@ -454,3 +454,71 @@ class TestSetupTearDownTask:
         with dag_maker():
             with pytest.raises(AirflowException, match="Cannot set both setup 
and teardown to True"):
                 TaskGroup("mygroup", setup=True, teardown=True)
+
+    def test_cannot_use_on_failure_fail_dagrun_without_teardown(self, 
dag_maker):
+        """Test that on_failure_fail_dagrun can only be used with teardown"""
+        with dag_maker():
+            with pytest.raises(
+                AirflowException, match="on_failure_fail_dagrun can only be 
set to True if teardown is True"
+            ):
+                TaskGroup("mygroup", on_failure_fail_dagrun=True)
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_teardown_task_decorators_works_with_on_failure_fail_dagrun(
+        self, on_failure_fail_dagrun, dag_maker
+    ):
+        @teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
+        def mytask():
+            print("I am a teardown task")
+
+        with dag_maker() as dag:
+            mytask()
+        teardown_task = dag.task_group.children["mytask"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+        assert len(dag.task_group.children) == 1
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_classic_teardown_task_works_with_on_failure_fail_dagrun(self, 
on_failure_fail_dagrun, dag_maker):
+        with dag_maker() as dag:
+            BashOperator.as_teardown(
+                task_id="mytask",
+                bash_command='echo "I am a teardown task"',
+                on_failure_fail_dagrun=on_failure_fail_dagrun,
+            )
+
+        teardown_task = dag.task_group.children["mytask"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+        assert len(dag.task_group.children) == 1
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_teardown_taskgroup_classic_works_with_on_failure_fail_dagrun(
+        self, on_failure_fail_dagrun, dag_maker
+    ):
+        with dag_maker() as dag:
+            with TaskGroup("mygroup", teardown=True, 
on_failure_fail_dagrun=on_failure_fail_dagrun):
+                BashOperator(task_id="mytask", bash_command="echo 1")
+
+        teardown_task = 
dag.task_group.children["mygroup"].children["mygroup.mytask"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_teardown_taskgroup_decorator_works_with_on_failure_fail_dagrun(
+        self, on_failure_fail_dagrun, dag_maker
+    ):
+        with dag_maker() as dag:
+
+            @task_group(teardown=True, 
on_failure_fail_dagrun=on_failure_fail_dagrun)
+            def mygroup():
+                @task
+                def mytask():
+                    print(1)
+
+                mytask()
+
+            mygroup()
+        teardown_task = 
dag.task_group.children["mygroup"].children["mygroup.mytask"]
+        assert teardown_task._is_teardown
+        assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 6e620b2aca..a044213932 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -38,6 +38,7 @@ from kubernetes.client import models as k8s
 
 import airflow
 from airflow.datasets import Dataset
+from airflow.decorators import teardown
 from airflow.exceptions import AirflowException, SerializationError
 from airflow.hooks.base import BaseHook
 from airflow.kubernetes.pod_generator import PodGenerator
@@ -48,6 +49,7 @@ from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import Param, ParamsDict
 from airflow.models.xcom import XCOM_RETURN_KEY, XCom
 from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
 from airflow.security import permissions
 from airflow.sensors.bash import BashSensor
 from airflow.serialization.json_schema import load_dag_schema_dict
@@ -161,6 +163,7 @@ serialized_simple_dag_ground_truth = {
                 "pool": "default_pool",
                 "_is_setup": False,
                 "_is_teardown": False,
+                "_on_failure_fail_dagrun": False,
                 "executor_config": {
                     "__type": "dict",
                     "__var": {
@@ -192,6 +195,7 @@ serialized_simple_dag_ground_truth = {
                 "pool": "default_pool",
                 "_is_setup": False,
                 "_is_teardown": False,
+                "_on_failure_fail_dagrun": False,
             },
         ],
         "schedule_interval": {"__type": "timedelta", "__var": 86400.0},
@@ -1260,7 +1264,6 @@ class TestStringifiedDAGs:
         """
         Test TaskGroup serialization/deserialization.
         """
-        from airflow.operators.empty import EmptyOperator
 
         execution_date = datetime(2020, 1, 1)
         with DAG("test_task_group_serialization", start_date=execution_date) 
as dag:
@@ -1317,7 +1320,6 @@ class TestStringifiedDAGs:
         """
         Test TaskGroup setup and teardown task serialization/deserialization.
         """
-        from airflow.operators.empty import EmptyOperator
 
         execution_date = datetime(2020, 1, 1)
         with DAG("test_task_group_setup_teardown_tasks", 
start_date=execution_date) as dag:
@@ -1376,7 +1378,6 @@ class TestStringifiedDAGs:
         Test TaskGroup setup and teardown taskgroup 
serialization/deserialization.
         """
         from airflow.decorators import setup, task_group, teardown
-        from airflow.operators.empty import EmptyOperator
 
         execution_date = datetime(2020, 1, 1)
         with DAG("test_task_group_setup_teardown_task_groups", 
start_date=execution_date) as dag:
@@ -1435,6 +1436,43 @@ class TestStringifiedDAGs:
             se_teardown_group.children["teardown_group.teardown1"], 
is_teardown=True
         )
 
+    def test_teardown_task_on_failure_fail_dagrun_serialization(self, 
dag_maker):
+        with dag_maker() as dag:
+
+            @teardown(on_failure_fail_dagrun=True)
+            def mytask():
+                print(1)
+
+            mytask()
+
+        dag_dict = SerializedDAG.to_dict(dag)
+        SerializedDAG.validate_schema(dag_dict)
+        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+        self.validate_deserialized_dag(json_dag, dag)
+
+        serialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
+        task = serialized_dag.task_group.children["mytask"]
+        assert task._is_teardown
+        assert task._on_failure_fail_dagrun
+
+    @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
+    def test_teardown_task_on_failure_fail_dagrun_serialization_taskgroup(
+        self, dag_maker, on_failure_fail_dagrun
+    ):
+        with dag_maker() as dag:
+            with TaskGroup("mygroup", teardown=True, 
on_failure_fail_dagrun=on_failure_fail_dagrun):
+                EmptyOperator(task_id="mytask")
+
+        dag_dict = SerializedDAG.to_dict(dag)
+        SerializedDAG.validate_schema(dag_dict)
+        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+        self.validate_deserialized_dag(json_dag, dag)
+
+        serialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
+        task = 
serialized_dag.task_group.children["mygroup"].children["mygroup.mytask"]
+        assert task._is_teardown
+        assert task._on_failure_fail_dagrun is on_failure_fail_dagrun
+
     def test_deps_sorted(self):
         """
         Tests serialize_operator, make sure the deps is in order

Reply via email to