ephraimbuddy commented on code in PR #32235:
URL: https://github.com/apache/airflow/pull/32235#discussion_r1253374630


##########
airflow/utils/setup_teardown.py:
##########
@@ -97,70 +107,64 @@ def pop_context_managed_teardown_task(cls) -> Operator | 
list[Operator] | None:
                 else:
                     teardown_task.set_upstream(old_teardown_task)
         else:
-            cls._context_managed_teardown_task = None
+            cls._context_managed_teardown_task = []
         return old_teardown_task
 
     @classmethod
-    def get_context_managed_setup_task(cls) -> Operator | list[Operator] | 
None:
+    def get_context_managed_setup_task(cls) -> AbstractOperator | 
list[AbstractOperator]:
         return cls._context_managed_setup_task
 
     @classmethod
-    def get_context_managed_teardown_task(cls) -> Operator | list[Operator] | 
None:
+    def get_context_managed_teardown_task(cls) -> AbstractOperator | 
list[AbstractOperator]:
         return cls._context_managed_teardown_task
 
     @classmethod
-    def push_setup_teardown_task(cls, operator: Operator | list[Operator]):
+    def push_setup_teardown_task(cls, operator: AbstractOperator | 
list[AbstractOperator]):
         if isinstance(operator, list):
-            first_task: Operator = operator[0]
-            if first_task.is_teardown:
-                if not all(task.is_teardown == first_task.is_teardown for task 
in operator):
-                    raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
-                upstream_tasks = first_task.upstream_list
-                for task in upstream_tasks:
-                    if not task.is_setup and not task.is_teardown:
-                        raise ValueError(
-                            "All upstream tasks in the context manager must be 
a setup or teardown task"
-                        )
-                
BaseSetupTeardownContext.push_context_managed_teardown_task(operator)
-                upstream_setup: list[Operator] = [task for task in 
upstream_tasks if task.is_setup]
-                if upstream_setup:
-                    
BaseSetupTeardownContext.push_context_managed_setup_task(upstream_setup)
-            elif first_task.is_setup:
-                if not all(task.is_setup == first_task.is_setup for task in 
operator):
-                    raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
-                for task in first_task.upstream_list:
-                    if not task.is_setup and not task.is_teardown:
-                        raise ValueError(
-                            "All upstream tasks in the context manager must be 
a setup or teardown task"
-                        )
-                
BaseSetupTeardownContext.push_context_managed_setup_task(operator)
-                downstream_teardown: list[Operator] = [
-                    task for task in first_task.downstream_list if 
task.is_teardown
-                ]
-                if downstream_teardown:
-                    
BaseSetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
+            if operator[0].is_teardown:
+                cls._push_tasks(operator)
+            elif operator[0].is_setup:
+                cls._push_tasks(operator, setup=True)
         elif operator.is_teardown:
-            upstream_tasks = operator.upstream_list
-            for task in upstream_tasks:
-                if not task.is_setup and not task.is_teardown:
-                    raise ValueError(
-                        "All upstream tasks in the context manager must be a 
setup or teardown task"
-                    )
-            
BaseSetupTeardownContext.push_context_managed_teardown_task(operator)
-            upstream_setup = [task for task in upstream_tasks if task.is_setup]
-            if upstream_setup:
-                
BaseSetupTeardownContext.push_context_managed_setup_task(upstream_setup)
+            cls._push_tasks(operator)
         elif operator.is_setup:
-            for task in operator.upstream_list:
-                if not task.is_setup and not task.is_teardown:
-                    raise ValueError(
-                        "All upstream tasks in the context manager must be a 
setup or teardown task"
+            cls._push_tasks(operator, setup=True)
+        cls.active = True
+
+    @classmethod
+    def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], 
setup: bool = False):
+        if isinstance(operator, list):
+            upstream_tasks = operator[0].upstream_list
+            downstream_list = operator[0].downstream_list
+            if not all(task.is_setup == operator[0].is_setup for task in 
operator):
+                cls.error("All tasks in the list must be either setup or 
teardown tasks")

Review Comment:
   I think it's Ok. What I mean is that if t is teardown and s is setup then we 
can't have `[t1,s1,t2]` but only `[s1,s2,s3]` or `[t1,t2,t3]`. 
   For `[t1, s1,t2]` the `is_setup` check will be `[False, True, False]` and it 
would fail(there's a mixup). 
   For `[s1,s2,s3]`, the `is_setup` will be `[True, True, True]` which will 
pass because all are the same. 
   For `[t1,t2,t3]`, we have `[False, False, False]` which will pass too.
   So the message is correct



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to