This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d800c1bc39 Support task mapping with setup teardown (#32820)
d800c1bc39 is described below
commit d800c1bc3967265280116a05d1855a4da0e1ba10
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Jul 28 01:20:46 2023 -0700
Support task mapping with setup teardown (#32820)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/decorators/base.py | 17 +-
.../0128_2_7_0_add_is_setup_to_task_instance.py | 51 +++
airflow/models/abstractoperator.py | 44 +--
airflow/models/baseoperator.py | 42 ++
airflow/models/dag.py | 14 +-
airflow/models/dagrun.py | 2 +-
airflow/models/mappedoperator.py | 49 +--
airflow/models/taskinstance.py | 5 +
airflow/models/xcom_arg.py | 2 -
airflow/ti_deps/deps/trigger_rule_dep.py | 18 +-
docs/apache-airflow/img/airflow_erd.sha256 | 2 +-
docs/apache-airflow/migrations-ref.rst | 4 +-
tests/models/test_mappedoperator.py | 436 +++++++++++++++++++++
tests/models/test_taskinstance.py | 1 +
tests/models/test_taskmixin.py | 55 ---
tests/serialization/test_dag_serialization.py | 49 ++-
tests/www/views/test_views_tasks.py | 7 +
17 files changed, 649 insertions(+), 149 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index b2059eb6b4..af37f191be 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -302,9 +302,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
decorator_name: str = attr.ib(repr=False, default="task")
_airflow_is_task_decorator: ClassVar[bool] = True
- is_setup: ClassVar[bool] = False
- is_teardown: ClassVar[bool] = False
- on_failure_fail_dagrun: ClassVar[bool] = False
+ is_setup: bool = False
+ is_teardown: bool = False
+ on_failure_fail_dagrun: bool = False
@multiple_outputs.default
def _infer_multiple_outputs(self):
@@ -382,6 +382,10 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping
already partial")
# Since the input is already checked at parse time, we can set strict
# to False to skip the checks on execution.
+ if self.is_teardown:
+ if "trigger_rule" in self.kwargs:
+ raise ValueError("Trigger rule not configurable for teardown
tasks.")
+ self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict:
bool = True) -> XComArg:
@@ -406,7 +410,12 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
task_params=task_kwargs.pop("params", None),
task_default_args=task_kwargs.pop("default_args", None),
)
- partial_kwargs.update(task_kwargs)
+ partial_kwargs.update(
+ task_kwargs,
+ is_setup=self.is_setup,
+ is_teardown=self.is_teardown,
+ on_failure_fail_dagrun=self.on_failure_fail_dagrun,
+ )
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
task_group)
if task_group:
diff --git
a/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py
b/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py
new file mode 100644
index 0000000000..07e4fb4482
--- /dev/null
+++ b/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py
@@ -0,0 +1,51 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add is_setup to task_instance
+
+Revision ID: 0646b768db47
+Revises: 788397e78828
+Create Date: 2023-07-24 10:12:07.630608
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "0646b768db47"
+down_revision = "788397e78828"
+branch_labels = None
+depends_on = None
+airflow_version = "2.7.0"
+
+
+TABLE_NAME = "task_instance"
+
+
+def upgrade():
+ """Apply is_setup column to task_instance"""
+ with op.batch_alter_table(TABLE_NAME) as batch_op:
+ batch_op.add_column(sa.Column("is_setup", sa.Boolean(),
nullable=False, server_default="0"))
+
+
+def downgrade():
+ """Remove is_setup column from task_instance"""
+ with op.batch_alter_table(TABLE_NAME) as batch_op:
+ batch_op.drop_column("is_setup", mssql_drop_default=True)
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 908cb32580..7cba88f28d 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -109,8 +109,6 @@ class AbstractOperator(Templater, DAGNode):
inlets: list
trigger_rule: TriggerRule
- _is_setup = False
- _is_teardown = False
_on_failure_fail_dagrun = False
HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
@@ -160,44 +158,20 @@ class AbstractOperator(Templater, DAGNode):
return self.task_id
@property
- def is_setup(self):
- """
- Whether the operator is a setup task.
-
- :meta private:
- """
- return self._is_setup
+ def is_setup(self) -> bool:
+ raise NotImplementedError()
@is_setup.setter
- def is_setup(self, value):
- """
- Setter for is_setup property.
-
- :meta private:
- """
- if self.is_teardown is True and value is True:
- raise ValueError(f"Cannot mark task '{self.task_id}' as setup;
task is already a teardown.")
- self._is_setup = value
+ def is_setup(self, value: bool) -> None:
+ raise NotImplementedError()
@property
- def is_teardown(self):
- """
- Whether the operator is a teardown task.
-
- :meta private:
- """
- return self._is_teardown
+ def is_teardown(self) -> bool:
+ raise NotImplementedError()
@is_teardown.setter
- def is_teardown(self, value):
- """
- Setter for is_teardown property.
-
- :meta private:
- """
- if self.is_setup is True and value is True:
- raise ValueError(f"Cannot mark task '{self.task_id}' as teardown;
task is already a setup.")
- self._is_teardown = value
+ def is_teardown(self, value: bool) -> None:
+ raise NotImplementedError()
@property
def on_failure_fail_dagrun(self):
@@ -233,8 +207,6 @@ class AbstractOperator(Templater, DAGNode):
on_failure_fail_dagrun=NOTSET,
):
self.is_teardown = True
- if TYPE_CHECKING:
- assert isinstance(self, BaseOperator) # is_teardown not supported
for MappedOperator
self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
if on_failure_fail_dagrun is not NOTSET:
self.on_failure_fail_dagrun = on_failure_fail_dagrun
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index e389a8df8a..e72fcc2940 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -957,6 +957,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
)
self.template_fields = [self.template_fields]
+ self._is_setup = False
+ self._is_teardown = False
if SetupTeardownContext.active:
SetupTeardownContext.update_context_map(self)
@@ -1415,6 +1417,43 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return XComArg(operator=self)
+ @property
+ def is_setup(self) -> bool:
+ """Whether the operator is a setup task.
+
+ :meta private:
+ """
+ return self._is_setup
+
+ @is_setup.setter
+ def is_setup(self, value: bool) -> None:
+ """Setter for is_setup property.
+
+ :meta private:
+ """
+ if self.is_teardown and value:
+ raise ValueError(f"Cannot mark task '{self.task_id}' as setup;
task is already a teardown.")
+ self._is_setup = value
+
+ @property
+ def is_teardown(self) -> bool:
+ """Whether the operator is a teardown task.
+
+ :meta private:
+ """
+ return self._is_teardown
+
+ @is_teardown.setter
+ def is_teardown(self, value: bool) -> None:
+ """
+ Setter for is_teardown property.
+
+ :meta private:
+ """
+ if self.is_setup and value:
+ raise ValueError(f"Cannot mark task '{self.task_id}' as teardown;
task is already a setup.")
+ self._is_teardown = value
+
@staticmethod
def xcom_push(
context: Any,
@@ -1501,6 +1540,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"_BaseOperator__instantiated",
"_BaseOperator__init_kwargs",
"_BaseOperator__from_mapped",
+ "_is_setup",
+ "_is_teardown",
+ "_on_failure_fail_dagrun",
}
| { # Class level defaults need to be added to this list
"start_date",
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index ae8ed2b6b6..5d2172c56d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2674,7 +2674,7 @@ class DAG(LoggingMixin):
conn_file_path: str | None = None,
variable_file_path: str | None = None,
session: Session = NEW_SESSION,
- ) -> None:
+ ) -> DagRun:
"""
Execute one single DagRun for a given DAG and execution date.
@@ -2738,19 +2738,17 @@ class DAG(LoggingMixin):
# than creating a BackfillJob and allows us to surface logs to the user
while dr.state == DagRunState.RUNNING:
schedulable_tis, _ = dr.update_state(session=session)
- try:
- for ti in schedulable_tis:
+ for ti in schedulable_tis:
+ try:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
_run_task(ti, session=session)
- except Exception:
- self.log.info(
- "Task failed. DAG will continue to run until finished and
be marked as failed.",
- exc_info=True,
- )
+ except Exception:
+ self.log.exception("Task failed; ti=%s", ti)
if conn_file_path or variable_file_path:
# Remove the local variables we have added to the
secrets_backend_list
secrets_backend_list.pop(0)
+ return dr
@provide_session
def create_dagrun(
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index baca096024..a6e1da6a0e 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -696,7 +696,7 @@ class DagRun(Base, LoggingMixin):
msg="all_tasks_deadlocked",
)
- # finally, if the roots aren't done, the dag is still running
+ # finally, if the leaves aren't done, the dag is still running
else:
self.set_state(DagRunState.RUNNING)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index dd8b49fb98..0cf8852ea2 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -324,24 +324,6 @@ class MappedOperator(AbstractOperator):
f"{self.task_id!r}."
)
- @AbstractOperator.is_setup.setter # type: ignore[attr-defined]
- def is_setup(self, value):
- """
- Setter for is_setup property. Disabled for MappedOperator.
-
- :meta private:
- """
- raise ValueError("Cannot set is_setup for mapped operator.")
-
- @AbstractOperator.is_teardown.setter # type: ignore[attr-defined]
- def is_teardown(self, value):
- """
- Setter for is_teardown property. Disabled for MappedOperator.
-
- :meta private:
- """
- raise ValueError("Cannot set is_teardown for mapped operator.")
-
@classmethod
@cache
def get_serialized_fields(cls):
@@ -354,9 +336,9 @@ class MappedOperator(AbstractOperator):
"task_group",
"upstream_task_ids",
"supports_lineage",
- "is_setup",
- "is_teardown",
- "on_failure_fail_dagrun",
+ "_is_setup",
+ "_is_teardown",
+ "_on_failure_fail_dagrun",
}
@staticmethod
@@ -408,8 +390,23 @@ class MappedOperator(AbstractOperator):
@trigger_rule.setter
def trigger_rule(self, value):
- # required for mypy which complains about overriding writeable attr
with read-only property
- raise ValueError("Cannot set trigger_rule for mapped operator.")
+ self.partial_kwargs["trigger_rule"] = value
+
+ @property
+ def is_setup(self) -> bool:
+ return bool(self.partial_kwargs.get("is_setup"))
+
+ @is_setup.setter
+ def is_setup(self, value: bool) -> None:
+ self.partial_kwargs["is_setup"] = value
+
+ @property
+ def is_teardown(self) -> bool:
+ return bool(self.partial_kwargs.get("is_teardown"))
+
+ @is_teardown.setter
+ def is_teardown(self, value: bool) -> None:
+ self.partial_kwargs["is_teardown"] = value
@property
def depends_on_past(self) -> bool:
@@ -640,12 +637,18 @@ class MappedOperator(AbstractOperator):
else:
raise RuntimeError("cannot unmap a non-serialized operator
without context")
kwargs = self._get_unmap_kwargs(kwargs,
strict=self._disallow_kwargs_override)
+ is_setup = kwargs.pop("is_setup", False)
+ is_teardown = kwargs.pop("is_teardown", False)
+ on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun",
False)
op = self.operator_class(**kwargs, _airflow_from_mapped=True)
# We need to overwrite task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
# is prepended, and '__N' appended to deduplicate). This is hacky,
# but better than duplicating the whole mangling logic.
op.task_id = self.task_id
+ op.is_setup = is_setup
+ op.is_teardown = is_teardown
+ op.on_failure_fail_dagrun = on_failure_fail_dagrun
return op
# After a mapped operator is serialized, there's no real way to
actually
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 6eafbab4a4..c31e2d8c15 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -41,6 +41,7 @@ import lazy_object_proxy
import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import (
+ Boolean,
Column,
DateTime,
Float,
@@ -419,6 +420,7 @@ class TaskInstance(Base, LoggingMixin):
# Usually used when resuming from DEFERRED.
next_method = Column(String(1000))
next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
+ is_setup = Column(Boolean, nullable=False, default=False,
server_default="0")
# If adding new fields here then remember to add them to
# refresh_from_db() or they won't display in the UI correctly
@@ -577,6 +579,7 @@ class TaskInstance(Base, LoggingMixin):
"operator": task.task_type,
"custom_operator_name": getattr(task, "custom_operator_name",
None),
"map_index": map_index,
+ "is_setup": task.is_setup,
}
@reconstructor
@@ -875,6 +878,7 @@ class TaskInstance(Base, LoggingMixin):
self.trigger_id = ti.trigger_id
self.next_method = ti.next_method
self.next_kwargs = ti.next_kwargs
+ self.is_setup = ti.is_setup
else:
self.state = None
@@ -896,6 +900,7 @@ class TaskInstance(Base, LoggingMixin):
self.executor_config = task.executor_config
self.operator = task.task_type
self.custom_operator_name = getattr(task, "custom_operator_name", None)
+ self.is_setup = task.is_setup
@provide_session
def clear_xcom_data(self, session: Session = NEW_SESSION) -> None:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2383861483..e8cde16db8 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -346,8 +346,6 @@ class PlainXComArg(XComArg):
):
for operator, _ in self.iter_references():
operator.is_teardown = True
- if TYPE_CHECKING:
- assert isinstance(operator, BaseOperator) # Can't set
MappedOperator as teardown
operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
if on_failure_fail_dagrun is not NOTSET:
operator.on_failure_fail_dagrun = on_failure_fail_dagrun
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow/ti_deps/deps/trigger_rule_dep.py
index b1c2d09b96..3d2cc21d06 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -22,9 +22,8 @@ import collections.abc
import functools
from typing import TYPE_CHECKING, Iterator, NamedTuple
-from sqlalchemy import and_, func, or_
+from sqlalchemy import and_, case, func, or_, true
-from airflow.models import MappedOperator
from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus
@@ -69,7 +68,7 @@ class _UpstreamTIStates(NamedTuple):
curr_state = {ti.state: 1}
counter.update(curr_state)
# setup task cannot be mapped
- if not isinstance(ti.task, MappedOperator) and ti.task.is_setup:
+ if ti.task.is_setup:
setup_counter.update(curr_state)
return _UpstreamTIStates(
success=counter.get(TaskInstanceState.SUCCESS, 0),
@@ -227,18 +226,15 @@ class TriggerRuleDep(BaseTIDep):
# "simple" tasks (no task or task group mapping involved).
if not any(needs_expansion(t) for t in upstream_tasks.values()):
upstream = len(upstream_tasks)
- upstream_setup = len(
- [x for x in upstream_tasks.values() if not isinstance(x,
MappedOperator) and x.is_setup]
- )
+ upstream_setup = len([x for x in upstream_tasks.values() if
x.is_setup])
else:
- upstream = (
- session.query(func.count())
+ upstream, upstream_setup = (
+ session.query(func.count(),
func.sum(case((TaskInstance.is_setup == true(), 1), else_=0)))
.filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id
== ti.run_id)
.filter(or_(*_iter_upstream_conditions()))
- .scalar()
+ .one()
)
- # todo: add support for mapped setup?
- upstream_setup = None
+
upstream_done = done >= upstream
changed = False
diff --git a/docs/apache-airflow/img/airflow_erd.sha256
b/docs/apache-airflow/img/airflow_erd.sha256
index f10af90f5b..7e68174d5b 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-7ff18b1eafa528dbdfb62d75151a526de280f73e1edf7fea18da7c64644f0da9
\ No newline at end of file
+aed14092a8168884ceda6b63c269779ccec09b6302da07a7914cc36aaf9a8d3b
\ No newline at end of file
diff --git a/docs/apache-airflow/migrations-ref.rst
b/docs/apache-airflow/migrations-ref.rst
index 4f5abde251..c0f0443a9d 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are
executed via when you ru
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version |
Description |
+=================================+===================+===================+==============================================================+
-| ``788397e78828`` (head) | ``937cbd173ca1`` | ``2.7.0`` |
Add custom_operator_name column |
+| ``0646b768db47`` (head) | ``788397e78828`` | ``2.7.0`` |
Add is_setup to task_instance |
++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
+| ``788397e78828`` | ``937cbd173ca1`` | ``2.7.0`` |
Add custom_operator_name column |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``937cbd173ca1`` | ``c804e5c76e3e`` | ``2.7.0`` |
Add index to task_instance table |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index d6aef85428..1cfb53628a 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -18,12 +18,15 @@
from __future__ import annotations
import logging
+from collections import defaultdict
from datetime import timedelta
from unittest.mock import patch
import pendulum
import pytest
+from airflow.decorators import setup, task, task_group, teardown
+from airflow.exceptions import AirflowSkipException
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
@@ -651,3 +654,436 @@ def test_task_mapping_with_explicit_task_group():
assert finish.upstream_list == [mapped]
assert mapped.downstream_list == [finish]
+
+
+class TestMappedSetupTeardown:
+ @staticmethod
+ def get_states(dr):
+ ti_dict = defaultdict(dict)
+ for ti in dr.get_task_instances():
+ if ti.map_index == -1:
+ ti_dict[ti.task_id] = ti.state
+ else:
+ ti_dict[ti.task_id][ti.map_index] = ti.state
+ return ti_dict
+
+ def test_one_to_many_work_failed(self, session, dag_maker):
+ """
+ Work task failed. Setup maps to teardown. Should have 3 teardowns
all successful even
+ though the work task has failed.
+ """
+ with dag_maker(dag_id="one_to_many") as dag:
+
+ @setup
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("fail!")
+ return val
+
+ @teardown
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup()
+ t = my_teardown.expand(val=s)
+ with t:
+ my_work(s)
+
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": "success",
+ "my_work": "failed",
+ "my_teardown": {0: "success", 1: "success", 2: "success"},
+ }
+ assert states == expected
+
+ def test_many_one_explicit_odd_setup_mapped_setups_fail(self, dag_maker):
+ """
+ one unmapped setup goes to two different teardowns
+ one mapped setup goes to same teardown
+ mapped setups fail
+ teardowns should still run
+ """
+ with dag_maker(
+ dag_id="many_one_explicit_odd_setup_mapped_setups_fail",
+ ) as dag:
+
+ @task
+ def other_setup():
+ print("other setup")
+ return "other setup"
+
+ @task
+ def other_work():
+ print("other work")
+ return "other work"
+
+ @task
+ def other_teardown():
+ print("other teardown")
+ return "other teardown"
+
+ @task
+ def my_setup(val):
+ print(f"setup: {val}")
+ raise ValueError("fail")
+ return val
+
+ @task
+ def my_work(val):
+ print(f"work: {val}")
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+ o_setup = other_setup()
+ o_teardown = other_teardown()
+ with o_teardown.as_teardown(setups=o_setup):
+ other_work()
+ t = my_teardown(s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ o_setup >> t
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": {0: "failed", 1: "failed", 2: "failed"},
+ "other_setup": "success",
+ "other_teardown": "success",
+ "other_work": "success",
+ "my_teardown": "success",
+ "my_work": "upstream_failed",
+ }
+ assert states == expected
+
+ def test_many_one_explicit_odd_setup_all_setups_fail(self, dag_maker):
+ """
+ one unmapped setup goes to two different teardowns
+ one mapped setup goes to same teardown
+ all setups fail
+ teardowns should not run
+ """
+ with dag_maker(
+ dag_id="many_one_explicit_odd_setup_all_setups_fail",
+ ) as dag:
+
+ @task
+ def other_setup():
+ print("other setup")
+ raise ValueError("fail")
+ return "other setup"
+
+ @task
+ def other_work():
+ print("other work")
+ return "other work"
+
+ @task
+ def other_teardown():
+ print("other teardown")
+ return "other teardown"
+
+ @task
+ def my_setup(val):
+ print(f"setup: {val}")
+ raise ValueError("fail")
+ return val
+
+ @task
+ def my_work(val):
+ print(f"work: {val}")
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+ o_setup = other_setup()
+ o_teardown = other_teardown()
+ with o_teardown.as_teardown(setups=o_setup):
+ other_work()
+ t = my_teardown(s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ o_setup >> t
+
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_teardown": "upstream_failed",
+ "other_setup": "failed",
+ "other_work": "upstream_failed",
+ "other_teardown": "upstream_failed",
+ "my_setup": {0: "failed", 1: "failed", 2: "failed"},
+ "my_work": "upstream_failed",
+ }
+ assert states == expected
+
+ def test_many_one_explicit_odd_setup_one_mapped_fails(self, dag_maker):
+ """
+ one unmapped setup goes to two different teardowns
+ one mapped setup goes to same teardown
+ one of the mapped setup instances fails
+ teardowns should all run
+ """
+ with dag_maker(dag_id="many_one_explicit_odd_setup_one_mapped_fails")
as dag:
+
+ @task
+ def other_setup():
+ print("other setup")
+ return "other setup"
+
+ @task
+ def other_work():
+ print("other work")
+ return "other work"
+
+ @task
+ def other_teardown():
+ print("other teardown")
+ return "other teardown"
+
+ @task
+ def my_setup(val):
+ if val == "data2.json":
+ raise ValueError("fail!")
+ elif val == "data3.json":
+ raise AirflowSkipException("skip!")
+ print(f"setup: {val}")
+ return val
+
+ @task
+ def my_work(val):
+ print(f"work: {val}")
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+ o_setup = other_setup()
+ o_teardown = other_teardown()
+ with o_teardown.as_teardown(setups=o_setup):
+ other_work()
+ t = my_teardown(s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ o_setup >> t
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": {0: "success", 1: "failed", 2: "skipped"},
+ "other_setup": "success",
+ "other_teardown": "success",
+ "other_work": "success",
+ "my_teardown": "success",
+ "my_work": "upstream_failed",
+ }
+ assert states == expected
+
+ def test_one_to_many_as_teardown(self, dag_maker, session):
+ """
+ 1 setup mapping to 3 teardowns
+ 1 work task
+ work fails
+ teardowns succeed
+ dagrun should be failure
+ """
+ with dag_maker(dag_id="one_to_many_as_teardown") as dag:
+
+ @task
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("this fails")
+ return val
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup()
+ t = my_teardown.expand(val=s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": "success",
+ "my_teardown": {0: "success", 1: "success", 2: "success"},
+ "my_work": "failed",
+ }
+ assert states == expected
+
+ def test_one_to_many_as_teardown_offd(self, dag_maker, session):
+ """
+ 1 setup mapping to 3 teardowns
+ 1 work task
+ work succeeds
+ all but one teardown succeed
+ offd=True
+ dagrun should be success
+ """
+ with dag_maker(dag_id="one_to_many_as_teardown_offd") as dag:
+
+ @task
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ return val
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+ if val == 2:
+ raise ValueError("failure")
+
+ s = my_setup()
+ t = my_teardown.expand(val=s).as_teardown(setups=s,
on_failure_fail_dagrun=True)
+ with t:
+ my_work(s)
+ # todo: if on_failure_fail_dagrun=True, should we still regard the
WORK task as a leaf?
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": "success",
+ "my_teardown": {0: "success", 1: "failed", 2: "success"},
+ "my_work": "success",
+ }
+ assert states == expected
+
+ def test_mapped_task_group_simple(self, dag_maker, session):
+ """
+ Mapped task group wherein there's a simple s >> w >> t pipeline.
+ When s is skipped, all should be skipped
+ When s is failed, all should be upstream failed
+ """
+ with dag_maker(dag_id="mapped_task_group_simple") as dag:
+
+ @setup
+ def my_setup(val):
+ if val == "data2.json":
+ raise ValueError("fail!")
+ elif val == "data3.json":
+ raise AirflowSkipException("skip!")
+ print(f"setup: {val}")
+
+ @task
+ def my_work(val):
+ print(f"work: {val}")
+
+ @teardown
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ @task_group
+ def file_transforms(filename):
+ s = my_setup(filename)
+ t = my_teardown(filename).as_teardown(setups=s)
+ with t:
+ my_work(filename)
+
+ file_transforms.expand(filename=["data1.json", "data2.json",
"data3.json"])
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "file_transforms.my_setup": {0: "success", 1: "failed", 2:
"skipped"},
+ "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2:
"skipped"},
+ "file_transforms.my_teardown": {0: "success", 1:
"upstream_failed", 2: "skipped"},
+ }
+
+ assert states == expected
+
+ def test_mapped_task_group_work_fail_or_skip(self, dag_maker, session):
+ """
+ Mapped task group wherein there's a simple s >> w >> t pipeline.
+ When w is skipped, teardown should still run
+ When w is failed, teardown should still run
+ """
+ with dag_maker(dag_id="mapped_task_group_work_fail_or_skip") as dag:
+
+ @setup
+ def my_setup(val):
+ print(f"setup: {val}")
+
+ @task
+ def my_work(val):
+ if val == "data2.json":
+ raise ValueError("fail!")
+ elif val == "data3.json":
+ raise AirflowSkipException("skip!")
+ print(f"work: {val}")
+
+ @teardown
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ @task_group
+ def file_transforms(filename):
+ s = my_setup(filename)
+ t = my_teardown(filename).as_teardown(setups=s)
+ with t:
+ my_work(filename)
+
+ file_transforms.expand(filename=["data1.json", "data2.json",
"data3.json"])
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "file_transforms.my_setup": {0: "success", 1: "success", 2:
"success"},
+ "file_transforms.my_teardown": {0: "success", 1: "success", 2:
"success"},
+ "file_transforms.my_work": {0: "success", 1: "failed", 2:
"skipped"},
+ }
+ assert states == expected
+
+ def test_teardown_many_one_explicit(self, dag_maker, session):
+ """-- passing
+ one mapped setup going to one unmapped work
+ 3 diff states for setup: success / failed / skipped
+ teardown still runs, and receives the xcom from the single successful
setup
+ """
+ with dag_maker(dag_id="teardown_many_one_explicit") as dag:
+
+ @task
+ def my_setup(val):
+ if val == "data2.json":
+ raise ValueError("fail!")
+ elif val == "data3.json":
+ raise AirflowSkipException("skip!")
+ print(f"setup: {val}")
+ return val
+
+ @task
+ def my_work(val):
+ print(f"work: {val}")
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+ with my_teardown(s).as_teardown(setups=s):
+ my_work(s)
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": {0: "success", 1: "failed", 2: "skipped"},
+ "my_teardown": "success",
+ "my_work": "upstream_failed",
+ }
+ assert states == expected
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 9e7f5eb48a..78f7fb9690 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3026,6 +3026,7 @@ class TestTaskInstance:
"_try_number": 1,
"max_tries": 1,
"hostname": "some_unique_hostname",
+ "is_setup": False,
"unixname": "some_unique_unixname",
"job_id": 1234,
"pool": "some_fake_pool_id",
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
index 2d55b30cb3..c1795f22a8 100644
--- a/tests/models/test_taskmixin.py
+++ b/tests/models/test_taskmixin.py
@@ -211,61 +211,6 @@ def
test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker):
m.operator.is_setup = True
-def test_no_setup_or_teardown_for_mapped_operator(dag_maker):
- @task
- def add_one(x):
- return x + 1
-
- @task
- def print_task(values):
- print(sum(values))
-
- # vanilla mapped task
- with dag_maker():
- added_vals = add_one.expand(x=[1, 2, 3])
- print_task(added_vals)
-
- # combining setup and teardown with vanilla mapped task is fine
- with dag_maker():
- s1 = BaseOperator(task_id="s1").as_setup()
- t1 = BaseOperator(task_id="t1").as_teardown(setups=s1)
- added_vals = add_one.expand(x=[1, 2, 3])
- print_task_task = print_task(added_vals)
- s1 >> added_vals
- print_task_task >> t1
- # confirm structure
- assert s1.downstream_task_ids == {"add_one", "t1"}
- assert t1.upstream_task_ids == {"print_task", "s1"}
- assert added_vals.operator.upstream_task_ids == {"s1"}
- assert added_vals.operator.downstream_task_ids == {"print_task"}
- assert print_task_task.operator.upstream_task_ids == {"add_one"}
- assert print_task_task.operator.downstream_task_ids == {"t1"}
-
- # but you can't use a mapped task as setup or teardown
- with dag_maker():
- added_vals = add_one.expand(x=[1, 2, 3])
- with pytest.raises(ValueError, match="Cannot set is_teardown for
mapped operator"):
- added_vals.as_teardown()
-
- # ... no matter how hard you try
- with dag_maker():
- added_vals = add_one.expand(x=[1, 2, 3])
- with pytest.raises(ValueError, match="Cannot set is_teardown for
mapped operator"):
- added_vals.is_teardown = True
-
- # same with setup
- with dag_maker():
- added_vals = add_one.expand(x=[1, 2, 3])
- with pytest.raises(ValueError, match="Cannot set is_setup for mapped
operator"):
- added_vals.as_setup()
-
- # and again, trying harder...
- with dag_maker():
- added_vals = add_one.expand(x=[1, 2, 3])
- with pytest.raises(ValueError, match="Cannot set is_setup for mapped
operator"):
- added_vals.is_setup = True
-
-
def test_set_setup_teardown_ctx_dependencies_using_decorated_tasks(dag_maker):
with dag_maker():
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index c89122879e..f59794509e 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -53,6 +53,7 @@ from airflow.operators.empty import EmptyOperator
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.security import permissions
from airflow.sensors.bash import BashSensor
+from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
from airflow.serialization.serialized_objects import (
DagDependency,
@@ -112,7 +113,8 @@ executor_config_pod = k8s.V1Pod(
]
),
)
-
+TYPE = Encoding.TYPE
+VAR = Encoding.VAR
serialized_simple_dag_ground_truth = {
"__version": 1,
"dag": {
@@ -383,7 +385,11 @@ class TestStringifiedDAGs:
serialized_dags[v.dag_id] = dag
# Compares with the ground truth of JSON string.
- self.validate_serialized_dag(serialized_dags["simple_dag"],
serialized_simple_dag_ground_truth)
+ actual, expected = self.prepare_ser_dags_for_comparison(
+ actual=serialized_dags["simple_dag"],
+ expected=serialized_simple_dag_ground_truth,
+ )
+ assert actual == expected
@pytest.mark.parametrize(
"timetable, serialized_timetable",
@@ -412,7 +418,14 @@ class TestStringifiedDAGs:
del expected["dag"]["schedule_interval"]
expected["dag"]["timetable"] = serialized_timetable
- self.validate_serialized_dag(serialized_dag, expected)
+ actual, expected = self.prepare_ser_dags_for_comparison(
+ actual=serialized_dag,
+ expected=expected,
+ )
+ for task in actual["dag"]["tasks"]:
+ for k, v in task.items():
+ print(task["task_id"], k, v)
+ assert actual == expected
def test_dag_serialization_unregistered_custom_timetable(self):
"""Verify serialization fails without timetable registration."""
@@ -429,10 +442,10 @@ class TestStringifiedDAGs:
)
assert str(ctx.value) == message
- def validate_serialized_dag(self, json_dag, ground_truth_dag):
+ def prepare_ser_dags_for_comparison(self, actual, expected):
"""Verify serialized DAGs match the ground truth."""
- assert json_dag["dag"]["fileloc"].split("/")[-1] ==
"test_dag_serialization.py"
- json_dag["dag"]["fileloc"] = None
+ assert actual["dag"]["fileloc"].split("/")[-1] ==
"test_dag_serialization.py"
+ actual["dag"]["fileloc"] = None
def sorted_serialized_dag(dag_dict: dict):
"""
@@ -447,7 +460,11 @@ class TestStringifiedDAGs:
)
return dag_dict
- assert sorted_serialized_dag(ground_truth_dag) ==
sorted_serialized_dag(json_dag)
+ # by roundtripping to json we get a cleaner diff
+ # if not doing this, we get false alarms such as "__var" != VAR
+ actual = json.loads(json.dumps(sorted_serialized_dag(actual)))
+ expected = json.loads(json.dumps(sorted_serialized_dag(expected)))
+ return actual, expected
def test_deserialization_across_process(self):
"""A serialized DAG can be deserialized in another process."""
@@ -2290,6 +2307,9 @@ def test_taskflow_expand_serde():
"_operator_name": "@task",
"downstream_task_ids": [],
"partial_kwargs": {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {
"__type": "dict",
@@ -2329,6 +2349,9 @@ def test_taskflow_expand_serde():
value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1",
"key": XCOM_RETURN_KEY})},
)
assert deserialized.partial_kwargs == {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
@@ -2344,6 +2367,9 @@ def test_taskflow_expand_serde():
value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1",
"key": XCOM_RETURN_KEY})},
)
assert pickled.partial_kwargs == {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
@@ -2376,6 +2402,9 @@ def test_taskflow_expand_kwargs_serde(strict):
"_operator_name": "@task",
"downstream_task_ids": [],
"partial_kwargs": {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {
"__type": "dict",
@@ -2413,6 +2442,9 @@ def test_taskflow_expand_kwargs_serde(strict):
value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert deserialized.partial_kwargs == {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
@@ -2428,6 +2460,9 @@ def test_taskflow_expand_kwargs_serde(strict):
_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert pickled.partial_kwargs == {
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
diff --git a/tests/www/views/test_views_tasks.py
b/tests/www/views/test_views_tasks.py
index 32df99c047..93bc52884c 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -1017,6 +1017,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1049,6 +1050,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1081,6 +1083,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1113,6 +1116,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1145,6 +1149,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1177,6 +1182,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,
@@ -1209,6 +1215,7 @@ def test_task_instances(admin_client):
"executor_config": {},
"external_executor_id": None,
"hostname": "",
+ "is_setup": False,
"job_id": None,
"map_index": -1,
"max_tries": 0,