This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c76555930a Refactor setup/teardown decorator so they return 
_TaskDecorator (#30342)
c76555930a is described below

commit c76555930aee9692d2a839b9c7b9e2220717b8a0
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Mar 28 19:15:07 2023 +0100

    Refactor setup/teardown decorator so they return _TaskDecorator (#30342)
    
    * Refactor setup/teardown decorator so they return _TaskDecorator
    
    Prior to this, the setup/teardown decorators do return ordinary functions.
    Returning _TaskDecorator like @task will be useful for future works
    
    * fixup! Refactor setup/teardown decorator so they return _TaskDecorator
    
    * Fix tests
    
    * Update tests/decorators/test_setup_teardown.py
    
    Co-authored-by: Daniel Standish 
<[email protected]>
    
    * Fix on_failure_fail_dagrun default
    
    ---------
    
    Co-authored-by: Daniel Standish 
<[email protected]>
    Co-authored-by: Jed Cunningham <[email protected]>
---
 airflow/decorators/base.py                    |  12 ++-
 airflow/decorators/setup_teardown.py          |  41 ++++------
 airflow/models/baseoperator.py                |  23 ++----
 airflow/utils/setup_teardown.py               |  55 -------------
 airflow/utils/task_group.py                   |   5 --
 tests/decorators/test_setup_teardown.py       |  92 +++++++++++++++++----
 tests/serialization/test_dag_serialization.py |   1 +
 tests/utils/test_setup_teardown.py            | 113 --------------------------
 8 files changed, 112 insertions(+), 230 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index bc83c7786b..3ece688dd9 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -301,6 +301,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
     decorator_name: str = attr.ib(repr=False, default="task")
 
     _airflow_is_task_decorator: ClassVar[bool] = True
+    _is_setup: ClassVar[bool] = False
+    _is_teardown: ClassVar[bool] = False
+    _on_failure_fail_dagrun: ClassVar[bool] = False
 
     @multiple_outputs.default
     def _infer_multiple_outputs(self):
@@ -341,6 +344,9 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
             multiple_outputs=self.multiple_outputs,
             **self.kwargs,
         )
+        op._is_setup = self._is_setup
+        op._is_teardown = self._is_teardown
+        op._on_failure_fail_dagrun = self._on_failure_fail_dagrun
         op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, 
op.doc_yaml]
         # Set the task's doc_md to the function's docstring if it exists and 
no other doc* args are set.
         if self.function.__doc__ and not any(op_doc_attrs):
@@ -473,7 +479,11 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
 
     def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, 
OperatorSubclass]:
-        return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
+        result = attr.evolve(self, kwargs={**self.kwargs, **kwargs})
+        setattr(result, "_is_setup", self._is_setup)
+        setattr(result, "_is_teardown", self._is_teardown)
+        setattr(result, "_on_failure_fail_dagrun", 
self._on_failure_fail_dagrun)
+        return result
 
 
 @attr.define(kw_only=True, repr=False)
diff --git a/airflow/decorators/setup_teardown.py 
b/airflow/decorators/setup_teardown.py
index 8480bde636..8fccdb9fe5 100644
--- a/airflow/decorators/setup_teardown.py
+++ b/airflow/decorators/setup_teardown.py
@@ -16,39 +16,34 @@
 # under the License.
 from __future__ import annotations
 
-import functools
 import types
 from typing import Callable
 
+from airflow import AirflowException
 from airflow.decorators import python_task
-from airflow.utils.setup_teardown import SetupTeardownContext
+from airflow.decorators.task_group import _TaskGroupFactory
 
 
-def setup_task(python_callable: Callable) -> Callable:
+def setup_task(func: Callable) -> Callable:
     # Using FunctionType here since _TaskDecorator is also a callable
