This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a45a7c149a Protect having work task upstream of setup/teardown task 
context manager (#31642)
a45a7c149a is described below

commit a45a7c149ae02cf9e0e64f4119223ed1743c81f7
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Jun 2 06:27:02 2023 +0100

    Protect having work task upstream of setup/teardown task context manager 
(#31642)
    
    * Protect having work task upstream of setup/teardown task context manager
    
    This commit ensures users don't use work task upstream of setup/teardown 
task
    in a context manager. When such is detected, a ValueError will be raised so 
user
    is aware that it's an ambigous setting
    
    * fixup! Protect having work task upstream of setup/teardown task context 
manager
    
    * Update airflow/utils/setup_teardown.py
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * fixup! Update airflow/utils/setup_teardown.py
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/utils/setup_teardown.py         | 26 ++++++++++-
 tests/decorators/test_setup_teardown.py | 76 +++++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
index 9b28085880..9155e9704a 100644
--- a/airflow/utils/setup_teardown.py
+++ b/airflow/utils/setup_teardown.py
@@ -110,13 +110,24 @@ class SetupTeardownContext:
             if first_task.is_teardown:
                 if not all(task.is_teardown == first_task.is_teardown for task 
in operator):
                     raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
+                upstream_tasks = first_task.upstream_list
+                for task in upstream_tasks:
+                    if not task.is_setup and not task.is_teardown:
+                        raise ValueError(
+                            "All upstream tasks in the context manager must be 
a setup or teardown task"
+                        )
                 
SetupTeardownContext.push_context_managed_teardown_task(operator)
-                upstream_setup: list[Operator] = [task for task in 
first_task.upstream_list if task.is_setup]
+                upstream_setup: list[Operator] = [task for task in 
upstream_tasks if task.is_setup]
                 if upstream_setup:
                     
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
             elif first_task.is_setup:
                 if not all(task.is_setup == first_task.is_setup for task in 
operator):
                     raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
+                for task in first_task.upstream_list:
+                    if not task.is_setup and not task.is_teardown:
+                        raise ValueError(
+                            "All upstream tasks in the context manager must be 
a setup or teardown task"
+                        )
                 SetupTeardownContext.push_context_managed_setup_task(operator)
                 downstream_teardown: list[Operator] = [
                     task for task in first_task.downstream_list if 
task.is_teardown
@@ -124,11 +135,22 @@ class SetupTeardownContext:
                 if downstream_teardown:
                     
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
         elif operator.is_teardown:
+            upstream_tasks = operator.upstream_list
+            for task in upstream_tasks:
+                if not task.is_setup and not task.is_teardown:
+                    raise ValueError(
+                        "All upstream tasks in the context manager must be a 
setup or teardown task"
+                    )
             SetupTeardownContext.push_context_managed_teardown_task(operator)
-            upstream_setup = [task for task in operator.upstream_list if 
task.is_setup]
+            upstream_setup = [task for task in upstream_tasks if task.is_setup]
             if upstream_setup:
                 
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
         elif operator.is_setup:
+            for task in operator.upstream_list:
+                if not task.is_setup and not task.is_teardown:
+                    raise ValueError(
+                        "All upstream tasks in the context manager must be a 
setup or teardown task"
+                    )
             SetupTeardownContext.push_context_managed_setup_task(operator)
             downstream_teardown = [task for task in operator.downstream_list 
if task.is_teardown]
             if downstream_teardown:
diff --git a/tests/decorators/test_setup_teardown.py 
b/tests/decorators/test_setup_teardown.py
index 2a1e8136bd..4e4b8ed54f 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -1088,3 +1088,79 @@ class TestSetupTearDownTask:
             "setuptask",
             "mytask",
         }
+
+    def test_work_task_inbetween_setup_n_teardown_tasks(self, dag_maker):
+        @task
+        def mytask():
+            print("mytask")
+
+        @setup
+        def setuptask():
+            print("setuptask")
+
+        @teardown
+        def teardowntask():
+            print("teardowntask")
+
+        with pytest.raises(
+            ValueError, match="All upstream tasks in the context manager must 
be a setup or teardown task"
+        ):
+            with dag_maker():
+                with setuptask() >> mytask() >> teardowntask():
+                    ...
+
+    def test_errors_when_work_task_is_upstream_of_setup_task(self, dag_maker):
+        @task
+        def mytask():
+            print("mytask")
+
+        @setup
+        def setuptask():
+            print("setuptask")
+
+        with pytest.raises(
+            ValueError, match="All upstream tasks in the context manager must 
be a setup or teardown task"
+        ):
+            with dag_maker():
+                with mytask() >> setuptask():
+                    ...
+
+    def 
test_errors_when_work_task_is_upstream_of_context_wrapper_with_teardown(self, 
dag_maker):
+        @task
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardowntask")
+
+        @teardown
+        def teardowntask2():
+            print("teardowntask")
+
+        with pytest.raises(
+            ValueError, match="All upstream tasks in the context manager must 
be a setup or teardown task"
+        ):
+            with dag_maker():
+                with mytask() >> context_wrapper([teardowntask(), 
teardowntask2()]):
+                    ...
+
+    def 
test_errors_when_work_task_is_upstream_of_context_wrapper_with_setup(self, 
dag_maker):
+        @task
+        def mytask():
+            print("mytask")
+
+        @setup
+        def setuptask():
+            print("setuptask")
+
+        @setup
+        def setuptask2():
+            print("setuptask")
+
+        with pytest.raises(
+            ValueError, match="All upstream tasks in the context manager must 
be a setup or teardown task"
+        ):
+            with dag_maker():
+                with mytask() >> context_wrapper([setuptask(), setuptask2()]):
+                    ...

Reply via email to