This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d800c1bc39 Support task mapping with setup teardown (#32820)
d800c1bc39 is described below

commit d800c1bc3967265280116a05d1855a4da0e1ba10
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Jul 28 01:20:46 2023 -0700

    Support task mapping with setup teardown (#32820)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/decorators/base.py                         |  17 +-
 .../0128_2_7_0_add_is_setup_to_task_instance.py    |  51 +++
 airflow/models/abstractoperator.py                 |  44 +--
 airflow/models/baseoperator.py                     |  42 ++
 airflow/models/dag.py                              |  14 +-
 airflow/models/dagrun.py                           |   2 +-
 airflow/models/mappedoperator.py                   |  49 +--
 airflow/models/taskinstance.py                     |   5 +
 airflow/models/xcom_arg.py                         |   2 -
 airflow/ti_deps/deps/trigger_rule_dep.py           |  18 +-
 docs/apache-airflow/img/airflow_erd.sha256         |   2 +-
 docs/apache-airflow/migrations-ref.rst             |   4 +-
 tests/models/test_mappedoperator.py                | 436 +++++++++++++++++++++
 tests/models/test_taskinstance.py                  |   1 +
 tests/models/test_taskmixin.py                     |  55 ---
 tests/serialization/test_dag_serialization.py      |  49 ++-
 tests/www/views/test_views_tasks.py                |   7 +
 17 files changed, 649 insertions(+), 149 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index b2059eb6b4..af37f191be 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -302,9 +302,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
     decorator_name: str = attr.ib(repr=False, default="task")
 
     _airflow_is_task_decorator: ClassVar[bool] = True
-    is_setup: ClassVar[bool] = False
-    is_teardown: ClassVar[bool] = False
-    on_failure_fail_dagrun: ClassVar[bool] = False
+    is_setup: bool = False
+    is_teardown: bool = False
+    on_failure_fail_dagrun: bool = False
 
     @multiple_outputs.default
     def _infer_multiple_outputs(self):
@@ -382,6 +382,10 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping 
already partial")
         # Since the input is already checked at parse time, we can set strict
         # to False to skip the checks on execution.
+        if self.is_teardown:
+            if "trigger_rule" in self.kwargs:
+                raise ValueError("Trigger rule not configurable for teardown 
tasks.")
+            self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
         return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
 
     def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: 
bool = True) -> XComArg:
@@ -406,7 +410,12 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             task_params=task_kwargs.pop("params", None),
             task_default_args=task_kwargs.pop("default_args", None),
         )
-        partial_kwargs.update(task_kwargs)
+        partial_kwargs.update(
+            task_kwargs,
+            is_setup=self.is_setup,
+            is_teardown=self.is_teardown,
+            on_failure_fail_dagrun=self.on_failure_fail_dagrun,
+        )
 
         task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, 
task_group)
         if task_group:
diff --git 
a/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py 
b/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py
new file mode 100644
index 0000000000..07e4fb4482
--- /dev/null
+++ b/airflow/migrations/versions/0128_2_7_0_add_is_setup_to_task_instance.py
@@ -0,0 +1,51 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add is_setup to task_instance
+
+Revision ID: 0646b768db47
+Revises: 788397e78828
+Create Date: 2023-07-24 10:12:07.630608
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "0646b768db47"
+down_revision = "788397e78828"
+branch_labels = None
+depends_on = None
+airflow_version = "2.7.0"
+
+
+TABLE_NAME = "task_instance"
+
+
+def upgrade():
+    """Apply is_setup column to task_instance"""
+    with op.batch_alter_table(TABLE_NAME) as batch_op:
+        batch_op.add_column(sa.Column("is_setup", sa.Boolean(), 
nullable=False, server_default="0"))
+
+
+def downgrade():
+    """Remove is_setup column from task_instance"""
+    with op.batch_alter_table(TABLE_NAME) as batch_op:
+        batch_op.drop_column("is_setup", mssql_drop_default=True)
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 908cb32580..7cba88f28d 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -109,8 +109,6 @@ class AbstractOperator(Templater, DAGNode):
     inlets: list
     trigger_rule: TriggerRule
 
-    _is_setup = False
-    _is_teardown = False
     _on_failure_fail_dagrun = False
 
     HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
@@ -160,44 +158,20 @@ class AbstractOperator(Templater, DAGNode):
         return self.task_id
 
     @property
-    def is_setup(self):
-        """
-        Whether the operator is a setup task.
-
-        :meta private:
-        """
-        return self._is_setup
+    def is_setup(self) -> bool:
+        raise NotImplementedError()
 
     @is_setup.setter
