This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 016ce99486 Change `as_setup` and `as_teardown` to instance methods
(#32053)
016ce99486 is described below
commit 016ce9948625a556093b0182439aa50314c651da
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Jun 26 10:32:25 2023 -0700
Change `as_setup` and `as_teardown` to instance methods (#32053)
This provides a number of benefits.
* provides a oneline syntax for setting setup / teardown deps
* makes it easy to convert dags to use feature
* provides a mechanism to combine "reusable" taskflow tasks with setup /
teardown
* set setup and teardown in the same place you set deps
---------
Co-authored-by: Ephraim Anierobi <[email protected]>
---
airflow/example_dags/example_setup_teardown.py | 16 +-
.../example_setup_teardown_taskflow.py | 65 ++++--
airflow/models/abstractoperator.py | 94 +++++++-
airflow/models/baseoperator.py | 35 ---
airflow/models/mappedoperator.py | 26 ++-
airflow/models/taskmixin.py | 15 ++
airflow/models/xcom_arg.py | 53 ++++-
airflow/serialization/serialized_objects.py | 3 +-
tests/decorators/test_setup_teardown.py | 63 +++---
tests/models/test_dag.py | 16 +-
tests/models/test_taskinstance.py | 2 +-
tests/models/test_taskmixin.py | 248 +++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 16 +-
tests/ti_deps/deps/test_trigger_rule_dep.py | 2 +-
14 files changed, 534 insertions(+), 120 deletions(-)
diff --git a/airflow/example_dags/example_setup_teardown.py
b/airflow/example_dags/example_setup_teardown.py
index 77d7d5bdc6..59aba9753a 100644
--- a/airflow/example_dags/example_setup_teardown.py
+++ b/airflow/example_dags/example_setup_teardown.py
@@ -30,21 +30,19 @@ with DAG(
catchup=False,
tags=["example"],
) as dag:
- root_setup = BashOperator.as_setup(task_id="root_setup",
bash_command="echo 'Hello from root_setup'")
+ root_setup = BashOperator(task_id="root_setup", bash_command="echo 'Hello
from root_setup'").as_setup()
root_normal = BashOperator(task_id="normal", bash_command="echo 'I am just
a normal task'")
- root_teardown = BashOperator.as_teardown(
+ root_teardown = BashOperator(
task_id="root_teardown", bash_command="echo 'Goodbye from
root_teardown'"
- )
+ ).as_teardown(setups=root_setup)
root_setup >> root_normal >> root_teardown
- root_setup >> root_teardown
with TaskGroup("section_1") as section_1:
- inner_setup = BashOperator.as_setup(
+ inner_setup = BashOperator(
task_id="taskgroup_setup", bash_command="echo 'Hello from
taskgroup_setup'"
- )
+ ).as_setup()
inner_normal = BashOperator(task_id="normal", bash_command="echo 'I am
just a normal task'")
- inner_teardown = BashOperator.as_teardown(
+ inner_teardown = BashOperator(
task_id="taskgroup_teardown", bash_command="echo 'Hello from
taskgroup_teardown'"
- )
+ ).as_teardown(setups=inner_setup)
inner_setup >> inner_normal >> inner_teardown
- inner_setup >> inner_teardown
root_normal >> section_1
diff --git a/airflow/example_dags/example_setup_teardown_taskflow.py
b/airflow/example_dags/example_setup_teardown_taskflow.py
index 245cc6a2e9..128534f1d2 100644
--- a/airflow/example_dags/example_setup_teardown_taskflow.py
+++ b/airflow/example_dags/example_setup_teardown_taskflow.py
@@ -29,30 +29,61 @@ with DAG(
catchup=False,
tags=["example"],
) as dag:
- # You can use the setup and teardown decorators to add setup and teardown
tasks at the DAG level
- @setup
+
@task
- def root_setup():
- print("Hello from root_setup")
+ def task_1():
+ print("Hello 1")
- @teardown
@task
- def root_teardown():
- print("Goodbye from root_teardown")
+ def task_2():
+ print("Hello 2")
+
+ @task
+ def task_3():
+ print("Hello 3")
+
+ # you can set setup / teardown relationships with the `as_teardown` method.
+ t1 = task_1()
+ t2 = task_2()
+ t3 = task_3()
+ t1 >> t2 >> t3.as_teardown(setups=t1)
+
+ # the method `as_teadrown` will mark t3 as teardown, t1 as setup, and
arrow t1 >> t3
+ # now if you clear t2 (downstream), then t1 will be cleared in addition to
t3
+
+ # it's also possible to use a decorator to mark a task as setup or
+ # teardown when you define it. see below.
+
+ @setup
+ def dag_setup():
+ print("I am dag_setup")
+
+ @teardown
+ def dag_teardown():
+ print("I am dag_teardown")
@task
- def normal():
+ def dag_normal_task():
print("I am just a normal task")
+ s = dag_setup()
+ t = dag_teardown()
+
+ # by using the decorators, dag_setup and dag_teardown are already marked
as setup / teardown
+ # now we just need to make sure they are linked directly
+ # what we need to do is this::
+ # s >> t
+ # s >> dag_normal_task() >> t
+ # but we can use a context manager to make it cleaner
+ with s >> t:
+ dag_normal_task()
+
@task_group
def section_1():
- # You can also have setup and teardown tasks at the task group level
- @setup
@task
def my_setup():
print("I set up")
- @teardown
@task
def my_teardown():
print("I tear down")
@@ -61,13 +92,7 @@ with DAG(
def hello():
print("I say hello")
- s = my_setup()
- w = hello()
- t = my_teardown()
- s >> w >> t
- s >> t
+ (s := my_setup()) >> hello() >> my_teardown().as_teardown(setups=s)
- rs = root_setup()
- normal() >> section_1()
- rt = root_teardown()
- rs >> rt
+ # and let's put section 1 inside the "dag setup" and "dag teardown"
+ s >> section_1() >> t
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index bf06b39da9..ff4f5c4140 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -26,7 +26,7 @@ from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
-from airflow.models.taskmixin import DAGNode
+from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.log.secrets_masker import redact
@@ -35,6 +35,7 @@ from airflow.utils.sqlalchemy import skip_locked,
with_row_locks
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.weight_rule import WeightRule
TaskStateChangeCallback = Callable[[Context], None]
@@ -102,6 +103,11 @@ class AbstractOperator(Templater, DAGNode):
outlets: list
inlets: list
+ trigger_rule: TriggerRule
+
+ _is_setup = False
+ _is_teardown = False
+ _on_failure_fail_dagrun = False
HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
(
@@ -149,6 +155,92 @@ class AbstractOperator(Templater, DAGNode):
def node_id(self) -> str:
return self.task_id
+ @property
+ def is_setup(self):
+ """
+ Whether the operator is a setup task.
+
+ :meta private:
+ """
+ return self._is_setup
+
+ @is_setup.setter
+ def is_setup(self, value):
+ """
+ Setter for is_setup property.
+
+ :meta private:
+ """
+ if self.is_teardown is True and value is True:
+ raise ValueError(f"Cannot mark task '{self.task_id}' as setup;
task is already a teardown.")
+ self._is_setup = value
+
+ @property
+ def is_teardown(self):
+ """
+ Whether the operator is a teardown task.
+
+ :meta private:
+ """
+ return self._is_teardown
+
+ @is_teardown.setter
+ def is_teardown(self, value):
+ """
+ Setter for is_teardown property.
+
+ :meta private:
+ """
+ if self.is_setup is True and value is True:
+ raise ValueError(f"Cannot mark task '{self.task_id}' as teardown;
task is already a setup.")
+ self._is_teardown = value
+
+ @property
+ def on_failure_fail_dagrun(self):
+ """
+ Whether the operator should fail the dagrun on failure.
+
+ :meta private:
+ """
+ return self._on_failure_fail_dagrun
+
+ @on_failure_fail_dagrun.setter
+ def on_failure_fail_dagrun(self, value):
+ """
+ Setter for on_failure_fail_dagrun property.
+
+ :meta private:
+ """
+ if value is True and self.is_teardown is not True:
+ raise ValueError(
+ f"Cannot set task on_failure_fail_dagrun for "
+ f"'{self.task_id}' because it is not a teardown task."
+ )
+ self._on_failure_fail_dagrun = value
+
+ def as_setup(self):
+ self.is_setup = True
+ return self
+
+ def as_teardown(
+ self,
+ *,
+ setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+ on_failure_fail_dagrun=NOTSET,
+ ):
+ self.is_teardown = True
+ if TYPE_CHECKING:
+ assert isinstance(self, BaseOperator) # is_teardown not supported
for MappedOperator
+ self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
+ if on_failure_fail_dagrun is not NOTSET:
+ self.on_failure_fail_dagrun = on_failure_fail_dagrun
+ if not isinstance(setups, ArgNotSet):
+ setups = [setups] if isinstance(setups, DependencyMixin) else
setups
+ for s in setups:
+ s.is_setup = True
+ s >> self
+ return self
+
def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
"""Get direct relative IDs to the current task, upstream or
downstream."""
if upstream:
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 79d4637387..3d0812b62a 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -721,25 +721,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# Set to True for an operator instantiated by a mapped operator.
__from_mapped = False
- is_setup = False
- """
- Whether the operator is a setup task
-
- :meta private:
- """
- is_teardown = False
- """
- Whether the operator is a teardown task
-
- :meta private:
- """
- on_failure_fail_dagrun = False
- """
- Whether the operator should fail the dagrun on failure
-
- :meta private:
- """
-
def __init__(
self,
task_id: str,
@@ -976,22 +957,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
if SetupTeardownContext.active:
SetupTeardownContext.update_context_map(self)
- @classmethod
- def as_setup(cls, *args, **kwargs):
- op = cls(*args, **kwargs)
- op.is_setup = True
- return op
-
- @classmethod
- def as_teardown(cls, *args, **kwargs):
- on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
- if "trigger_rule" in kwargs:
- raise ValueError("Cannot set trigger rule for teardown tasks.")
- op = cls(*args, **kwargs,
trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
- op.is_teardown = True
- op.on_failure_fail_dagrun = on_failure_fail_dagrun
- return op
-
def __enter__(self):
if not self.is_setup and not self.is_teardown:
raise AirflowException("Only setup/teardown tasks can be used as
context managers.")
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 66b75923fd..dd8b49fb98 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -290,9 +290,6 @@ class MappedOperator(AbstractOperator):
subdag: None = None # Since we don't support SubDagOperator, this is
always None.
supports_lineage: bool = False
- is_setup: bool = False
- is_teardown: bool = False
- on_failure_fail_dagrun: bool = False
HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] =
AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
(
@@ -327,6 +324,24 @@ class MappedOperator(AbstractOperator):
f"{self.task_id!r}."
)
+ @AbstractOperator.is_setup.setter # type: ignore[attr-defined]
+ def is_setup(self, value):
+ """
+ Setter for is_setup property. Disabled for MappedOperator.
+
+ :meta private:
+ """
+ raise ValueError("Cannot set is_setup for mapped operator.")
+
+ @AbstractOperator.is_teardown.setter # type: ignore[attr-defined]
+ def is_teardown(self, value):
+ """
+ Setter for is_teardown property. Disabled for MappedOperator.
+
+ :meta private:
+ """
+ raise ValueError("Cannot set is_teardown for mapped operator.")
+
@classmethod
@cache
def get_serialized_fields(cls):
@@ -391,6 +406,11 @@ class MappedOperator(AbstractOperator):
def trigger_rule(self) -> TriggerRule:
return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
+ @trigger_rule.setter
+ def trigger_rule(self, value):
+ # required for mypy which complains about overriding writeable attr
with read-only property
+ raise ValueError("Cannot set trigger_rule for mapped operator.")
+
@property
def depends_on_past(self) -> bool:
return bool(self.partial_kwargs.get("depends_on_past"))
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 0c1c94b7b8..38d050e29e 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -24,10 +24,12 @@ import pendulum
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.serialization.enums import DagAttributeTypes
+from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from logging import Logger
+ from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.utils.edgemodifier import EdgeModifier
@@ -69,6 +71,19 @@ class DependencyMixin:
"""Set a task or a task list to be directly downstream from the
current task."""
raise NotImplementedError()
+ def as_setup(self) -> DependencyMixin:
+ """Mark a task as setup task."""
+ raise NotImplementedError()
+
+ def as_teardown(
+ self,
+ *,
+ setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+ on_failure_fail_dagrun=NOTSET,
+ ) -> DependencyMixin:
+ """Mark a task as teardown and set its setups as direct relatives."""
+ raise NotImplementedError()
+
def update_relative(
self, other: DependencyMixin, upstream: bool = True, edge_modifier:
EdgeModifier | None = None
) -> None:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 72cd2278f4..7024fbd8a9 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -19,13 +19,14 @@ from __future__ import annotations
import contextlib
import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence,
Union, overload
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping,
Sequence, Union, overload
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.utils.context import Context
@@ -34,6 +35,7 @@ from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.state import State
+from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -296,6 +298,55 @@ class PlainXComArg(XComArg):
def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
return cls(dag.get_task(data["task_id"]), data["key"])
+ @property
+ def is_setup(self) -> bool:
+ return self.operator.is_setup
+
+ @is_setup.setter
+ def is_setup(self, val: bool):
+ self.operator.is_setup = val
+
+ @property
+ def is_teardown(self) -> bool:
+ return self.operator.is_teardown
+
+ @is_teardown.setter
+ def is_teardown(self, val: bool):
+ self.operator.is_teardown = val
+
+ @property
+ def on_failure_fail_dagrun(self) -> bool:
+ return self.operator.on_failure_fail_dagrun
+
+ @on_failure_fail_dagrun.setter
+ def on_failure_fail_dagrun(self, val: bool):
+ self.operator.on_failure_fail_dagrun = val
+
+ def as_setup(self) -> DependencyMixin:
+ for operator, _ in self.iter_references():
+ operator.is_setup = True
+ return self
+
+ def as_teardown(
+ self,
+ *,
+ setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
+ on_failure_fail_dagrun=NOTSET,
+ ):
+ for operator, _ in self.iter_references():
+ operator.is_teardown = True
+ if TYPE_CHECKING:
+ assert isinstance(operator, BaseOperator) # Can't set
MappedOperator as teardown
+ operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
+ if on_failure_fail_dagrun is not NOTSET:
+ operator.on_failure_fail_dagrun = on_failure_fail_dagrun
+ if not isinstance(setups, ArgNotSet):
+ setups = [setups] if isinstance(setups, DependencyMixin) else
setups
+ for s in setups:
+ s.is_setup = True
+ s >> operator
+ return self
+
def iter_references(self) -> Iterator[tuple[Operator, str]]:
yield self.operator, self.key
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 3a528e9c8a..8e53aa3465 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -960,7 +960,8 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
v = cls.deserialize(v)
elif k in ("outlets", "inlets"):
v = cls.deserialize(v)
-
+ elif k == "on_failure_fail_dagrun":
+ k = "_on_failure_fail_dagrun"
# else use v as it is
setattr(op, k, v)
diff --git a/tests/decorators/test_setup_teardown.py
b/tests/decorators/test_setup_teardown.py
index d91f7fac53..8b7f761798 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -65,7 +65,7 @@ class TestSetupTearDownTask:
def test_marking_operator_as_setup_task(self, dag_maker):
with dag_maker() as dag:
- BashOperator.as_setup(task_id="mytask", bash_command='echo "I am a
setup task"')
+ BashOperator(task_id="mytask", bash_command='echo "I am a setup
task"').as_setup()
assert len(dag.task_group.children) == 1
setup_task = dag.task_group.children["mytask"]
@@ -86,7 +86,7 @@ class TestSetupTearDownTask:
def test_marking_operator_as_teardown_task(self, dag_maker):
with dag_maker() as dag:
- BashOperator.as_teardown(task_id="mytask", bash_command='echo "I
am a setup task"')
+ BashOperator(task_id="mytask", bash_command='echo "I am a setup
task"').as_teardown()
assert len(dag.task_group.children) == 1
teardown_task = dag.task_group.children["mytask"]
@@ -146,11 +146,10 @@ class TestSetupTearDownTask:
@pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
def test_classic_teardown_task_works_with_on_failure_fail_dagrun(self,
on_failure_fail_dagrun, dag_maker):
with dag_maker() as dag:
- BashOperator.as_teardown(
+ BashOperator(
task_id="mytask",
bash_command='echo "I am a teardown task"',
- on_failure_fail_dagrun=on_failure_fail_dagrun,
- )
+ ).as_teardown(on_failure_fail_dagrun=on_failure_fail_dagrun)
teardown_task = dag.task_group.children["mytask"]
assert teardown_task.is_teardown
@@ -605,11 +604,11 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
- teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
+ teardowntask2 = BashOperator(task_id="teardowntask2",
bash_command="echo 1").as_teardown()
with setuptask >> teardowntask:
with setuptask2 >> teardowntask2:
mytask() >> mytask2()
@@ -643,11 +642,11 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
- teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
+ teardowntask2 = BashOperator(task_id="teardowntask2",
bash_command="echo 1").as_teardown()
with setuptask >> teardowntask:
with setuptask2 >> teardowntask2:
mytask() << mytask2()
@@ -676,7 +675,7 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
with setuptask:
mytask() >> mytask2()
@@ -698,7 +697,7 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker("foo") as dag:
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
with teardowntask:
mytask() >> mytask2()
@@ -720,10 +719,10 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
with setuptask >> teardowntask:
with setuptask2:
mytask() << mytask2()
@@ -758,8 +757,8 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
with setuptask:
t1 = mytask()
t2 = mytask2()
@@ -801,8 +800,8 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
with setuptask:
t1 = mytask()
t2 = mytask2()
@@ -841,11 +840,11 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
- setuptask2 = BashOperator.as_setup(task_id="setuptask2",
bash_command="echo 1")
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
+ setuptask2 = BashOperator(task_id="setuptask2", bash_command="echo
1").as_setup()
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
- teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
+ teardowntask2 = BashOperator(task_id="teardowntask2",
bash_command="echo 1").as_teardown()
with setuptask >> teardowntask:
with setuptask2 >> teardowntask2:
mytask()
@@ -1047,9 +1046,9 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
- teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
+ teardowntask2 = BashOperator(task_id="teardowntask2",
bash_command="echo 1").as_teardown()
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
with [teardowntask, teardowntask2] << setuptask:
mytask()
@@ -1077,9 +1076,9 @@ class TestSetupTearDownTask:
print("mytask")
with dag_maker() as dag:
- teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
- teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
- setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
+ teardowntask = BashOperator(task_id="teardowntask",
bash_command="echo 1").as_teardown()
+ teardowntask2 = BashOperator(task_id="teardowntask2",
bash_command="echo 1").as_teardown()
+ setuptask = BashOperator(task_id="setuptask", bash_command="echo
1").as_setup()
with setuptask >> context_wrapper([teardowntask, teardowntask2]):
mytask()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 3c739c115f..9eec01c44a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -3541,16 +3541,16 @@ class TestTaskClearingSetupTeardownBehavior:
"""
def teardown_task(task_id):
- return BaseOperator.as_teardown(task_id=task_id)
+ return BaseOperator(task_id=task_id).as_teardown()
def teardown_task_f(task_id):
- return BaseOperator.as_teardown(task_id=task_id,
on_failure_fail_dagrun=True)
+ return
BaseOperator(task_id=task_id).as_teardown(on_failure_fail_dagrun=True)
def work_task(task_id):
return BaseOperator(task_id=task_id)
def setup_task(task_id):
- return BaseOperator.as_setup(task_id=task_id)
+ return BaseOperator(task_id=task_id).as_setup()
def make_task(task_id):
"""
@@ -3709,7 +3709,7 @@ class TestTaskClearingSetupTeardownBehavior:
assert self.cleared_downstream(w1) == {s1, w1, w2, t1}
assert self.cleared_downstream(w2) == {w2}
# and if there's a downstream setup, it will be included as well
- s2 = BaseOperator.as_setup(task_id="s2", dag=dag)
+ s2 = BaseOperator(task_id="s2", dag=dag).as_setup()
t1 >> s2
assert w1.get_flat_relative_ids(upstream=False) == {"t1", "w2", "s2"}
assert self.cleared_downstream(w1) == {s1, w1, w2, t1, s2}
@@ -3755,16 +3755,16 @@ class TestTaskClearingSetupTeardownBehavior:
"""
dag = DAG(dag_id="test_dag", start_date=pendulum.now())
with dag:
- dag_setup = BaseOperator.as_setup(task_id="dag_setup")
- dag_teardown = BaseOperator.as_teardown(task_id="dag_teardown")
+ dag_setup = BaseOperator(task_id="dag_setup").as_setup()
+ dag_teardown = BaseOperator(task_id="dag_teardown").as_teardown()
dag_setup >> dag_teardown
for group_name in ("g1", "g2"):
with TaskGroup(group_name) as tg:
- group_setup = BaseOperator.as_setup(task_id="group_setup")
+ group_setup =
BaseOperator(task_id="group_setup").as_setup()
w1 = BaseOperator(task_id="w1")
w2 = BaseOperator(task_id="w2")
w3 = BaseOperator(task_id="w3")
- group_teardown =
BaseOperator.as_teardown(task_id="group_teardown")
+ group_teardown =
BaseOperator(task_id="group_teardown").as_teardown()
group_setup >> w1 >> w2 >> w3 >> group_teardown
group_setup >> group_teardown
dag_setup >> tg >> dag_teardown
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 89aedaf3bb..a9b7ea9b08 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1304,7 +1304,7 @@ class TestTaskInstance:
task = EmptyOperator(task_id=f"work_{i}", dag=dag)
task.set_downstream(downstream)
for i in range(upstream_setups):
- task = EmptyOperator.as_setup(task_id=f"setup_{i}", dag=dag)
+ task = EmptyOperator(task_id=f"setup_{i}", dag=dag).as_setup()
task.set_downstream(downstream)
assert task.start_date is not None
run_date = task.start_date + datetime.timedelta(days=5)
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
new file mode 100644
index 0000000000..83a040b86e
--- /dev/null
+++ b/tests/models/test_taskmixin.py
@@ -0,0 +1,248 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from itertools import product
+
+import pytest
+
+from airflow.decorators import task
+from airflow.models.baseoperator import BaseOperator
+
+
+def cleared_tasks(dag, task_id):
+ dag_ = dag.partial_subset(task_id, include_downstream=True,
include_upstream=False)
+ return {x.task_id for x in dag_.tasks}
+
+
+def get_task_attr(task_like, attr):
+ try:
+ return getattr(task_like, attr)
+ except AttributeError:
+ return getattr(task_like.operator, attr)
+
+
+def make_task(name, type_):
+ if type_ == "classic":
+ return BaseOperator(task_id=name)
+ else:
+
+ @task
+ def my_task():
+ pass
+
+ return my_task.override(task_id=name)()
+
+
[email protected]("setup_type, work_type, teardown_type", product(*3 *
[["classic", "taskflow"]]))
+def test_as_teardown(dag_maker, setup_type, work_type, teardown_type):
+ """
+ Check that as_teardown works properly as implemented in PlainXComArg
+
+ It should mark the teardown as teardown, and if a task is provided, it
should mark that as setup
+ and set it as a direct upstream.
+ """
+ with dag_maker() as dag:
+ s1 = make_task(name="s1", type_=setup_type)
+ w1 = make_task(name="w1", type_=work_type)
+ t1 = make_task(name="t1", type_=teardown_type)
+ # initial conditions
+ assert cleared_tasks(dag, "w1") == {"w1"}
+
+ # after setting deps, still none are setup / teardown
+ # verify relationships
+ s1 >> w1 >> t1
+ assert cleared_tasks(dag, "w1") == {"w1", "t1"}
+ assert get_task_attr(t1, "is_teardown") is False
+ assert get_task_attr(s1, "is_setup") is False
+ assert get_task_attr(t1, "upstream_task_ids") == {"w1"}
+
+ # now when we use as_teardown, s1 should be setup, t1 should be teardown,
and we should have s1 >> t1
+ t1.as_teardown(setups=s1)
+ assert cleared_tasks(dag, "w1") == {"s1", "w1", "t1"}
+ assert get_task_attr(t1, "is_teardown") is True
+ assert get_task_attr(s1, "is_setup") is True
+ assert get_task_attr(t1, "upstream_task_ids") == {"w1", "s1"}
+
+
[email protected]("setup_type, work_type, teardown_type", product(*3 *
[["classic", "taskflow"]]))
+def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type):
+ """
+ Check that as_teardown implementations work properly. Tests all
combinations of taskflow and classic.
+
+ It should mark the teardown as teardown, and if a task is provided, it
should mark that as setup
+ and set it as a direct upstream.
+ """
+
+ with dag_maker() as dag:
+ s1 = make_task(name="s1", type_=setup_type)
+ w1 = make_task(name="w1", type_=work_type)
+ t1 = make_task(name="t1", type_=teardown_type)
+
+ # verify initial conditions
+ for task_ in (s1, w1, t1):
+ assert get_task_attr(task_, "upstream_list") == []
+ assert get_task_attr(task_, "downstream_list") == []
+ assert get_task_attr(task_, "is_setup") is False
+ assert get_task_attr(task_, "is_teardown") is False
+ assert cleared_tasks(dag, get_task_attr(task_, "task_id")) ==
{get_task_attr(task_, "task_id")}
+
+ # now set the deps in one line
+ s1 >> w1 >> t1.as_teardown(setups=s1)
+
+ # verify resulting configuration
+ # should be equiv to the following:
+ # * s1.is_setup = True
+ # * t1.is_teardown = True
+ # * s1 >> t1
+ # * s1 >> w1 >> t1
+ for task_, exp_up, exp_down in [
+ (s1, set(), {"w1", "t1"}),
+ (w1, {"s1"}, {"t1"}),
+ (t1, {"s1", "w1"}, set()),
+ ]:
+ assert get_task_attr(task_, "upstream_task_ids") == exp_up
+ assert get_task_attr(task_, "downstream_task_ids") == exp_down
+ assert cleared_tasks(dag, "s1") == {"s1", "w1", "t1"}
+ assert cleared_tasks(dag, "w1") == {"s1", "w1", "t1"}
+ assert cleared_tasks(dag, "t1") == {"t1"}
+ for task_, exp_is_setup, exp_is_teardown in [
+ (s1, True, False),
+ (w1, False, False),
+ (t1, False, True),
+ ]:
+ assert get_task_attr(task_, "is_setup") is exp_is_setup
+ assert get_task_attr(task_, "is_teardown") is exp_is_teardown
+
+
[email protected]("type_", ["classic", "taskflow"])
+def test_cannot_be_both_setup_and_teardown(dag_maker, type_):
+ # can't change a setup task to a teardown task or vice versa
+ for first, second in [("setup", "teardown"), ("teardown", "setup")]:
+ with dag_maker():
+ s1 = make_task(name="s1", type_=type_)
+ getattr(s1, f"as_{first}")()
+ with pytest.raises(
+ ValueError, match=f"Cannot mark task 's1' as {second}; task is
already a {first}."
+ ):
+ getattr(s1, f"as_{second}")()
+ s1.as_teardown()
+
+
+def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker):
+ with dag_maker():
+ t = make_task(name="t", type_="classic")
+ assert t.is_teardown is False
+ with pytest.raises(
+ ValueError,
+ match="Cannot set task on_failure_fail_dagrun for 't' because it
is not a teardown task",
+ ):
+ t.on_failure_fail_dagrun = True
+
+
+def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker):
+ @task(on_failure_fail_dagrun=True)
+ def my_bad_task():
+ pass
+
+ @task
+ def my_ok_task():
+ pass
+
+ with dag_maker():
+ with pytest.raises(
+ ValueError,
+ match="Cannot set task on_failure_fail_dagrun for "
+ "'my_bad_task' because it is not a teardown task",
+ ):
+ my_bad_task()
+ # no issue
+ m = my_ok_task()
+ assert m.operator.is_teardown is False
+ # also fine
+ m = my_ok_task().as_teardown()
+ assert m.operator.is_teardown is True
+ assert m.operator.on_failure_fail_dagrun is False
+ # and also fine
+ m = my_ok_task().as_teardown(on_failure_fail_dagrun=True)
+ assert m.operator.is_teardown is True
+ assert m.operator.on_failure_fail_dagrun is True
+ # but we can't unset
+ with pytest.raises(
+ ValueError, match="Cannot mark task 'my_ok_task__2' as setup; task
is already a teardown."
+ ):
+ m.as_setup()
+ with pytest.raises(
+ ValueError, match="Cannot mark task 'my_ok_task__2' as setup; task
is already a teardown."
+ ):
+ m.operator.is_setup = True
+
+
+def test_no_setup_or_teardown_for_mapped_operator(dag_maker):
+ @task
+ def add_one(x):
+ return x + 1
+
+ @task
+ def print_task(values):
+ print(sum(values))
+
+ # vanilla mapped task
+ with dag_maker():
+ added_vals = add_one.expand(x=[1, 2, 3])
+ print_task(added_vals)
+
+ # combining setup and teardown with vanilla mapped task is fine
+ with dag_maker():
+ s1 = BaseOperator(task_id="s1").as_setup()
+ t1 = BaseOperator(task_id="t1").as_teardown(setups=s1)
+ added_vals = add_one.expand(x=[1, 2, 3])
+ print_task_task = print_task(added_vals)
+ s1 >> added_vals
+ print_task_task >> t1
+ # confirm structure
+ assert s1.downstream_task_ids == {"add_one", "t1"}
+ assert t1.upstream_task_ids == {"print_task", "s1"}
+ assert added_vals.operator.upstream_task_ids == {"s1"}
+ assert added_vals.operator.downstream_task_ids == {"print_task"}
+ assert print_task_task.operator.upstream_task_ids == {"add_one"}
+ assert print_task_task.operator.downstream_task_ids == {"t1"}
+
+ # but you can't use a mapped task as setup or teardown
+ with dag_maker():
+ added_vals = add_one.expand(x=[1, 2, 3])
+ with pytest.raises(ValueError, match="Cannot set is_teardown for
mapped operator"):
+ added_vals.as_teardown()
+
+ # ... no matter how hard you try
+ with dag_maker():
+ added_vals = add_one.expand(x=[1, 2, 3])
+ with pytest.raises(ValueError, match="Cannot set is_teardown for
mapped operator"):
+ added_vals.is_teardown = True
+
+ # same with setup
+ with dag_maker():
+ added_vals = add_one.expand(x=[1, 2, 3])
+ with pytest.raises(ValueError, match="Cannot set is_setup for mapped
operator"):
+ added_vals.as_setup()
+
+ # and again, trying harder...
+ with dag_maker():
+ added_vals = add_one.expand(x=[1, 2, 3])
+ with pytest.raises(ValueError, match="Cannot set is_setup for mapped
operator"):
+ added_vals.is_setup = True
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 221d7a3245..69a2c9df22 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1328,18 +1328,18 @@ class TestStringifiedDAGs:
execution_date = datetime(2020, 1, 1)
with DAG("test_task_group_setup_teardown_tasks",
start_date=execution_date) as dag:
- EmptyOperator.as_setup(task_id="setup")
- EmptyOperator.as_teardown(task_id="teardown")
+ EmptyOperator(task_id="setup").as_setup()
+ EmptyOperator(task_id="teardown").as_teardown()
with TaskGroup("group1"):
- EmptyOperator.as_setup(task_id="setup1")
+ EmptyOperator(task_id="setup1").as_setup()
EmptyOperator(task_id="task1")
- EmptyOperator.as_teardown(task_id="teardown1")
+ EmptyOperator(task_id="teardown1").as_teardown()
with TaskGroup("group2"):
- EmptyOperator.as_setup(task_id="setup2")
+ EmptyOperator(task_id="setup2").as_setup()
EmptyOperator(task_id="task2")
- EmptyOperator.as_teardown(task_id="teardown2")
+ EmptyOperator(task_id="teardown2").as_teardown()
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
@@ -1394,8 +1394,8 @@ class TestStringifiedDAGs:
serialized_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
task = serialized_dag.task_group.children["mytask"]
- assert task.is_teardown
- assert task.on_failure_fail_dagrun
+ assert task.is_teardown is True
+ assert task.on_failure_fail_dagrun is True
def test_deps_sorted(self):
"""
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 2beaa7097e..faa70b5a49 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -64,7 +64,7 @@ def get_task_instance(monkeypatch, session, dag_maker):
for task_id in normal_tasks or []:
EmptyOperator(task_id=task_id) >> task
for task_id in setup_tasks or []:
- EmptyOperator.as_setup(task_id=task_id) >> task
+ EmptyOperator(task_id=task_id).as_setup() >> task
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task