-    if isinstance(python_callable, types.FunctionType):
-        python_callable = python_task(python_callable)
+    if isinstance(func, types.FunctionType):
+        func = python_task(func)
+    if isinstance(func, _TaskGroupFactory):
+        raise AirflowException("Task groups cannot be marked as setup or 
teardown.")
+    func._is_setup = True  # type: ignore[attr-defined]
+    return func
 
-    @functools.wraps(python_callable)
-    def wrapper(*args, **kwargs):
-        with SetupTeardownContext.setup():
-            return python_callable(*args, **kwargs)
 
-    return wrapper
-
-
-def teardown_task(_func=None, *, on_failure_fail_dagrun: bool | None = None) 
-> Callable:
-    def teardown(python_callable: Callable) -> Callable:
+def teardown_task(_func=None, *, on_failure_fail_dagrun: bool = False) -> 
Callable:
+    def teardown(func: Callable) -> Callable:
         # Using FunctionType here since _TaskDecorator is also a callable
-        if isinstance(python_callable, types.FunctionType):
-            python_callable = python_task(python_callable)
-
-        @functools.wraps(python_callable)
-        def wrapper(*args, **kwargs) -> Callable:
-            with 
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
-                return python_callable(*args, **kwargs)
-
-        return wrapper
+        if isinstance(func, types.FunctionType):
+            func = python_task(func)
+        if isinstance(func, _TaskGroupFactory):
+            raise AirflowException("Task groups cannot be marked as setup or 
teardown.")
+        func._is_teardown = True  # type: ignore[attr-defined]
+        func._on_failure_fail_dagrun = on_failure_fail_dagrun  # type: 
ignore[attr-defined]
+        return func
 
     if _func is None:
         return teardown
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d78152d990..557dc634d8 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -88,7 +88,6 @@ from airflow.utils.decorators import 
fixup_decorator_warning_stack
 from airflow.utils.helpers import validate_key
 from airflow.utils.operator_resources import Resources
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -920,27 +919,19 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             )
             self.template_fields = [self.template_fields]
 
-        if SetupTeardownContext.is_setup:
-            self._is_setup = True
-        elif SetupTeardownContext.is_teardown:
-            self._is_teardown = True
-            if SetupTeardownContext.on_failure_fail_dagrun:
-                self._on_failure_fail_dagrun = True
-
     @classmethod
     def as_setup(cls, *args, **kwargs):
-        from airflow.utils.setup_teardown import SetupTeardownContext
-
-        with SetupTeardownContext.setup():
-            return cls(*args, **kwargs)
+        op = cls(*args, **kwargs)
+        op._is_setup = True
+        return op
 
     @classmethod
     def as_teardown(cls, *args, **kwargs):
-        from airflow.utils.setup_teardown import SetupTeardownContext
-
         on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
-        with 
SetupTeardownContext.teardown(on_failure_fail_dagrun=on_failure_fail_dagrun):
-            return cls(*args, **kwargs)
+        op = cls(*args, **kwargs)
+        op._is_teardown = True
+        op._on_failure_fail_dagrun = on_failure_fail_dagrun
+        return op
 
     def __eq__(self, other):
         if type(self) is type(other):
diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
deleted file mode 100644
index e421678e46..0000000000
--- a/airflow/utils/setup_teardown.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from contextlib import contextmanager
-
-from airflow.exceptions import AirflowException
-
-
-class SetupTeardownContext:
-    """Track whether the next added task is a setup or teardown task"""
-
-    is_setup: bool = False
-    is_teardown: bool = False
-    on_failure_fail_dagrun: bool = False
-
-    @classmethod
-    @contextmanager
-    def setup(cls):
-        if cls.is_setup or cls.is_teardown:
-            raise AirflowException("You cannot mark a setup or teardown task 
as setup or teardown again.")
-
-        cls.is_setup = True
-        try:
-            yield
-        finally:
-            cls.is_setup = False
-
-    @classmethod
-    @contextmanager
-    def teardown(cls, *, on_failure_fail_dagrun=False):
-        if cls.is_setup or cls.is_teardown:
-            raise AirflowException("You cannot mark a setup or teardown task 
as setup or teardown again.")
-
-        cls.is_teardown = True
-        cls.on_failure_fail_dagrun = on_failure_fail_dagrun
-        try:
-            yield
-        finally:
-            cls.is_teardown = False
-            cls.on_failure_fail_dagrun = False
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index b0611305d4..9b94a55952 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -38,7 +38,6 @@ from airflow.exceptions import (
 from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.utils.helpers import validate_group_key
-from airflow.utils.setup_teardown import SetupTeardownContext
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -158,10 +157,6 @@ class TaskGroup(DAGNode):
         self.upstream_task_ids = set()
         self.downstream_task_ids = set()
 
-        if SetupTeardownContext.is_setup or SetupTeardownContext.is_teardown:
-            # TODO: This might not be the ideal place to check this.
-            raise AirflowException("Task groups cannot be marked as setup or 
teardown.")
-
     def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
         if self._group_id is None:
             return
diff --git a/tests/decorators/test_setup_teardown.py 
b/tests/decorators/test_setup_teardown.py
index 38793c863b..151ba1d0d8 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -92,37 +92,40 @@ class TestSetupTearDownTask:
         assert teardown_task._is_teardown
 
     def test_setup_taskgroup_decorator(self, dag_maker):
-        @setup
-        @task_group
-        def mygroup():
-            @task
-            def mytask():
-                print("I am a setup task")
-
-            mytask()
-
         with dag_maker():
             with pytest.raises(
                 expected_exception=AirflowException,
                 match="Task groups cannot be marked as setup or teardown.",
             ):
+
+                @setup
+                @task_group
+                def mygroup():
+                    @task
+                    def mytask():
+                        print("I am a setup task")
+
+                    mytask()
+
                 mygroup()
 
     def test_teardown_taskgroup_decorator(self, dag_maker):
-        @teardown
-        @task_group
-        def mygroup():
-            @task
-            def mytask():
-                print("I am a teardown task")
-
-            mytask()
 
         with dag_maker():
             with pytest.raises(
                 expected_exception=AirflowException,
                 match="Task groups cannot be marked as setup or teardown.",
             ):
+
+                @teardown
+                @task_group
+                def mygroup():
+                    @task
+                    def mytask():
+                        print("I am a teardown task")
+
+                    mytask()
+
                 mygroup()
 
     @pytest.mark.parametrize("on_failure_fail_dagrun", [True, False])
@@ -153,3 +156,58 @@ class TestSetupTearDownTask:
         assert teardown_task._is_teardown
         assert teardown_task._on_failure_fail_dagrun is on_failure_fail_dagrun
         assert len(dag.task_group.children) == 1
+
+    def test_setup_task_can_be_overriden(self, dag_maker):
+        @setup
+        def mytask():
+            print("I am a setup task")
+
+        with dag_maker() as dag:
+            mytask.override(task_id="mytask2")()
+        assert len(dag.task_group.children) == 1
+        setup_task = dag.task_group.children["mytask2"]
+        assert setup_task._is_setup
+
+    def test_setup_teardown_mixed_up_in_a_dag(self, dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @setup
+        def setuptask2():
+            print("setup")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        @teardown
+        def teardowntask2():
+            print("teardown")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @task()
+        def mytask2():
+            print("mytask")
+
+        with dag_maker() as dag:
+            setuptask()
+            teardowntask()
+            setuptask2()
+            teardowntask2()
+            mytask()
+            mytask2()
+
+        assert len(dag.task_group.children) == 6
+        assert [x for x in dag.tasks if not x.downstream_list]  # no deps have 
been set
+        assert dag.task_group.children["setuptask"]._is_setup
+        assert dag.task_group.children["teardowntask"]._is_teardown
+        assert dag.task_group.children["setuptask2"]._is_setup
+        assert dag.task_group.children["teardowntask2"]._is_teardown
+        assert dag.task_group.children["mytask"]._is_setup is False
+        assert dag.task_group.children["mytask"]._is_teardown is False
+        assert dag.task_group.children["mytask2"]._is_setup is False
+        assert dag.task_group.children["mytask2"]._is_teardown is False
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 96cf19aa25..e514caf623 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -570,6 +570,7 @@ class TestStringifiedDAGs:
                 "on_retry_callback",
                 # Checked separately
                 "resources",
+                "_on_failure_fail_dagrun",
             }
         else:  # Promised to be mapped by the assert above.
             assert isinstance(serialized_task, MappedOperator)
diff --git a/tests/utils/test_setup_teardown.py 
b/tests/utils/test_setup_teardown.py
deleted file mode 100644
index 5907994bbc..0000000000
--- a/tests/utils/test_setup_teardown.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.exceptions import AirflowException
-from airflow.utils.setup_teardown import SetupTeardownContext
-
-
-class TestSetupTearDownContext:
-    def test_setup(self):
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-        with SetupTeardownContext.setup():
-            assert SetupTeardownContext.is_setup is True
-            assert SetupTeardownContext.is_teardown is False
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_teardown(self):
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-        with SetupTeardownContext.setup():
-            assert SetupTeardownContext.is_setup is True
-            assert SetupTeardownContext.is_teardown is False
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_setup_exception(self):
-        """Ensure context is reset even if an exception happens"""
-        with pytest.raises(Exception, match="Hello"):
-            with SetupTeardownContext.setup():
-                raise Exception("Hello")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_teardown_exception(self):
-        """Ensure context is reset even if an exception happens"""
-        with pytest.raises(Exception, match="Hello"):
-            with SetupTeardownContext.teardown():
-                raise Exception("Hello")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_setup_block_nested(self):
-        with SetupTeardownContext.setup():
-            with pytest.raises(
-                AirflowException,
-                match=("You cannot mark a setup or teardown task as setup or 
teardown again."),
-            ):
-                with SetupTeardownContext.setup():
-                    raise Exception("This should not be reached")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_teardown_block_nested(self):
-        with SetupTeardownContext.teardown():
-            with pytest.raises(
-                AirflowException,
-                match=("You cannot mark a setup or teardown task as setup or 
teardown again."),
-            ):
-                with SetupTeardownContext.teardown():
-                    raise Exception("This should not be reached")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_teardown_nested_in_setup_blocked(self):
-        with SetupTeardownContext.setup():
-            with pytest.raises(
-                AirflowException,
-                match=("You cannot mark a setup or teardown task as setup or 
teardown again."),
-            ):
-                with SetupTeardownContext.teardown():
-                    raise Exception("This should not be reached")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False
-
-    def test_setup_nested_in_teardown_blocked(self):
-        with SetupTeardownContext.teardown():
-            with pytest.raises(
-                AirflowException,
-                match=("You cannot mark a setup or teardown task as setup or 
teardown again."),
-            ):
-                with SetupTeardownContext.setup():
-                    raise Exception("This should not be reached")
-
-        assert SetupTeardownContext.is_setup is False
-        assert SetupTeardownContext.is_teardown is False

Reply via email to