-    def is_setup(self, value):
-        """
-        Setter for is_setup property.
-
-        :meta private:
-        """
-        if self.is_teardown is True and value is True:
-            raise ValueError(f"Cannot mark task '{self.task_id}' as setup; 
task is already a teardown.")
-        self._is_setup = value
+    def is_setup(self, value: bool) -> None:
+        raise NotImplementedError()
 
     @property
-    def is_teardown(self):
-        """
-        Whether the operator is a teardown task.
-
-        :meta private:
-        """
-        return self._is_teardown
+    def is_teardown(self) -> bool:
+        raise NotImplementedError()
 
     @is_teardown.setter
-    def is_teardown(self, value):
-        """
-        Setter for is_teardown property.
-
-        :meta private:
-        """
-        if self.is_setup is True and value is True:
-            raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; 
task is already a setup.")
-        self._is_teardown = value
+    def is_teardown(self, value: bool) -> None:
+        raise NotImplementedError()
 
     @property
     def on_failure_fail_dagrun(self):
@@ -233,8 +207,6 @@ class AbstractOperator(Templater, DAGNode):
         on_failure_fail_dagrun=NOTSET,
     ):
         self.is_teardown = True
-        if TYPE_CHECKING:
-            assert isinstance(self, BaseOperator)  # is_teardown not supported 
for MappedOperator
         self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
         if on_failure_fail_dagrun is not NOTSET:
             self.on_failure_fail_dagrun = on_failure_fail_dagrun
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index e389a8df8a..e72fcc2940 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -957,6 +957,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             )
             self.template_fields = [self.template_fields]
 
+        self._is_setup = False
+        self._is_teardown = False
         if SetupTeardownContext.active:
             SetupTeardownContext.update_context_map(self)
 
@@ -1415,6 +1417,43 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         return XComArg(operator=self)
 
