This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 12b0b6b607 Fix mapped setup teardown classic operator (#32934)
12b0b6b607 is described below

commit 12b0b6b60721d7ebfc4e5440a7ca9bf10efb8d11
Author: Daniel Standish <[email protected]>
AuthorDate: Sun Jul 30 18:40:54 2023 -0700

    Fix mapped setup teardown classic operator (#32934)
---
 airflow/models/abstractoperator.py            |  10 +
 airflow/models/baseoperator.py                |   9 -
 tests/models/test_mappedoperator.py           | 808 +++++++++++++++++---------
 tests/serialization/test_dag_serialization.py |  23 +
 4 files changed, 576 insertions(+), 274 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 7cba88f28d..11e9184735 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -34,6 +34,7 @@ from airflow.utils.context import Context
 from airflow.utils.db import exists_query
 from airflow.utils.log.secrets_masker import redact
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.sqlalchemy import skip_locked, with_row_locks
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
@@ -709,3 +710,12 @@ class AbstractOperator(Templater, DAGNode):
                 raise
             else:
                 setattr(parent, attr_name, rendered_content)
+
+    def __enter__(self):
+        if not self.is_setup and not self.is_teardown:
+            raise AirflowException("Only setup/teardown tasks can be used as 
context managers.")
+        SetupTeardownContext.push_setup_teardown_task(self)
+        return SetupTeardownContext
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        SetupTeardownContext.set_work_task_roots_and_leaves()
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index e72fcc2940..d49babbeb0 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -962,15 +962,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if SetupTeardownContext.active:
             SetupTeardownContext.update_context_map(self)
 
-    def __enter__(self):
-        if not self.is_setup and not self.is_teardown:
-            raise AirflowException("Only setup/teardown tasks can be used as 
context managers.")
-        SetupTeardownContext.push_setup_teardown_task(self)
-        return SetupTeardownContext
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        SetupTeardownContext.set_work_task_roots_and_leaves()
-
     def __eq__(self, other):
         if type(self) is type(other):
             # Use getattr() instead of __dict__ as __dict__ doesn't return
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 1cfb53628a..d626b8499e 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -34,6 +34,7 @@ from airflow.models.param import ParamsDict
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom_arg import XComArg
+from airflow.operators.python import PythonOperator
 from airflow.utils.context import Context
 from airflow.utils.state import TaskInstanceState
 from airflow.utils.task_group import TaskGroup
@@ -667,32 +668,75 @@ class TestMappedSetupTeardown:
                 ti_dict[ti.task_id][ti.map_index] = ti.state
         return ti_dict
 
-    def test_one_to_many_work_failed(self, session, dag_maker):
+    def classic_operator(self, task_id, ret=None, partial=False, fail=False):
+        def success_callable(ret=None):
+            def inner(*args, **kwargs):
+                print(args)
+                print(kwargs)
+                if ret:
+                    return ret
+
+            return inner
+
+        def failure_callable():
+            def inner(*args, **kwargs):
+                print(args)
+                print(kwargs)
+                raise ValueError("fail")
+
+            return inner
+
+        kwargs = dict(task_id=task_id)
+        if not fail:
+            kwargs.update(python_callable=success_callable(ret=ret))
+        else:
+            kwargs.update(python_callable=failure_callable())
+        if partial:
+            return PythonOperator.partial(**kwargs)
+        else:
+            return PythonOperator(**kwargs)
+
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_one_to_many_work_failed(self, type_, dag_maker):
         """
         Work task failed.  Setup maps to teardown.  Should have 3 teardowns 
all successful even
         though the work task has failed.
         """
-        with dag_maker(dag_id="one_to_many") as dag:
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @setup
+                def my_setup():
+                    print("setting up multiple things")
+                    return [1, 2, 3]
 
-            @setup
-            def my_setup():
-                print("setting up multiple things")
-                return [1, 2, 3]
+                @task
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    raise ValueError("fail!")
+
+                @teardown
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup()
+                t = my_teardown.expand(val=s)
+                with t:
+                    my_work(s)
+        else:
 
             @task
             def my_work(val):
-                print(f"doing work with multiple things: {val}")
-                raise ValueError("fail!")
-                return val
-
-            @teardown
-            def my_teardown(val):
-                print(f"teardown: {val}")
+                print(f"work: {val}")
+                raise ValueError("i fail")
 
-            s = my_setup()
-            t = my_teardown.expand(val=s)
-            with t:
-                my_work(s)
+            with dag_maker() as dag:
+                my_setup = self.classic_operator("my_setup", [[1], [2], [3]])
+                my_teardown = self.classic_operator("my_teardown", 
partial=True)
+                t = my_teardown.expand(op_args=my_setup.output)
+                with t.as_teardown(setups=my_setup):
+                    my_work(my_setup.output)
+            return dag
 
         dr = dag.test()
         states = self.get_states(dr)
@@ -703,55 +747,80 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_many_one_explicit_odd_setup_mapped_setups_fail(self, dag_maker):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_many_one_explicit_odd_setup_mapped_setups_fail(self, type_, 
dag_maker):
         """
         one unmapped setup goes to two different teardowns
         one mapped setup goes to same teardown
         mapped setups fail
         teardowns should still run
         """
-        with dag_maker(
-            dag_id="many_one_explicit_odd_setup_mapped_setups_fail",
-        ) as dag:
-
-            @task
-            def other_setup():
-                print("other setup")
-                return "other setup"
-
-            @task
-            def other_work():
-                print("other work")
-                return "other work"
-
-            @task
-            def other_teardown():
-                print("other teardown")
-                return "other teardown"
-
-            @task
-            def my_setup(val):
-                print(f"setup: {val}")
-                raise ValueError("fail")
-                return val
-
-            @task
-            def my_work(val):
-                print(f"work: {val}")
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def other_setup():
+                    print("other setup")
+                    return "other setup"
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def other_teardown():
+                    print("other teardown")
+                    return "other teardown"
+
+                @task
+                def my_setup(val):
+                    print(f"setup: {val}")
+                    raise ValueError("fail")
+                    return val
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup.expand(val=["data1.json", "data2.json", 
"data3.json"])
+                o_setup = other_setup()
+                o_teardown = other_teardown()
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                t = my_teardown(s).as_teardown(setups=s)
+                with t:
+                    my_work(s)
+                o_setup >> t
+        else:
+            with dag_maker() as dag:
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                my_teardown = self.classic_operator("my_teardown")
+
+                my_setup = self.classic_operator("my_setup", partial=True, 
fail=True)
+                s = my_setup.expand(op_args=[["data1.json"], ["data2.json"], 
["data3.json"]])
+                o_setup = self.classic_operator("other_setup")
+                o_teardown = self.classic_operator("other_teardown")
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                t = my_teardown.as_teardown(setups=s)
+                with t:
+                    my_work(s.output)
+                o_setup >> t
 
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
-            o_setup = other_setup()
-            o_teardown = other_teardown()
-            with o_teardown.as_teardown(setups=o_setup):
-                other_work()
-            t = my_teardown(s).as_teardown(setups=s)
-            with t:
-                my_work(s)
-            o_setup >> t
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -764,56 +833,90 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_many_one_explicit_odd_setup_all_setups_fail(self, dag_maker):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_many_one_explicit_odd_setup_all_setups_fail(self, type_, 
dag_maker):
         """
         one unmapped setup goes to two different teardowns
         one mapped setup goes to same teardown
         all setups fail
         teardowns should not run
         """
-        with dag_maker(
-            dag_id="many_one_explicit_odd_setup_all_setups_fail",
-        ) as dag:
-
-            @task
-            def other_setup():
-                print("other setup")
-                raise ValueError("fail")
-                return "other setup"
-
-            @task
-            def other_work():
-                print("other work")
-                return "other work"
-
-            @task
-            def other_teardown():
-                print("other teardown")
-                return "other teardown"
-
-            @task
-            def my_setup(val):
-                print(f"setup: {val}")
-                raise ValueError("fail")
-                return val
-
-            @task
-            def my_work(val):
-                print(f"work: {val}")
-
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
-            o_setup = other_setup()
-            o_teardown = other_teardown()
-            with o_teardown.as_teardown(setups=o_setup):
-                other_work()
-            t = my_teardown(s).as_teardown(setups=s)
-            with t:
-                my_work(s)
-            o_setup >> t
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def other_setup():
+                    print("other setup")
+                    raise ValueError("fail")
+                    return "other setup"
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def other_teardown():
+                    print("other teardown")
+                    return "other teardown"
+
+                @task
+                def my_setup(val):
+                    print(f"setup: {val}")
+                    raise ValueError("fail")
+                    return val
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup.expand(val=["data1.json", "data2.json", 
"data3.json"])
+                o_setup = other_setup()
+                o_teardown = other_teardown()
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                t = my_teardown(s).as_teardown(setups=s)
+                with t:
+                    my_work(s)
+                o_setup >> t
+        else:
+            with dag_maker() as dag:
+
+                @task
+                def other_setup():
+                    print("other setup")
+                    raise ValueError("fail")
+                    return "other setup"
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def other_teardown():
+                    print("other teardown")
+                    return "other teardown"
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                my_setup = self.classic_operator("my_setup", partial=True, 
fail=True)
+                s = my_setup.expand(op_args=[["data1.json"], ["data2.json"], 
["data3.json"]])
+                o_setup = other_setup()
+                o_teardown = other_teardown()
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                my_teardown = self.classic_operator("my_teardown")
+                t = my_teardown.as_teardown(setups=s)
+                with t:
+                    my_work(s.output)
+                o_setup >> t
 
         dr = dag.test()
         states = self.get_states(dr)
@@ -827,56 +930,106 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_many_one_explicit_odd_setup_one_mapped_fails(self, dag_maker):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_many_one_explicit_odd_setup_one_mapped_fails(self, type_, 
dag_maker):
         """
         one unmapped setup goes to two different teardowns
         one mapped setup goes to same teardown
         one of the mapped setup instances fails
         teardowns should all run
         """
-        with dag_maker(dag_id="many_one_explicit_odd_setup_one_mapped_fails") 
as dag:
-
-            @task
-            def other_setup():
-                print("other setup")
-                return "other setup"
-
-            @task
-            def other_work():
-                print("other work")
-                return "other work"
-
-            @task
-            def other_teardown():
-                print("other teardown")
-                return "other teardown"
-
-            @task
-            def my_setup(val):
-                if val == "data2.json":
-                    raise ValueError("fail!")
-                elif val == "data3.json":
-                    raise AirflowSkipException("skip!")
-                print(f"setup: {val}")
-                return val
-
-            @task
-            def my_work(val):
-                print(f"work: {val}")
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def other_setup():
+                    print("other setup")
+                    return "other setup"
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def other_teardown():
+                    print("other teardown")
+                    return "other teardown"
+
+                @task
+                def my_setup(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+                    return val
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup.expand(val=["data1.json", "data2.json", 
"data3.json"])
+                o_setup = other_setup()
+                o_teardown = other_teardown()
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                t = my_teardown(s).as_teardown(setups=s)
+                with t:
+                    my_work(s)
+                o_setup >> t
+        else:
+            with dag_maker() as dag:
+
+                @task
+                def other_setup():
+                    print("other setup")
+                    return "other setup"
+
+                @task
+                def other_work():
+                    print("other work")
+                    return "other work"
+
+                @task
+                def other_teardown():
+                    print("other teardown")
+                    return "other teardown"
+
+                def my_setup_callable(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+                    return val
+
+                my_setup = PythonOperator.partial(task_id="my_setup", 
python_callable=my_setup_callable)
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                def my_teardown_callable(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup.expand(op_args=[["data1.json"], ["data2.json"], 
["data3.json"]])
+                o_setup = other_setup()
+                o_teardown = other_teardown()
+                with o_teardown.as_teardown(setups=o_setup):
+                    other_work()
+                my_teardown = PythonOperator(
+                    task_id="my_teardown", op_args=[s.output], 
python_callable=my_teardown_callable
+                )
+                t = my_teardown.as_teardown(setups=s)
+                with t:
+                    my_work(s.output)
+                o_setup >> t
 
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
-            o_setup = other_setup()
-            o_teardown = other_teardown()
-            with o_teardown.as_teardown(setups=o_setup):
-                other_work()
-            t = my_teardown(s).as_teardown(setups=s)
-            with t:
-                my_work(s)
-            o_setup >> t
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -889,7 +1042,8 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_one_to_many_as_teardown(self, dag_maker, session):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_one_to_many_as_teardown(self, type_, dag_maker):
         """
         1 setup mapping to 3 teardowns
         1 work task
@@ -897,27 +1051,43 @@ class TestMappedSetupTeardown:
         teardowns succeed
         dagrun should be failure
         """
-        with dag_maker(dag_id="one_to_many_as_teardown") as dag:
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def my_setup():
+                    print("setting up multiple things")
+                    return [1, 2, 3]
+
+                @task
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    raise ValueError("this fails")
+                    return val
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup()
+                t = my_teardown.expand(val=s).as_teardown(setups=s)
+                with t:
+                    my_work(s)
+        else:
+            with dag_maker() as dag:
 
-            @task
-            def my_setup():
-                print("setting up multiple things")
-                return [1, 2, 3]
+                @task
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    raise ValueError("this fails")
+                    return val
 
-            @task
-            def my_work(val):
-                print(f"doing work with multiple things: {val}")
-                raise ValueError("this fails")
-                return val
+                my_teardown = self.classic_operator(task_id="my_teardown", 
partial=True)
 
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            s = my_setup()
-            t = my_teardown.expand(val=s).as_teardown(setups=s)
-            with t:
-                my_work(s)
+                s = self.classic_operator(task_id="my_setup", ret=[[1], [2], 
[3]])
+                t = my_teardown.expand(op_args=s.output).as_teardown(setups=s)
+                with t:
+                    my_work(s)
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -927,38 +1097,61 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_one_to_many_as_teardown_offd(self, dag_maker, session):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_one_to_many_as_teardown_on_failure_fail_dagrun(self, type_, 
dag_maker):
         """
         1 setup mapping to 3 teardowns
         1 work task
         work succeeds
         all but one teardown succeed
-        offd=True
+        on_failure_fail_dagrun=True
         dagrun should be success
         """
-        with dag_maker(dag_id="one_to_many_as_teardown_offd") as dag:
-
-            @task
-            def my_setup():
-                print("setting up multiple things")
-                return [1, 2, 3]
-
-            @task
-            def my_work(val):
-                print(f"doing work with multiple things: {val}")
-                return val
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def my_setup():
+                    print("setting up multiple things")
+                    return [1, 2, 3]
+
+                @task
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    return val
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+                    if val == 2:
+                        raise ValueError("failure")
+
+                s = my_setup()
+                t = my_teardown.expand(val=s).as_teardown(setups=s, 
on_failure_fail_dagrun=True)
+                with t:
+                    my_work(s)
+                # todo: if on_failure_fail_dagrun=True, should we still regard 
the WORK task as a leaf?
+        else:
+            with dag_maker() as dag:
+
+                @task
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    return val
+
+                def my_teardown_callable(val):
+                    print(f"teardown: {val}")
+                    if val == 2:
+                        raise ValueError("failure")
+
+                s = self.classic_operator(task_id="my_setup", ret=[[1], [2], 
[3]])
+                my_teardown = PythonOperator.partial(
+                    task_id="my_teardown", python_callable=my_teardown_callable
+                ).expand(op_args=s.output)
+                t = my_teardown.as_teardown(setups=s, 
on_failure_fail_dagrun=True)
+                with t:
+                    my_work(s.output)
 
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
-                if val == 2:
-                    raise ValueError("failure")
-
-            s = my_setup()
-            t = my_teardown.expand(val=s).as_teardown(setups=s, 
on_failure_fail_dagrun=True)
-            with t:
-                my_work(s)
-            # todo: if on_failure_fail_dagrun=True, should we still regard the 
WORK task as a leaf?
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -968,38 +1161,70 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_mapped_task_group_simple(self, dag_maker, session):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_mapped_task_group_simple(self, type_, dag_maker, session):
         """
         Mapped task group wherein there's a simple s >> w >> t pipeline.
         When s is skipped, all should be skipped
         When s is failed, all should be upstream failed
         """
-        with dag_maker(dag_id="mapped_task_group_simple") as dag:
-
-            @setup
-            def my_setup(val):
-                if val == "data2.json":
-                    raise ValueError("fail!")
-                elif val == "data3.json":
-                    raise AirflowSkipException("skip!")
-                print(f"setup: {val}")
-
-            @task
-            def my_work(val):
-                print(f"work: {val}")
-
-            @teardown
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            @task_group
-            def file_transforms(filename):
-                s = my_setup(filename)
-                t = my_teardown(filename).as_teardown(setups=s)
-                with t:
-                    my_work(filename)
-
-            file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @setup
+                def my_setup(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                @teardown
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                @task_group
+                def file_transforms(filename):
+                    s = my_setup(filename)
+                    t = my_teardown(filename)
+                    s >> t
+                    with t:
+                        my_work(filename)
+
+                file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        else:
+            with dag_maker() as dag:
+
+                def my_setup_callable(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                def my_teardown_callable(val):
+                    print(f"teardown: {val}")
+
+                @task_group
+                def file_transforms(filename):
+                    s = PythonOperator(
+                        task_id="my_setup", python_callable=my_setup_callable, 
op_args=filename
+                    )
+                    t = PythonOperator(
+                        task_id="my_teardown", 
python_callable=my_teardown_callable, op_args=filename
+                    )
+                    with t.as_teardown(setups=s):
+                        my_work(filename)
+
+                file_transforms.expand(filename=[["data1.json"], 
["data2.json"], ["data3.json"]])
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -1010,38 +1235,68 @@ class TestMappedSetupTeardown:
 
         assert states == expected
 
-    def test_mapped_task_group_work_fail_or_skip(self, dag_maker, session):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_mapped_task_group_work_fail_or_skip(self, type_, dag_maker):
         """
         Mapped task group wherein there's a simple s >> w >> t pipeline.
         When w is skipped, teardown should still run
         When w is failed, teardown should still run
         """
-        with dag_maker(dag_id="mapped_task_group_work_fail_or_skip") as dag:
-
-            @setup
-            def my_setup(val):
-                print(f"setup: {val}")
-
-            @task
-            def my_work(val):
-                if val == "data2.json":
-                    raise ValueError("fail!")
-                elif val == "data3.json":
-                    raise AirflowSkipException("skip!")
-                print(f"work: {val}")
-
-            @teardown
-            def my_teardown(val):
-                print(f"teardown: {val}")
-
-            @task_group
-            def file_transforms(filename):
-                s = my_setup(filename)
-                t = my_teardown(filename).as_teardown(setups=s)
-                with t:
-                    my_work(filename)
-
-            file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @setup
+                def my_setup(val):
+                    print(f"setup: {val}")
+
+                @task
+                def my_work(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"work: {val}")
+
+                @teardown
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                @task_group
+                def file_transforms(filename):
+                    s = my_setup(filename)
+                    t = my_teardown(filename).as_teardown(setups=s)
+                    with t:
+                        my_work(filename)
+
+                file_transforms.expand(filename=["data1.json", "data2.json", 
"data3.json"])
+        else:
+            with dag_maker() as dag:
+
+                @task
+                def my_work(vals):
+                    val = vals[0]
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"work: {val}")
+
+                @teardown
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                def null_callable(val):
+                    pass
+
+                @task_group
+                def file_transforms(filename):
+                    s = PythonOperator(task_id="my_setup", 
python_callable=null_callable, op_args=filename)
+                    t = PythonOperator(task_id="my_teardown", 
python_callable=null_callable, op_args=filename)
+                    t = t.as_teardown(setups=s)
+                    with t:
+                        my_work(filename)
+
+                file_transforms.expand(filename=[["data1.json"], 
["data2.json"], ["data3.json"]])
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
@@ -1051,34 +1306,57 @@ class TestMappedSetupTeardown:
         }
         assert states == expected
 
-    def test_teardown_many_one_explicit(self, dag_maker, session):
+    @pytest.mark.parametrize("type_", ["taskflow", "classic"])
+    def test_teardown_many_one_explicit(self, type_, dag_maker):
         """-- passing
         one mapped setup going to one unmapped work
         3 diff states for setup: success / failed / skipped
         teardown still runs, and receives the xcom from the single successful 
setup
         """
-        with dag_maker(dag_id="teardown_many_one_explicit") as dag:
-
-            @task
-            def my_setup(val):
-                if val == "data2.json":
-                    raise ValueError("fail!")
-                elif val == "data3.json":
-                    raise AirflowSkipException("skip!")
-                print(f"setup: {val}")
-                return val
-
-            @task
-            def my_work(val):
-                print(f"work: {val}")
-
-            @task
-            def my_teardown(val):
-                print(f"teardown: {val}")
+        if type_ == "taskflow":
+            with dag_maker() as dag:
+
+                @task
+                def my_setup(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+                    return val
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                @task
+                def my_teardown(val):
+                    print(f"teardown: {val}")
+
+                s = my_setup.expand(val=["data1.json", "data2.json", 
"data3.json"])
+                with my_teardown(s).as_teardown(setups=s):
+                    my_work(s)
+        else:
+            with dag_maker() as dag:
+
+                def my_setup_callable(val):
+                    if val == "data2.json":
+                        raise ValueError("fail!")
+                    elif val == "data3.json":
+                        raise AirflowSkipException("skip!")
+                    print(f"setup: {val}")
+                    return val
+
+                @task
+                def my_work(val):
+                    print(f"work: {val}")
+
+                s = PythonOperator.partial(task_id="my_setup", 
python_callable=my_setup_callable)
+                s = s.expand(op_args=[["data1.json"], ["data2.json"], 
["data3.json"]])
+                t = self.classic_operator("my_teardown")
+                with t.as_teardown(setups=s):
+                    my_work(s.output)
 
-            s = my_setup.expand(val=["data1.json", "data2.json", "data3.json"])
-            with my_teardown(s).as_teardown(setups=s):
-                my_work(s)
         dr = dag.test()
         states = self.get_states(dr)
         expected = {
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index f59794509e..301e54a0ec 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1414,6 +1414,29 @@ class TestStringifiedDAGs:
         assert task.is_teardown is True
         assert task.on_failure_fail_dagrun is True
 
+    def test_teardown_mapped_serialization(self, dag_maker):
+        with dag_maker() as dag:
+
+            @teardown(on_failure_fail_dagrun=True)
+            def mytask(val=None):
+                print(1)
+
+            mytask.expand(val=[1, 2, 3])
+
+        task = dag.task_group.children["mytask"]
+        assert task.partial_kwargs["is_teardown"] is True
+        assert task.partial_kwargs["on_failure_fail_dagrun"] is True
+
+        dag_dict = SerializedDAG.to_dict(dag)
+        SerializedDAG.validate_schema(dag_dict)
+        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+        self.validate_deserialized_dag(json_dag, dag)
+
+        serialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
+        task = serialized_dag.task_group.children["mytask"]
+        assert task.partial_kwargs["is_teardown"] is True
+        assert task.partial_kwargs["on_failure_fail_dagrun"] is True
+
     def test_deps_sorted(self):
         """
         Tests serialize_operator, make sure the deps is in order


Reply via email to