This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 016ce99486 Change `as_setup` and `as_teardown` to instance methods 
(#32053)
016ce99486 is described below

commit 016ce9948625a556093b0182439aa50314c651da
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Jun 26 10:32:25 2023 -0700

    Change `as_setup` and `as_teardown` to instance methods (#32053)
    
    This provides a number of benefits.
    * provides a oneline syntax for setting setup / teardown deps
    * makes it easy to convert dags to use feature
    * provides a mechanism to combine "reusable" taskflow tasks with setup / 
teardown
    * set setup and teardown in the same place you set deps
    
    ---------
    
    Co-authored-by: Ephraim Anierobi <[email protected]>
---
 airflow/example_dags/example_setup_teardown.py     |  16 +-
 .../example_setup_teardown_taskflow.py             |  65 ++++--
 airflow/models/abstractoperator.py                 |  94 +++++++-
 airflow/models/baseoperator.py                     |  35 ---
 airflow/models/mappedoperator.py                   |  26 ++-
 airflow/models/taskmixin.py                        |  15 ++
 airflow/models/xcom_arg.py                         |  53 ++++-
 airflow/serialization/serialized_objects.py        |   3 +-
 tests/decorators/test_setup_teardown.py            |  63 +++---
 tests/models/test_dag.py                           |  16 +-
 tests/models/test_taskinstance.py                  |   2 +-
 tests/models/test_taskmixin.py                     | 248 +++++++++++++++++++++
 tests/serialization/test_dag_serialization.py      |  16 +-
 tests/ti_deps/deps/test_trigger_rule_dep.py        |   2 +-
 14 files changed, 534 insertions(+), 120 deletions(-)

diff --git a/airflow/example_dags/example_setup_teardown.py 
b/airflow/example_dags/example_setup_teardown.py
index 77d7d5bdc6..59aba9753a 100644
--- a/airflow/example_dags/example_setup_teardown.py
+++ b/airflow/example_dags/example_setup_teardown.py
@@ -30,21 +30,19 @@ with DAG(
     catchup=False,
     tags=["example"],
 ) as dag:
-    root_setup = BashOperator.as_setup(task_id="root_setup", 
bash_command="echo 'Hello from root_setup'")
+    root_setup = BashOperator(task_id="root_setup", bash_command="echo 'Hello 
from root_setup'").as_setup()
     root_normal = BashOperator(task_id="normal", bash_command="echo 'I am just 
a normal task'")
-    root_teardown = BashOperator.as_teardown(
+    root_teardown = BashOperator(
         task_id="root_teardown", bash_command="echo 'Goodbye from 
root_teardown'"
-    )
+    ).as_teardown(setups=root_setup)
     root_setup >> root_normal >> root_teardown
-    root_setup >> root_teardown
     with TaskGroup("section_1") as section_1:
-        inner_setup = BashOperator.as_setup(
+        inner_setup = BashOperator(
             task_id="taskgroup_setup", bash_command="echo 'Hello from 
taskgroup_setup'"
-        )
+        ).as_setup()
         inner_normal = BashOperator(task_id="normal", bash_command="echo 'I am 
just a normal task'")
-        inner_teardown = BashOperator.as_teardown(
+        inner_teardown = BashOperator(
             task_id="taskgroup_teardown", bash_command="echo 'Hello from 
taskgroup_teardown'"
-        )
+        ).as_teardown(setups=inner_setup)
         inner_setup >> inner_normal >> inner_teardown
-        inner_setup >> inner_teardown
     root_normal >> section_1
diff --git a/airflow/example_dags/example_setup_teardown_taskflow.py 
b/airflow/example_dags/example_setup_teardown_taskflow.py
index 245cc6a2e9..128534f1d2 100644
--- a/airflow/example_dags/example_setup_teardown_taskflow.py
+++ b/airflow/example_dags/example_setup_teardown_taskflow.py
@@ -29,30 +29,61 @@ with DAG(
     catchup=False,
     tags=["example"],
 ) as dag:
-    # You can use the setup and teardown decorators to add setup and teardown 
tasks at the DAG level
-    @setup
+
     @task
-    def root_setup():
-        print("Hello from root_setup")
+    def task_1():
+        print("Hello 1")
 
-    @teardown
     @task
-    def root_teardown():
-        print("Goodbye from root_teardown")
+    def task_2():
+        print("Hello 2")
+
+    @task
+    def task_3():
+        print("Hello 3")
+
+    # you can set setup / teardown relationships with the `as_teardown` method.
+    t1 = task_1()
+    t2 = task_2()
+    t3 = task_3()
+    t1 >> t2 >> t3.as_teardown(setups=t1)
+
+    # the method `as_teadrown` will mark t3 as teardown, t1 as setup, and 
arrow t1 >> t3
+    # now if you clear t2 (downstream), then t1 will be cleared in addition to 
t3
+
+    # it's also possible to use a decorator to mark a task as setup or
+    # teardown when you define it. see below.
+
+    @setup
+    def dag_setup():
+        print("I am dag_setup")
+
+    @teardown
+    def dag_teardown():
+        print("I am dag_teardown")
 
     @task
-    def normal():
+    def dag_normal_task():
         print("I am just a normal task")
 
+    s = dag_setup()
+    t = dag_teardown()
+
+    # by using the decorators, dag_setup and dag_teardown are already marked 
as setup / teardown
+    # now we just need to make sure they are linked directly
+    # what we need to do is this::
+    #     s >> t
+    #     s >> dag_normal_task() >> t
+    # but we can use a context manager to make it cleaner
+    with s >> t:
+        dag_normal_task()
+
     @task_group
     def section_1():
-        # You can also have setup and teardown tasks at the task group level
-        @setup
         @task
         def my_setup():
             print("I set up")
 
-        @teardown
         @task
         def my_teardown():
             print("I tear down")
@@ -61,13 +92,7 @@ with DAG(
         def hello():
             print("I say hello")
 
-        s = my_setup()
-        w = hello()
-        t = my_teardown()
-        s >> w >> t
-        s >> t
+        (s := my_setup()) >> hello() >> my_teardown().as_teardown(setups=s)
 
-    rs = root_setup()
-    normal() >> section_1()
-    rt = root_teardown()
-    rs >> rt
+    # and let's put section 1 inside the "dag setup" and "dag teardown"
+    s >> section_1() >> t
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index bf06b39da9..ff4f5c4140 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -26,7 +26,7 @@ from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models.expandinput import NotFullyPopulated
-from airflow.models.taskmixin import DAGNode
+from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.template.templater import Templater
 from airflow.utils.context import Context
 from airflow.utils.log.secrets_masker import redact
@@ -35,6 +35,7 @@ from airflow.utils.sqlalchemy import skip_locked, 
with_row_locks
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET, ArgNotSet
 from airflow.utils.weight_rule import WeightRule
 
 TaskStateChangeCallback = Callable[[Context], None]
@@ -102,6 +103,11 @@ class AbstractOperator(Templater, DAGNode):
 
     outlets: list
     inlets: list
+    trigger_rule: TriggerRule
+
+    _is_setup = False
+    _is_teardown = False
+    _on_failure_fail_dagrun = False
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
         (
@@ -149,6 +155,92 @@ class AbstractOperator(Templater, DAGNode):
     def node_id(self) -> str:
         return self.task_id
 
+    @property
+    def is_setup(self):
+        """
+        Whether the operator is a setup task.
+
+        :meta private:
+        """
+        return self._is_setup
+
+    @is_setup.setter
+    def is_setup(self, value):
+        """
+        Setter for is_setup property.
+
+        :meta private:
+        """
+        if self.is_teardown is True and value is True:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as setup; 
task is already a teardown.")
+        self._is_setup = value
+
+    @property
+    def is_teardown(self):
+        """
+        Whether the operator is a teardown task.
+
+        :meta private:
+        """
+        return self._is_teardown
+
+    @is_teardown.setter
+    def is_teardown(self, value):
+        """
+        Setter for is_teardown property.
+
+        :meta private:
+        """
+        if self.is_setup is True and value is True:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; 
task is already a setup.")
+        self._is_teardown = value
+
+    @property
+    def on_failure_fail_dagrun(self):
+        """
+        Whether the operator should fail the dagrun on failure.
+
+        :meta private:
+        """
+        return self._on_failure_fail_dagrun
+
+    @on_failure_fail_dagrun.setter
+    def on_failure_fail_dagrun(self, value):
+        """
+        Setter for on_failure_fail_dagrun property.
+
+        :meta private:
+        """
+        if value is True and self.is_teardown is not True:
+            raise ValueError(
+                f"Cannot set task on_failure_fail_dagrun for "
+                f"'{self.task_id}' because it is not a teardown task."
+            )
+        self._on_failure_fail_dagrun = value
+
+    def as_setup(self):
+        self.is_setup = True
+        return self
+
+    def as_teardown(
+        self,
+        *,
+        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+        on_failure_fail_dagrun=NOTSET,
+    ):
+        self.is_teardown = True
+        if TYPE_CHECKING:
+            assert isinstance(self, BaseOperator)  # is_teardown not supported 
for MappedOperator
+        self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
+        if on_failure_fail_dagrun is not NOTSET:
+            self.on_failure_fail_dagrun = on_failure_fail_dagrun
+        if not isinstance(setups, ArgNotSet):
+            setups = [setups] if isinstance(setups, DependencyMixin) else 
setups
+            for s in setups:
+                s.is_setup = True
+                s >> self
+        return self
+
     def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
         """Get direct relative IDs to the current task, upstream or 
downstream."""
         if upstream:
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 79d4637387..3d0812b62a 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -721,25 +721,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     # Set to True for an operator instantiated by a mapped operator.
     __from_mapped = False
 
-    is_setup = False
-    """
-    Whether the operator is a setup task
-
-    :meta private:
-    """
-    is_teardown = False
-    """
-    Whether the operator is a teardown task
-
-    :meta private:
-    """
-    on_failure_fail_dagrun = False
-    """
-    Whether the operator should fail the dagrun on failure
-
-    :meta private:
-    """
-
     def __init__(
         self,
         task_id: str,
@@ -976,22 +957,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if SetupTeardownContext.active:
             SetupTeardownContext.update_context_map(self)
 
-    @classmethod
-    def as_setup(cls, *args, **kwargs):
-        op = cls(*args, **kwargs)
-        op.is_setup = True
-        return op
-
-    @classmethod
-    def as_teardown(cls, *args, **kwargs):
-        on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
-        if "trigger_rule" in kwargs:
-            raise ValueError("Cannot set trigger rule for teardown tasks.")
-        op = cls(*args, **kwargs, 
trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
-        op.is_teardown = True
-        op.on_failure_fail_dagrun = on_failure_fail_dagrun
-        return op
-
     def __enter__(self):
         if not self.is_setup and not self.is_teardown:
             raise AirflowException("Only setup/teardown tasks can be used as 
context managers.")
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 66b75923fd..dd8b49fb98 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -290,9 +290,6 @@ class MappedOperator(AbstractOperator):
 
     subdag: None = None  # Since we don't support SubDagOperator, this is 
always None.
     supports_lineage: bool = False
-    is_setup: bool = False
-    is_teardown: bool = False
-    on_failure_fail_dagrun: bool = False
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = 
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
         (
@@ -327,6 +324,24 @@ class MappedOperator(AbstractOperator):
                 f"{self.task_id!r}."
             )
 
+    @AbstractOperator.is_setup.setter  # type: ignore[attr-defined]
+    def is_setup(self, value):
+        """
+        Setter for is_setup property. Disabled for MappedOperator.
+
+        :meta private:
+        """
+        raise ValueError("Cannot set is_setup for mapped operator.")
+
+    @AbstractOperator.is_teardown.setter  # type: ignore[attr-defined]
+    def is_teardown(self, value):
+        """
+        Setter for is_teardown property. Disabled for MappedOperator.
+
+        :meta private:
+        """
+        raise ValueError("Cannot set is_teardown for mapped operator.")
+
     @classmethod
     @cache
     def get_serialized_fields(cls):
@@ -391,6 +406,11 @@ class MappedOperator(AbstractOperator):
     def trigger_rule(self) -> TriggerRule:
         return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
 
+    @trigger_rule.setter
+    def trigger_rule(self, value):
+        # required for mypy which complains about overriding writeable attr 
with read-only property
+        raise ValueError("Cannot set trigger_rule for mapped operator.")
+
     @property
     def depends_on_past(self) -> bool:
         return bool(self.partial_kwargs.get("depends_on_past"))
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 0c1c94b7b8..38d050e29e 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -24,10 +24,12 @@ import pendulum
 
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.serialization.enums import DagAttributeTypes
+from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from logging import Logger
 
+    from airflow.models.baseoperator import BaseOperator
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
     from airflow.utils.edgemodifier import EdgeModifier
@@ -69,6 +71,19 @@ class DependencyMixin:
         """Set a task or a task list to be directly downstream from the 
current task."""
         raise NotImplementedError()
 
+    def as_setup(self) -> DependencyMixin:
+        """Mark a task as setup task."""
+        raise NotImplementedError()
+
+    def as_teardown(
+        self,
+        *,
+        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+        on_failure_fail_dagrun=NOTSET,
+    ) -> DependencyMixin:
+        """Mark a task as teardown and set its setups as direct relatives."""
+        raise NotImplementedError()
+
     def update_relative(
         self, other: DependencyMixin, upstream: bool = True, edge_modifier: 
EdgeModifier | None = None
     ) -> None:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 72cd2278f4..7024fbd8a9 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -19,13 +19,14 @@ from __future__ import annotations
 
 import contextlib
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, 
Union, overload
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, 
Sequence, Union, overload
 
 from sqlalchemy import func, or_
 from sqlalchemy.orm import Session
 
 from airflow.exceptions import AirflowException, XComNotFound
 from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.baseoperator import BaseOperator
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.utils.context import Context
@@ -34,6 +35,7 @@ from airflow.utils.mixins import ResolveMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.state import State
+from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import NOTSET, ArgNotSet
 from airflow.utils.xcom import XCOM_RETURN_KEY
 
@@ -296,6 +298,55 @@ class PlainXComArg(XComArg):
     def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
         return cls(dag.get_task(data["task_id"]), data["key"])
 
+    @property
+    def is_setup(self) -> bool:
+        return self.operator.is_setup
+
+    @is_setup.setter
+    def is_setup(self, val: bool):
+        self.operator.is_setup = val
+
+    @property
+    def is_teardown(self) -> bool:
+        return self.operator.is_teardown
+
+    @is_teardown.setter
+    def is_teardown(self, val: bool):
+        self.operator.is_teardown = val
+
+    @property
+    def on_failure_fail_dagrun(self) -> bool:
+        return self.operator.on_failure_fail_dagrun
+
+    @on_failure_fail_dagrun.setter
+    def on_failure_fail_dagrun(self, val: bool):
+        self.operator.on_failure_fail_dagrun = val
+
+    def as_setup(self) -> DependencyMixin:
+        for operator, _ in self.iter_references():
+            operator.is_setup = True
+        return self
+
+    def as_teardown(
+        self,
+        *,
+        setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+        on_failure_fail_dagrun=NOTSET,
+    ):
+        for operator, _ in self.iter_references():
+            operator.is_teardown = True
+            if TYPE_CHECKING:
+                assert isinstance(operator, BaseOperator)  # Can't set 
MappedOperator as teardown
+            operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
+            if on_failure_fail_dagrun is not NOTSET:
+                operator.on_failure_fail_dagrun = on_failure_fail_dagrun
+            if not isinstance(setups, ArgNotSet):
+                setups = [setups] if isinstance(setups, DependencyMixin) else 
setups
+                for s in setups:
+                    s.is_setup = True
+                    s >> operator
+        return self
+
     def iter_references(self) -> Iterator[tuple[Operator, str]]:
         yield self.operator, self.key
 
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 3a528e9c8a..8e53aa3465 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -960,7 +960,8 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 v = cls.deserialize(v)
             elif k in ("outlets", "inlets"):
                 v = cls.deserialize(v)
-
+            elif k == "on_failure_fail_dagrun":
+                k = "_on_failure_fail_dagrun"
             # else use v as it is
 
             setattr(op, k, v)
diff --git a/tests/decorators/test_setup_teardown.py 
b/tests/decorators/test_setup_teardown.py
index d91f7fac53..8b7f761798 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -65,7 +65,7 @@ class TestSetupTearDownTask:
 
     def test_marking_operator_as_setup_task(self, dag_maker):
         with dag_maker() as dag:
-            BashOperator.as_setup(task_id="mytask", bash_command='echo "I am a 
setup task"')
+            BashOperator(task_id="mytask", bash_command='echo "I am a setup 
task"').as_setup()
 
         assert len(dag.task_group.children) == 1
         setup_task = dag.task_group.children["mytask"]
@@ -86,7 +86,7 @@ class TestSetupTearDownTask:
 
     def test_marking_operator_as_teardown_task(self, dag_maker):
         with dag_maker() as dag:
-            BashOperator.as_teardown(task_id="mytask", bash_command='echo "I 
am a setup task"')
+            BashOperator(task_id="mytask", bash_command='echo "I am a setup 
task"').as_teardown()
 
         assert len(dag.task_group.children) == 1
         teardown_task = dag.task_group.children["mytask"]
@@ -146,11 +146,10 @@ class TestSetupTearDownTask:
     @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
     def test_classic_teardown_task_works_with_on_failure_fail_dagrun(self, 
on_failure_fail_dagrun, dag_maker):
         with dag_maker() as dag:
-            BashOperator.as_teardown(
+            BashOperator(
                 task_id="mytask",
                 bash_command='echo "I am a teardown task"',
-                on_failure_fail_dagrun=on_failure_fail_dagrun,
-            )
+            ).as_teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
 
         teardown_task = dag.task_group.children["mytask"]
         assert teardown_task.is_teardown
@@ -605,11 +604,11 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
 
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
-            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
+            teardowntask2 = BashOperator(task_id="teardowntask2", 
bash_command="echo 1").as_teardown()
             with setuptask >> teardowntask:
                 with setuptask2 >> teardowntask2:
                     mytask() >> mytask2()
@@ -643,11 +642,11 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
 
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
-            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
+            teardowntask2 = BashOperator(task_id="teardowntask2", 
bash_command="echo 1").as_teardown()
             with setuptask >> teardowntask:
                 with setuptask2 >> teardowntask2:
                     mytask() << mytask2()
@@ -676,7 +675,7 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
             with setuptask:
                 mytask() >> mytask2()
 
@@ -698,7 +697,7 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker("foo") as dag:
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
             with teardowntask:
                 mytask() >> mytask2()
 
@@ -720,10 +719,10 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
 
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
             with setuptask >> teardowntask:
                 with setuptask2:
                     mytask() << mytask2()
@@ -758,8 +757,8 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
             with setuptask:
                 t1 = mytask()
                 t2 = mytask2()
@@ -801,8 +800,8 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
             with setuptask:
                 t1 = mytask()
                 t2 = mytask2()
@@ -841,11 +840,11 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
-            setuptask2 = BashOperator.as_setup(task_id="setuptask2", 
bash_command="echo 1")
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
+            setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo 
1").as_setup()
 
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
-            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
+            teardowntask2 = BashOperator(task_id="teardowntask2", 
bash_command="echo 1").as_teardown()
             with setuptask >> teardowntask:
                 with setuptask2 >> teardowntask2:
                     mytask()
@@ -1047,9 +1046,9 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
-            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
+            teardowntask2 = BashOperator(task_id="teardowntask2", 
bash_command="echo 1").as_teardown()
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
             with [teardowntask, teardowntask2] << setuptask:
                 mytask()
 
@@ -1077,9 +1076,9 @@ class TestSetupTearDownTask:
             print("mytask")
 
         with dag_maker() as dag:
-            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
-            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
-            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
+            teardowntask = BashOperator(task_id="teardowntask", 
bash_command="echo 1").as_teardown()
+            teardowntask2 = BashOperator(task_id="teardowntask2", 
bash_command="echo 1").as_teardown()
+            setuptask = BashOperator(task_id="setuptask", bash_command="echo 
1").as_setup()
             with setuptask >> context_wrapper([teardowntask, teardowntask2]):
                 mytask()
 
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 3c739c115f..9eec01c44a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -3541,16 +3541,16 @@ class TestTaskClearingSetupTeardownBehavior:
         """
 
         def teardown_task(task_id):
-            return BaseOperator.as_teardown(task_id=task_id)
+            return BaseOperator(task_id=task_id).as_teardown()
 
         def teardown_task_f(task_id):
-            return BaseOperator.as_teardown(task_id=task_id, 
on_failure_fail_dagrun=True)
+            return 
BaseOperator(task_id=task_id).as_teardown(on_failure_fail_dagrun=True)
 
         def work_task(task_id):
             return BaseOperator(task_id=task_id)
 
         def setup_task(task_id):
-            return BaseOperator.as_setup(task_id=task_id)
+            return BaseOperator(task_id=task_id).as_setup()
 
         def make_task(task_id):
             """
@@ -3709,7 +3709,7 @@ class TestTaskClearingSetupTeardownBehavior:
         assert self.cleared_downstream(w1) == {s1, w1, w2, t1}
         assert self.cleared_downstream(w2) == {w2}
         # and if there's a downstream setup, it will be included as well
-        s2 = BaseOperator.as_setup(task_id="s2", dag=dag)
+        s2 = BaseOperator(task_id="s2", dag=dag).as_setup()
         t1 >> s2
         assert w1.get_flat_relative_ids(upstream=False) == {"t1", "w2", "s2"}
         assert self.cleared_downstream(w1) == {s1, w1, w2, t1, s2}
@@ -3755,16 +3755,16 @@ class TestTaskClearingSetupTeardownBehavior:
         """
         dag = DAG(dag_id="test_dag", start_date=pendulum.now())
         with dag:
-            dag_setup = BaseOperator.as_setup(task_id="dag_setup")
-            dag_teardown = BaseOperator.as_teardown(task_id="dag_teardown")
+            dag_setup = BaseOperator(task_id="dag_setup").as_setup()
+            dag_teardown = BaseOperator(task_id="dag_teardown").as_teardown()
             dag_setup >> dag_teardown
             for group_name in ("g1", "g2"):
                 with TaskGroup(group_name) as tg:
-                    group_setup = BaseOperator.as_setup(task_id="group_setup")
+                    group_setup = 
BaseOperator(task_id="group_setup").as_setup()
                     w1 = BaseOperator(task_id="w1")
                     w2 = BaseOperator(task_id="w2")
                     w3 = BaseOperator(task_id="w3")
-                    group_teardown = 
BaseOperator.as_teardown(task_id="group_teardown")
+                    group_teardown = 
BaseOperator(task_id="group_teardown").as_teardown()
                     group_setup >> w1 >> w2 >> w3 >> group_teardown
                     group_setup >> group_teardown
                 dag_setup >> tg >> dag_teardown
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 89aedaf3bb..a9b7ea9b08 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1304,7 +1304,7 @@ class TestTaskInstance:
                 task = EmptyOperator(task_id=f"work_{i}", dag=dag)
                 task.set_downstream(downstream)
             for i in range(upstream_setups):
-                task = EmptyOperator.as_setup(task_id=f"setup_{i}", dag=dag)
+                task = EmptyOperator(task_id=f"setup_{i}", dag=dag).as_setup()
                 task.set_downstream(downstream)
             assert task.start_date is not None
             run_date = task.start_date + datetime.timedelta(days=5)
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
new file mode 100644
index 0000000000..83a040b86e
--- /dev/null
+++ b/tests/models/test_taskmixin.py
@@ -0,0 +1,248 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from itertools import product
+
+import pytest
+
+from airflow.decorators import task
+from airflow.models.baseoperator import BaseOperator
+
+
+def cleared_tasks(dag, task_id):
+    dag_ = dag.partial_subset(task_id, include_downstream=True, 
include_upstream=False)
+    return {x.task_id for x in dag_.tasks}
+
+
+def get_task_attr(task_like, attr):
+    try:
+        return getattr(task_like, attr)
+    except AttributeError:
+        return getattr(task_like.operator, attr)
+
+
+def make_task(name, type_):
+    if type_ == "classic":
+        return BaseOperator(task_id=name)
+    else:
+
+        @task
+        def my_task():
+            pass
+
+        return my_task.override(task_id=name)()
+
+
[email protected]("setup_type, work_type, teardown_type", product(*3 * 
[["classic", "taskflow"]]))
+def test_as_teardown(dag_maker, setup_type, work_type, teardown_type):
+    """
+    Check that as_teardown works properly as implemented in PlainXComArg
+
+    It should mark the teardown as teardown, and if a task is provided, it 
should mark that as setup
+    and set it as a direct upstream.
+    """
+    with dag_maker() as dag:
+        s1 = make_task(name="s1", type_=setup_type)
+        w1 = make_task(name="w1", type_=work_type)
+        t1 = make_task(name="t1", type_=teardown_type)
+    # initial conditions
+    assert cleared_tasks(dag, "w1") == {"w1"}
+
+    # after setting deps, still none are setup / teardown
+    # verify relationships
+    s1 >> w1 >> t1
+    assert cleared_tasks(dag, "w1") == {"w1", "t1"}
+    assert get_task_attr(t1, "is_teardown") is False
+    assert get_task_attr(s1, "is_setup") is False
+    assert get_task_attr(t1, "upstream_task_ids") == {"w1"}
+
+    # now when we use as_teardown, s1 should be setup, t1 should be teardown, 
and we should have s1 >> t1
+    t1.as_teardown(setups=s1)
+    assert cleared_tasks(dag, "w1") == {"s1", "w1", "t1"}
+    assert get_task_attr(t1, "is_teardown") is True
+    assert get_task_attr(s1, "is_setup") is True
+    assert get_task_attr(t1, "upstream_task_ids") == {"w1", "s1"}
+
+
[email protected]("setup_type, work_type, teardown_type", product(*3 * 
[["classic", "taskflow"]]))
+def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type):
+    """
+    Check that as_teardown implementations work properly. Tests all 
combinations of taskflow and classic.
+
+    It should mark the teardown as teardown, and if a task is provided, it 
should mark that as setup
+    and set it as a direct upstream.
+    """
+
+    with dag_maker() as dag:
+        s1 = make_task(name="s1", type_=setup_type)
+        w1 = make_task(name="w1", type_=work_type)
+        t1 = make_task(name="t1", type_=teardown_type)
+
+    # verify initial conditions
+    for task_ in (s1, w1, t1):
+        assert get_task_attr(task_, "upstream_list") == []
+        assert get_task_attr(task_, "downstream_list") == []
+        assert get_task_attr(task_, "is_setup") is False
+        assert get_task_attr(task_, "is_teardown") is False
+        assert cleared_tasks(dag, get_task_attr(task_, "task_id")) == 
{get_task_attr(task_, "task_id")}
+
+    # now set the deps in one line
+    s1 >> w1 >> t1.as_teardown(setups=s1)
+
+    # verify resulting configuration
+    # should be equiv to the following:
+    #   * s1.is_setup = True
+    #   * t1.is_teardown = True
+    #   * s1 >> t1
+    #   * s1 >> w1 >> t1
+    for task_, exp_up, exp_down in [
+        (s1, set(), {"w1", "t1"}),
+        (w1, {"s1"}, {"t1"}),
+        (t1, {"s1", "w1"}, set()),
+    ]:
+        assert get_task_attr(task_, "upstream_task_ids") == exp_up
+        assert get_task_attr(task_, "downstream_task_ids") == exp_down
+    assert cleared_tasks(dag, "s1") == {"s1", "w1", "t1"}
+    assert cleared_tasks(dag, "w1") == {"s1", "w1", "t1"}
+    assert cleared_tasks(dag, "t1") == {"t1"}
+    for task_, exp_is_setup, exp_is_teardown in [
+        (s1, True, False),
+        (w1, False, False),
+        (t1, False, True),
+    ]:
+        assert get_task_attr(task_, "is_setup") is exp_is_setup
+        assert get_task_attr(task_, "is_teardown") is exp_is_teardown
+
+
[email protected]("type_", ["classic", "taskflow"])
+def test_cannot_be_both_setup_and_teardown(dag_maker, type_):
+    # can't change a setup task to a teardown task or vice versa
+    for first, second in [("setup", "teardown"), ("teardown", "setup")]:
+        with dag_maker():
+            s1 = make_task(name="s1", type_=type_)
+            getattr(s1, f"as_{first}")()
+            with pytest.raises(
+                ValueError, match=f"Cannot mark task 's1' as {second}; task is 
already a {first}."
+            ):
+                getattr(s1, f"as_{second}")()
+                s1.as_teardown()
+
+
+def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker):
+    with dag_maker():
+        t = make_task(name="t", type_="classic")
+        assert t.is_teardown is False
+        with pytest.raises(
+            ValueError,
+            match="Cannot set task on_failure_fail_dagrun for 't' because it 
is not a teardown task",
+        ):
+            t.on_failure_fail_dagrun = True
+
+
+def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker):
+    @task(on_failure_fail_dagrun=True)
+    def my_bad_task():
+        pass
+
+    @task
+    def my_ok_task():
+        pass
+
+    with dag_maker():
+        with pytest.raises(
+            ValueError,
+            match="Cannot set task on_failure_fail_dagrun for "
+            "'my_bad_task' because it is not a teardown task",
+        ):
+            my_bad_task()
+        # no issue
+        m = my_ok_task()
+        assert m.operator.is_teardown is False
+        # also fine
+        m = my_ok_task().as_teardown()
+        assert m.operator.is_teardown is True
+        assert m.operator.on_failure_fail_dagrun is False
+        # and also fine
+        m = my_ok_task().as_teardown(on_failure_fail_dagrun=True)
+        assert m.operator.is_teardown is True
+        assert m.operator.on_failure_fail_dagrun is True
+        # but we can't unset
+        with pytest.raises(
+            ValueError, match="Cannot mark task 'my_ok_task__2' as setup; task 
is already a teardown."
+        ):
+            m.as_setup()
+        with pytest.raises(
+            ValueError, match="Cannot mark task 'my_ok_task__2' as setup; task 
is already a teardown."
+        ):
+            m.operator.is_setup = True
+
+
+def test_no_setup_or_teardown_for_mapped_operator(dag_maker):
+    @task
+    def add_one(x):
+        return x + 1
+
+    @task
+    def print_task(values):
+        print(sum(values))
+
+    # vanilla mapped task
+    with dag_maker():
+        added_vals = add_one.expand(x=[1, 2, 3])
+        print_task(added_vals)
+
+    # combining setup and teardown with vanilla mapped task is fine
+    with dag_maker():
+        s1 = BaseOperator(task_id="s1").as_setup()
+        t1 = BaseOperator(task_id="t1").as_teardown(setups=s1)
+        added_vals = add_one.expand(x=[1, 2, 3])
+        print_task_task = print_task(added_vals)
+        s1 >> added_vals
+        print_task_task >> t1
+    # confirm structure
+    assert s1.downstream_task_ids == {"add_one", "t1"}
+    assert t1.upstream_task_ids == {"print_task", "s1"}
+    assert added_vals.operator.upstream_task_ids == {"s1"}
+    assert added_vals.operator.downstream_task_ids == {"print_task"}
+    assert print_task_task.operator.upstream_task_ids == {"add_one"}
+    assert print_task_task.operator.downstream_task_ids == {"t1"}
+
+    # but you can't use a mapped task as setup or teardown
+    with dag_maker():
+        added_vals = add_one.expand(x=[1, 2, 3])
+        with pytest.raises(ValueError, match="Cannot set is_teardown for 
mapped operator"):
+            added_vals.as_teardown()
+
+    # ... no matter how hard you try
+    with dag_maker():
+        added_vals = add_one.expand(x=[1, 2, 3])
+        with pytest.raises(ValueError, match="Cannot set is_teardown for 
mapped operator"):
+            added_vals.is_teardown = True
+
+    # same with setup
+    with dag_maker():
+        added_vals = add_one.expand(x=[1, 2, 3])
+        with pytest.raises(ValueError, match="Cannot set is_setup for mapped 
operator"):
+            added_vals.as_setup()
+
+    # and again, trying harder...
+    with dag_maker():
+        added_vals = add_one.expand(x=[1, 2, 3])
+        with pytest.raises(ValueError, match="Cannot set is_setup for mapped 
operator"):
+            added_vals.is_setup = True
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 221d7a3245..69a2c9df22 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1328,18 +1328,18 @@ class TestStringifiedDAGs:
 
         execution_date = datetime(2020, 1, 1)
         with DAG("test_task_group_setup_teardown_tasks", 
start_date=execution_date) as dag:
-            EmptyOperator.as_setup(task_id="setup")
-            EmptyOperator.as_teardown(task_id="teardown")
+            EmptyOperator(task_id="setup").as_setup()
+            EmptyOperator(task_id="teardown").as_teardown()
 
             with TaskGroup("group1"):
-                EmptyOperator.as_setup(task_id="setup1")
+                EmptyOperator(task_id="setup1").as_setup()
                 EmptyOperator(task_id="task1")
-                EmptyOperator.as_teardown(task_id="teardown1")
+                EmptyOperator(task_id="teardown1").as_teardown()
 
                 with TaskGroup("group2"):
-                    EmptyOperator.as_setup(task_id="setup2")
+                    EmptyOperator(task_id="setup2").as_setup()
                     EmptyOperator(task_id="task2")
-                    EmptyOperator.as_teardown(task_id="teardown2")
+                    EmptyOperator(task_id="teardown2").as_teardown()
 
         dag_dict = SerializedDAG.to_dict(dag)
         SerializedDAG.validate_schema(dag_dict)
@@ -1394,8 +1394,8 @@ class TestStringifiedDAGs:
 
         serialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
         task = serialized_dag.task_group.children["mytask"]
-        assert task.is_teardown
-        assert task.on_failure_fail_dagrun
+        assert task.is_teardown is True
+        assert task.on_failure_fail_dagrun is True
 
     def test_deps_sorted(self):
         """
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py 
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 2beaa7097e..faa70b5a49 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -64,7 +64,7 @@ def get_task_instance(monkeypatch, session, dag_maker):
             for task_id in normal_tasks or []:
                 EmptyOperator(task_id=task_id) >> task
             for task_id in setup_tasks or []:
-                EmptyOperator.as_setup(task_id=task_id) >> task
+                EmptyOperator(task_id=task_id).as_setup() >> task
         dr = dag_maker.create_dagrun()
         ti = dr.task_instances[0]
         ti.task = task


Reply via email to