+    @property
+    def is_setup(self) -> bool:
+        """Whether the operator is a setup task.
+
+        :meta private:
+        """
+        return self._is_setup
+
+    @is_setup.setter
+    def is_setup(self, value: bool) -> None:
+        """Setter for is_setup property.
+
+        :meta private:
+        """
+        if self.is_teardown and value:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as setup; 
task is already a teardown.")
+        self._is_setup = value
+
+    @property
+    def is_teardown(self) -> bool:
+        """Whether the operator is a teardown task.
+
+        :meta private:
+        """
+        return self._is_teardown
+
+    @is_teardown.setter
+    def is_teardown(self, value: bool) -> None:
+        """
+        Setter for is_teardown property.
+
+        :meta private:
+        """
+        if self.is_setup and value:
+            raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; 
task is already a setup.")
+        self._is_teardown = value
+
     @staticmethod
     def xcom_push(
         context: Any,
@@ -1501,6 +1540,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     "_BaseOperator__instantiated",
                     "_BaseOperator__init_kwargs",
                     "_BaseOperator__from_mapped",
+                    "_is_setup",
+                    "_is_teardown",
+                    "_on_failure_fail_dagrun",
                 }
                 | {  # Class level defaults need to be added to this list
                     "start_date",
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index ae8ed2b6b6..5d2172c56d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2674,7 +2674,7 @@ class DAG(LoggingMixin):
         conn_file_path: str | None = None,
         variable_file_path: str | None = None,
         session: Session = NEW_SESSION,
-    ) -> None:
+    ) -> DagRun:
         """
         Execute one single DagRun for a given DAG and execution date.
 
@@ -2738,19 +2738,17 @@ class DAG(LoggingMixin):
         # than creating a BackfillJob and allows us to surface logs to the user
         while dr.state == DagRunState.RUNNING:
             schedulable_tis, _ = dr.update_state(session=session)
-            try:
-                for ti in schedulable_tis:
+            for ti in schedulable_tis:
+                try:
                     add_logger_if_needed(ti)
                     ti.task = tasks[ti.task_id]
                     _run_task(ti, session=session)
-            except Exception:
-                self.log.info(
-                    "Task failed. DAG will continue to run until finished and 
be marked as failed.",
-                    exc_info=True,
-                )
+                except Exception:
+                    self.log.exception("Task failed; ti=%s", ti)
         if conn_file_path or variable_file_path:
             # Remove the local variables we have added to the 
secrets_backend_list
             secrets_backend_list.pop(0)
+        return dr
 
     @provide_session
     def create_dagrun(
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index baca096024..a6e1da6a0e 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -696,7 +696,7 @@ class DagRun(Base, LoggingMixin):
                     msg="all_tasks_deadlocked",
                 )
 
-        # finally, if the roots aren't done, the dag is still running
+        # finally, if the leaves aren't done, the dag is still running
         else:
             self.set_state(DagRunState.RUNNING)
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index dd8b49fb98..0cf8852ea2 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -324,24 +324,6 @@ class MappedOperator(AbstractOperator):
                 f"{self.task_id!r}."
             )
 
-    @AbstractOperator.is_setup.setter  # type: ignore[attr-defined]
-    def is_setup(self, value):
-        """
-        Setter for is_setup property. Disabled for MappedOperator.
-
-        :meta private:
-        """
-        raise ValueError("Cannot set is_setup for mapped operator.")
-
-    @AbstractOperator.is_teardown.setter  # type: ignore[attr-defined]
-    def is_teardown(self, value):
-        """
-        Setter for is_teardown property. Disabled for MappedOperator.
-
-        :meta private:
-        """
-        raise ValueError("Cannot set is_teardown for mapped operator.")
-
     @classmethod
     @cache
     def get_serialized_fields(cls):
@@ -354,9 +336,9 @@ class MappedOperator(AbstractOperator):
             "task_group",
             "upstream_task_ids",
             "supports_lineage",
-            "is_setup",
-            "is_teardown",
-            "on_failure_fail_dagrun",
+            "_is_setup",
+            "_is_teardown",
+            "_on_failure_fail_dagrun",
         }
 
     @staticmethod
@@ -408,8 +390,23 @@ class MappedOperator(AbstractOperator):
 
     @trigger_rule.setter
     def trigger_rule(self, value):
-        # required for mypy which complains about overriding writeable attr 
with read-only property
-        raise ValueError("Cannot set trigger_rule for mapped operator.")
+        self.partial_kwargs["trigger_rule"] = value
+
+    @property
+    def is_setup(self) -> bool:
+        return bool(self.partial_kwargs.get("is_setup"))
+
+    @is_setup.setter
+    def is_setup(self, value: bool) -> None:
+        self.partial_kwargs["is_setup"] = value
+
+    @property
+    def is_teardown(self) -> bool:
+        return bool(self.partial_kwargs.get("is_teardown"))
+
+    @is_teardown.setter
+    def is_teardown(self, value: bool) -> None:
+        self.partial_kwargs["is_teardown"] = value
 
     @property
     def depends_on_past(self) -> bool:
@@ -640,12 +637,18 @@ class MappedOperator(AbstractOperator):
             else:
                 raise RuntimeError("cannot unmap a non-serialized operator 
without context")
             kwargs = self._get_unmap_kwargs(kwargs, 
strict=self._disallow_kwargs_override)
+            is_setup = kwargs.pop("is_setup", False)
+            is_teardown = kwargs.pop("is_teardown", False)
+            on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", 
False)
             op = self.operator_class(**kwargs, _airflow_from_mapped=True)
             # We need to overwrite task_id here because BaseOperator further
             # mangles the task_id based on the task hierarchy (namely, group_id
             # is prepended, and '__N' appended to deduplicate). This is hacky,
             # but better than duplicating the whole mangling logic.
             op.task_id = self.task_id
+            op.is_setup = is_setup
+            op.is_teardown = is_teardown
+            op.on_failure_fail_dagrun = on_failure_fail_dagrun
             return op
 
         # After a mapped operator is serialized, there's no real way to 
actually
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 6eafbab4a4..c31e2d8c15 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -41,6 +41,7 @@ import lazy_object_proxy
 import pendulum
 from jinja2 import TemplateAssertionError, UndefinedError
 from sqlalchemy import (
+    Boolean,
     Column,
     DateTime,
     Float,
@@ -419,6 +420,7 @@ class TaskInstance(Base, LoggingMixin):
     # Usually used when resuming from DEFERRED.
     next_method = Column(String(1000))
     next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
+    is_setup = Column(Boolean, nullable=False, default=False, 
server_default="0")
 
     # If adding new fields here then remember to add them to
     # refresh_from_db() or they won't display in the UI correctly
@@ -577,6 +579,7 @@ class TaskInstance(Base, LoggingMixin):
             "operator": task.task_type,
             "custom_operator_name": getattr(task, "custom_operator_name", 
None),
             "map_index": map_index,
+            "is_setup": task.is_setup,
         }
 
     @reconstructor
@@ -875,6 +878,7 @@ class TaskInstance(Base, LoggingMixin):
             self.trigger_id = ti.trigger_id
             self.next_method = ti.next_method
             self.next_kwargs = ti.next_kwargs
+            self.is_setup = ti.is_setup
         else:
             self.state = None
 
@@ -896,6 +900,7 @@ class TaskInstance(Base, LoggingMixin):
         self.executor_config = task.executor_config
         self.operator = task.task_type
         self.custom_operator_name = getattr(task, "custom_operator_name", None)
+        self.is_setup = task.is_setup
 
     @provide_session
     def clear_xcom_data(self, session: Session = NEW_SESSION) -> None:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2383861483..e8cde16db8 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -346,8 +346,6 @@ class PlainXComArg(XComArg):
     ):
         for operator, _ in self.iter_references():
             operator.is_teardown = True
-            if TYPE_CHECKING:
-                assert isinstance(operator, BaseOperator)  # Can't set 
MappedOperator as teardown
             operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
             if on_failure_fail_dagrun is not NOTSET:
                 operator.on_failure_fail_dagrun = on_failure_fail_dagrun
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index b1c2d09b96..3d2cc21d06 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -22,9 +22,8 @@ import collections.abc
 import functools
 from typing import TYPE_CHECKING, Iterator, NamedTuple
 
-from sqlalchemy import and_, func, or_
+from sqlalchemy import and_, case, func, or_, true
 
-from airflow.models import MappedOperator
 from airflow.models.taskinstance import PAST_DEPENDS_MET
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus
@@ -69,7 +68,7 @@ class _UpstreamTIStates(NamedTuple):
             curr_state = {ti.state: 1}
             counter.update(curr_state)
             # setup task cannot be mapped
-            if not isinstance(ti.task, MappedOperator) and ti.task.is_setup:
+            if ti.task.is_setup:
                 setup_counter.update(curr_state)
         return _UpstreamTIStates(
             success=counter.get(TaskInstanceState.SUCCESS, 0),
@@ -227,18 +226,15 @@ class TriggerRuleDep(BaseTIDep):
         # "simple" tasks (no task or task group mapping involved).
         if not any(needs_expansion(t) for t in upstream_tasks.values()):
             upstream = len(upstream_tasks)
-            upstream_setup = len(
-                [x for x in upstream_tasks.values() if not isinstance(x, 
MappedOperator) and x.is_setup]
-            )
+            upstream_setup = len([x for x in upstream_tasks.values() if 
x.is_setup])
         else:
-            upstream = (
-                session.query(func.count())
+            upstream, upstream_setup = (
+                session.query(func.count(), 
func.sum(case((TaskInstance.is_setup == true(), 1), else_=0)))
                 .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id 
== ti.run_id)
                 .filter(or_(*_iter_upstream_conditions()))
-                .scalar()
+                .one()
             )
-            # todo: add support for mapped setup?
-            upstream_setup = None
+
         upstream_done = done >= upstream
 
         changed = False
diff --git a/docs/apache-airflow/img/airflow_erd.sha256 
b/docs/apache-airflow/img/airflow_erd.sha256
index f10af90f5b..7e68174d5b 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-7ff18b1eafa528dbdfb62d75151a526de280f73e1edf7fea18da7c64644f0da9
\ No newline at end of file
+aed14092a8168884ceda6b63c269779ccec09b6302da07a7914cc36aaf9a8d3b
\ No newline at end of file
diff --git a/docs/apache-airflow/migrations-ref.rst 
b/docs/apache-airflow/migrations-ref.rst
index 4f5abde251..c0f0443a9d 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are 
executed via when you ru
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
 | Revision ID                     | Revises ID        | Airflow Version   | 
Description                                                  |
 
+=================================+===================+===================+==============================================================+
-| ``788397e78828`` (head)         | ``937cbd173ca1``  | ``2.7.0``         | 
Add custom_operator_name column                              |
+| ``0646b768db47`` (head)         | ``788397e78828``  | ``2.7.0``         | 
Add is_setup to task_instance                                |
++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
+| ``788397e78828``                | ``937cbd173ca1``  | ``2.7.0``         | 
Add custom_operator_name column                              |
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
 | ``937cbd173ca1``                | ``c804e5c76e3e``  | ``2.7.0``         | 
Add index to task_instance table                             |
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index d6aef85428..1cfb53628a 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -18,12 +18,15 @@
 from __future__ import annotations
 
 import logging
+from collections import defaultdict
 from datetime import timedelta
 from unittest.mock import patch
 
 import pendulum
 import pytest
 
+from airflow.decorators import setup, task, task_group, teardown
+from airflow.exceptions import AirflowSkipException
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.mappedoperator import MappedOperator
@@ -651,3 +654,436 @@ def test_task_mapping_with_explicit_task_group():
 
     assert finish.upstream_list == [mapped]
     assert mapped.downstream_list == [finish]
+
+
+class TestMappedSetupTeardown:
+    @staticmethod
+    def get_states(dr):
+        ti_dict = defaultdict(dict)
+        for ti in dr.get_task_instances():
+            if ti.map_index == -1:
+                ti_dict[ti.task_id] = ti.state
+            else:
+                ti_dict[ti.task_id][ti.map_index] = ti.state
+        return ti_dict
+
+    def test_one_to_many_work_failed(self, session, dag_maker):
+        """
+        Work task failed.  Setup maps to teardown.  Should have 3 teardowns 
all successful even
+        though the work task has failed.
+        """
+        with dag_maker(dag_id="one_to_many") as dag:
+
+            @setup
+            def my_setup():
+                print("setting up multiple things")
+                return [1, 2, 3]
+
+            @task
+            def my_work(val):
+                print(f"doing work with multiple things: {val}")
+                raise ValueError("fail!")
+                return val
+
+            @teardown
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup()
+            t = my_teardown.expand(val=s)
+            with t:
+                my_work(s)
+
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": "success",
+            "my_work": "failed",
+            "my_teardown": {0: "success", 1: "success", 2: "success"},
+        }
+        assert states == expected
+
+    def test_many_one_explicit_odd_setup_mapped_setups_fail(self, dag_maker):
+        """
+        one unmapped setup goes to two different teardowns
+        one mapped setup goes to same teardown
+        mapped setups fail
+        teardowns should still run
+        """
+        with dag_maker(
+            dag_id="many_one_explicit_odd_setup_mapped_setups_fail",
+        ) as dag:
+
+            @task
+            def other_setup():
+                print("other setup")
+                return "other setup"
+
+            @task
+            def other_work():
+                print("other work")
+                return "other work"
+
+            @task
+            def other_teardown():
+                print("other teardown")
+                return "other teardown"
+
+            @task
+            def my_setup(val):
+                print(f"setup: {val}")
+                raise ValueError("fail")
+                return val
+
+            @task
+            def my_work(val):
+                print(f"work: {val}")
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+            o_setup = other_setup()
+            o_teardown = other_teardown()
+            with o_teardown.as_teardown(setups=o_setup):
+                other_work()
+            t = my_teardown(s).as_teardown(setups=s)
+            with t:
+                my_work(s)
+            o_setup >> t
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": {0: "failed", 1: "failed", 2: "failed"},
+            "other_setup": "success",
+            "other_teardown": "success",
+            "other_work": "success",
+            "my_teardown": "success",
+            "my_work": "upstream_failed",
+        }
+        assert states == expected
+
+    def test_many_one_explicit_odd_setup_all_setups_fail(self, dag_maker):
+        """
+        one unmapped setup goes to two different teardowns
+        one mapped setup goes to same teardown
+        all setups fail
+        teardowns should not run
+        """
+        with dag_maker(
+            dag_id="many_one_explicit_odd_setup_all_setups_fail",
+        ) as dag:
+
+            @task
+            def other_setup():
+                print("other setup")
+                raise ValueError("fail")
+                return "other setup"
+
+            @task
+            def other_work():
+                print("other work")
+                return "other work"
+
+            @task
+            def other_teardown():
+                print("other teardown")
+                return "other teardown"
+
+            @task
+            def my_setup(val):
+                print(f"setup: {val}")
+                raise ValueError("fail")
+                return val
+
+            @task
+            def my_work(val):
+                print(f"work: {val}")
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+            o_setup = other_setup()
+            o_teardown = other_teardown()
+            with o_teardown.as_teardown(setups=o_setup):
+                other_work()
+            t = my_teardown(s).as_teardown(setups=s)
+            with t:
+                my_work(s)
+            o_setup >> t
+
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_teardown": "upstream_failed",
+            "other_setup": "failed",
+            "other_work": "upstream_failed",
+            "other_teardown": "upstream_failed",
+            "my_setup": {0: "failed", 1: "failed", 2: "failed"},
+            "my_work": "upstream_failed",
+        }
+        assert states == expected
+
+    def test_many_one_explicit_odd_setup_one_mapped_fails(self, dag_maker):
+        """
+        one unmapped setup goes to two different teardowns
+        one mapped setup goes to same teardown
+        one of the mapped setup instances fails
+        teardowns should all run
+        """
+        with dag_maker(dag_id="many_one_explicit_odd_setup_one_mapped_fails") 
as dag:
+
+            @task
+            def other_setup():
+                print("other setup")
+                return "other setup"
+
+            @task
+            def other_work():
+                print("other work")
+                return "other work"
+
+            @task
+            def other_teardown():
+                print("other teardown")
+                return "other teardown"
+
+            @task
+            def my_setup(val):
+                if val == "data2.json":
+                    raise ValueError("fail!")
+                elif val == "data3.json":
+                    raise AirflowSkipException("skip!")
+                print(f"setup: {val}")
+                return val
+
+            @task
+            def my_work(val):
+                print(f"work: {val}")
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+            o_setup = other_setup()
+            o_teardown = other_teardown()
+            with o_teardown.as_teardown(setups=o_setup):
+                other_work()
+            t = my_teardown(s).as_teardown(setups=s)
+            with t:
+                my_work(s)
+            o_setup >> t
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": {0: "success", 1: "failed", 2: "skipped"},
+            "other_setup": "success",
+            "other_teardown": "success",
+            "other_work": "success",
+            "my_teardown": "success",
+            "my_work": "upstream_failed",
+        }
+        assert states == expected
+
+    def test_one_to_many_as_teardown(self, dag_maker, session):
+        """
+        1 setup mapping to 3 teardowns
+        1 work task
+        work fails
+        teardowns succeed
+        dagrun should be failure
+        """
+        with dag_maker(dag_id="one_to_many_as_teardown") as dag:
+
+            @task
+            def my_setup():
+                print("setting up multiple things")
+                return [1, 2, 3]
+
+            @task
+            def my_work(val):
+                print(f"doing work with multiple things: {val}")
+                raise ValueError("this fails")
+                return val
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup()
+            t = my_teardown.expand(val=s).as_teardown(setups=s)
+            with t:
+                my_work(s)
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": "success",
+            "my_teardown": {0: "success", 1: "success", 2: "success"},
+            "my_work": "failed",
+        }
+        assert states == expected
+
+    def test_one_to_many_as_teardown_offd(self, dag_maker, session):
+        """
+        1 setup mapping to 3 teardowns
+        1 work task
+        work succeeds
+        all but one teardown succeed
+        offd=True
+        dagrun should be success
+        """
+        with dag_maker(dag_id="one_to_many_as_teardown_offd") as dag:
+
+            @task
+            def my_setup():
+                print("setting up multiple things")
+                return [1, 2, 3]
+
+            @task
+            def my_work(val):
+                print(f"doing work with multiple things: {val}")
+                return val
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+                if val == 2:
+                    raise ValueError("failure")
+
+            s = my_setup()
+            t = my_teardown.expand(val=s).as_teardown(setups=s, 
on_failure_fail_dagrun=True)
+            with t:
+                my_work(s)
+            # todo: if on_failure_fail_dagrun=True, should we still regard the 
WORK task as a leaf?
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": "success",
+            "my_teardown": {0: "success", 1: "failed", 2: "success"},
+            "my_work": "success",
+        }
+        assert states == expected
+
+    def test_mapped_task_group_simple(self, dag_maker, session):
+        """
+        Mapped task group wherein there's a simple s >> w >> t pipeline.
+        When s is skipped, all should be skipped
+        When s is failed, all should be upstream failed
+        """
+        with dag_maker(dag_id="mapped_task_group_simple") as dag:
+
+            @setup
+            def my_setup(val):
+                if val == "data2.json":
+                    raise ValueError("fail!")
+                elif val == "data3.json":
+                    raise AirflowSkipException("skip!")
+                print(f"setup: {val}")
+
+            @task
+            def my_work(val):
+                print(f"work: {val}")
+
+            @teardown
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            @task_group
+            def file_transforms(filename):
+                s = my_setup(filename)
+                t = my_teardown(filename).as_teardown(setups=s)
+                with t:
+                    my_work(filename)
+
+            file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "file_transforms.my_setup": {0: "success", 1: "failed", 2: 
"skipped"},
+            "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: 
"skipped"},
+            "file_transforms.my_teardown": {0: "success", 1: 
"upstream_failed", 2: "skipped"},
+        }
+
+        assert states == expected
+
+    def test_mapped_task_group_work_fail_or_skip(self, dag_maker, session):
+        """
+        Mapped task group wherein there's a simple s >> w >> t pipeline.
+        When w is skipped, teardown should still run
+        When w is failed, teardown should still run
+        """
+        with dag_maker(dag_id="mapped_task_group_work_fail_or_skip") as dag:
+
+            @setup
+            def my_setup(val):
+                print(f"setup: {val}")
+
+            @task
+            def my_work(val):
+                if val == "data2.json":
+                    raise ValueError("fail!")
+                elif val == "data3.json":
+                    raise AirflowSkipException("skip!")
+                print(f"work: {val}")
+
+            @teardown
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            @task_group
+            def file_transforms(filename):
+                s = my_setup(filename)
+                t = my_teardown(filename).as_teardown(setups=s)
+                with t:
+                    my_work(filename)
+
+            file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "file_transforms.my_setup": {0: "success", 1: "success", 2: 
"success"},
+            "file_transforms.my_teardown": {0: "success", 1: "success", 2: 
"success"},
+            "file_transforms.my_work": {0: "success", 1: "failed", 2: 
"skipped"},
+        }
+        assert states == expected
+
+    def test_teardown_many_one_explicit(self, dag_maker, session):
+        """-- passing
+        one mapped setup going to one unmapped work
+        3 diff states for setup: success / failed / skipped
+        teardown still runs, and receives the xcom from the single successful 
setup
+        """
+        with dag_maker(dag_id="teardown_many_one_explicit") as dag:
+
+            @task
+            def my_setup(val):
+                if val == "data2.json":
+                    raise ValueError("fail!")
+                elif val == "data3.json":
+                    raise AirflowSkipException("skip!")
+                print(f"setup: {val}")
+                return val
+
+            @task
+            def my_work(val):
+                print(f"work: {val}")
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
+            with my_teardown(s).as_teardown(setups=s):
+                my_work(s)
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": {0: "success", 1: "failed", 2: "skipped"},
+            "my_teardown": "success",
+            "my_work": "upstream_failed",
+        }
+        assert states == expected
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 9e7f5eb48a..78f7fb9690 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3026,6 +3026,7 @@ class TestTaskInstance:
             "_try_number": 1,
             "max_tries": 1,
             "hostname": "some_unique_hostname",
+            "is_setup": False,
             "unixname": "some_unique_unixname",
             "job_id": 1234,
             "pool": "some_fake_pool_id",
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
index 2d55b30cb3..c1795f22a8 100644
--- a/tests/models/test_taskmixin.py
+++ b/tests/models/test_taskmixin.py
@@ -211,61 +211,6 @@ def 
test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker):
             m.operator.is_setup = True
 
 
-def test_no_setup_or_teardown_for_mapped_operator(dag_maker):
-    @task
-    def add_one(x):
-        return x + 1
-
-    @task
-    def print_task(values):
-        print(sum(values))
-
-    # vanilla mapped task
-    with dag_maker():
-        added_vals = add_one.expand(x=[1, 2, 3])
-        print_task(added_vals)
-
-    # combining setup and teardown with vanilla mapped task is fine
-    with dag_maker():
-        s1 = BaseOperator(task_id="s1").as_setup()
-        t1 = BaseOperator(task_id="t1").as_teardown(setups=s1)
-        added_vals = add_one.expand(x=[1, 2, 3])
-        print_task_task = print_task(added_vals)
-        s1 >> added_vals
-        print_task_task >> t1
-    # confirm structure
-    assert s1.downstream_task_ids == {"add_one", "t1"}
-    assert t1.upstream_task_ids == {"print_task", "s1"}
-    assert added_vals.operator.upstream_task_ids == {"s1"}
-    assert added_vals.operator.downstream_task_ids == {"print_task"}
-    assert print_task_task.operator.upstream_task_ids == {"add_one"}
-    assert print_task_task.operator.downstream_task_ids == {"t1"}
-
-    # but you can't use a mapped task as setup or teardown
-    with dag_maker():
-        added_vals = add_one.expand(x=[1, 2, 3])
-        with pytest.raises(ValueError, match="Cannot set is_teardown for 
mapped operator"):
-            added_vals.as_teardown()
-
-    # ... no matter how hard you try
-    with dag_maker():
-        added_vals = add_one.expand(x=[1, 2, 3])
-        with pytest.raises(ValueError, match="Cannot set is_teardown for 
mapped operator"):
-            added_vals.is_teardown = True
-
-    # same with setup
-    with dag_maker():
-        added_vals = add_one.expand(x=[1, 2, 3])
-        with pytest.raises(ValueError, match="Cannot set is_setup for mapped 
operator"):
-            added_vals.as_setup()
-
-    # and again, trying harder...
-    with dag_maker():
-        added_vals = add_one.expand(x=[1, 2, 3])
-        with pytest.raises(ValueError, match="Cannot set is_setup for mapped 
operator"):
-            added_vals.is_setup = True
-
-
 def test_set_setup_teardown_ctx_dependencies_using_decorated_tasks(dag_maker):
 
     with dag_maker():
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index c89122879e..f59794509e 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -53,6 +53,7 @@ from airflow.operators.empty import EmptyOperator
 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
 from airflow.security import permissions
 from airflow.sensors.bash import BashSensor
+from airflow.serialization.enums import Encoding
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import (
     DagDependency,
@@ -112,7 +113,8 @@ executor_config_pod = k8s.V1Pod(
         ]
     ),
 )
-
+TYPE = Encoding.TYPE
+VAR = Encoding.VAR
 serialized_simple_dag_ground_truth = {
     "__version": 1,
     "dag": {
@@ -383,7 +385,11 @@ class TestStringifiedDAGs:
             serialized_dags[v.dag_id] = dag
 
         # Compares with the ground truth of JSON string.
-        self.validate_serialized_dag(serialized_dags["simple_dag"], 
serialized_simple_dag_ground_truth)
+        actual, expected = self.prepare_ser_dags_for_comparison(
+            actual=serialized_dags["simple_dag"],
+            expected=serialized_simple_dag_ground_truth,
+        )
+        assert actual == expected
 
     @pytest.mark.parametrize(
         "timetable, serialized_timetable",
@@ -412,7 +418,14 @@ class TestStringifiedDAGs:
         del expected["dag"]["schedule_interval"]
         expected["dag"]["timetable"] = serialized_timetable
 
-        self.validate_serialized_dag(serialized_dag, expected)
+        actual, expected = self.prepare_ser_dags_for_comparison(
+            actual=serialized_dag,
+            expected=expected,
+        )
+        for task in actual["dag"]["tasks"]:
+            for k, v in task.items():
+                print(task["task_id"], k, v)
+        assert actual == expected
 
     def test_dag_serialization_unregistered_custom_timetable(self):
         """Verify serialization fails without timetable registration."""
@@ -429,10 +442,10 @@ class TestStringifiedDAGs:
         )
         assert str(ctx.value) == message
 
-    def validate_serialized_dag(self, json_dag, ground_truth_dag):
+    def prepare_ser_dags_for_comparison(self, actual, expected):
         """Verify serialized DAGs match the ground truth."""
-        assert json_dag["dag"]["fileloc"].split("/")[-1] == 
"test_dag_serialization.py"
-        json_dag["dag"]["fileloc"] = None
+        assert actual["dag"]["fileloc"].split("/")[-1] == 
"test_dag_serialization.py"
+        actual["dag"]["fileloc"] = None
 
         def sorted_serialized_dag(dag_dict: dict):
             """
@@ -447,7 +460,11 @@ class TestStringifiedDAGs:
             )
             return dag_dict
 
-        assert sorted_serialized_dag(ground_truth_dag) == 
sorted_serialized_dag(json_dag)
+        # by roundtripping to json we get a cleaner diff
+        # if not doing this, we get false alarms such as "__var" != VAR
+        actual = json.loads(json.dumps(sorted_serialized_dag(actual)))
+        expected = json.loads(json.dumps(sorted_serialized_dag(expected)))
+        return actual, expected
 
     def test_deserialization_across_process(self):
         """A serialized DAG can be deserialized in another process."""
@@ -2290,6 +2307,9 @@ def test_taskflow_expand_serde():
         "_operator_name": "@task",
         "downstream_task_ids": [],
         "partial_kwargs": {
+            "is_setup": False,
+            "is_teardown": False,
+            "on_failure_fail_dagrun": False,
             "op_args": [],
             "op_kwargs": {
                 "__type": "dict",
@@ -2329,6 +2349,9 @@ def test_taskflow_expand_serde():
         value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", 
"key": XCOM_RETURN_KEY})},
     )
     assert deserialized.partial_kwargs == {
+        "is_setup": False,
+        "is_teardown": False,
+        "on_failure_fail_dagrun": False,
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
         "retry_delay": timedelta(seconds=30),
@@ -2344,6 +2367,9 @@ def test_taskflow_expand_serde():
         value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", 
"key": XCOM_RETURN_KEY})},
     )
     assert pickled.partial_kwargs == {
+        "is_setup": False,
+        "is_teardown": False,
+        "on_failure_fail_dagrun": False,
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
         "retry_delay": timedelta(seconds=30),
@@ -2376,6 +2402,9 @@ def test_taskflow_expand_kwargs_serde(strict):
         "_operator_name": "@task",
         "downstream_task_ids": [],
         "partial_kwargs": {
+            "is_setup": False,
+            "is_teardown": False,
+            "on_failure_fail_dagrun": False,
             "op_args": [],
             "op_kwargs": {
                 "__type": "dict",
@@ -2413,6 +2442,9 @@ def test_taskflow_expand_kwargs_serde(strict):
         value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
     )
     assert deserialized.partial_kwargs == {
+        "is_setup": False,
+        "is_teardown": False,
+        "on_failure_fail_dagrun": False,
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
         "retry_delay": timedelta(seconds=30),
@@ -2428,6 +2460,9 @@ def test_taskflow_expand_kwargs_serde(strict):
         _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
     )
     assert pickled.partial_kwargs == {
+        "is_setup": False,
+        "is_teardown": False,
+        "on_failure_fail_dagrun": False,
         "op_args": [],
         "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
         "retry_delay": timedelta(seconds=30),
diff --git a/tests/www/views/test_views_tasks.py 
b/tests/www/views/test_views_tasks.py
index 32df99c047..93bc52884c 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -1017,6 +1017,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1049,6 +1050,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1081,6 +1083,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1113,6 +1116,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1145,6 +1149,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1177,6 +1182,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,
@@ -1209,6 +1215,7 @@ def test_task_instances(admin_client):
             "executor_config": {},
             "external_executor_id": None,
             "hostname": "",
+            "is_setup": False,
             "job_id": None,
             "map_index": -1,
             "max_tries": 0,


Reply via email to