dstandish commented on code in PR #32687:
URL: https://github.com/apache/airflow/pull/32687#discussion_r1268658727


##########
airflow/utils/setup_teardown.py:
##########
@@ -134,57 +168,147 @@ def push_setup_teardown_task(cls, operator: 
AbstractOperator | list[AbstractOper
     @classmethod
     def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], 
setup: bool = False):
         if isinstance(operator, list):
-            upstream_tasks = operator[0].upstream_list
-            downstream_list = operator[0].downstream_list
             if not all(task.is_setup == operator[0].is_setup for task in 
operator):
                 cls.error("All tasks in the list must be either setup or 
teardown tasks")
-        else:
-            upstream_tasks = operator.upstream_list
-            downstream_list = operator.downstream_list
         if setup:
             cls.push_context_managed_setup_task(operator)
-            if downstream_list:
-                cls.push_context_managed_teardown_task(list(downstream_list))
+            # workout the teardown
+            cls._update_teardown_downstream(operator)
         else:
             cls.push_context_managed_teardown_task(operator)
-            if upstream_tasks:
-                cls.push_context_managed_setup_task(list(upstream_tasks))
+            # workout the setups
+            cls._update_setup_upstream(operator)
+
+    @classmethod
+    def _update_teardown_downstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks downstream of the setup in 
the context manager,
+        if found, updates the _teardown_downstream_of_setup accordingly.
+
+        """
+        operator = operator[0] if isinstance(operator, list) else operator
+
+        def _get_teardowns(tasks):
+            teardowns = [i for i in tasks if i.is_teardown]
+            if not teardowns:
+                all_lists = [task.downstream_list + task.upstream_list for 
task in tasks]
+                new_list = [
+                    x
+                    for sublist in all_lists
+                    for x in sublist
+                    if (isinstance(operator, list) and x in operator) or x != 
operator
+                ]
+                if not new_list:
+                    return []
+                return _get_teardowns(new_list)
+            return teardowns
+
+        teardowns = _get_teardowns(operator.downstream_list)
+        teardown_task = cls._teardown_downstream_of_setup
+        if teardown_task and teardown_task != teardowns:
+            
cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
+        cls._teardown_downstream_of_setup = teardowns
+
+    @classmethod
+    def _update_setup_upstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks upstream of the teardown task 
in the context manager,
+        if found, updates the _setup_upstream_of_teardown accordingly.
+
+        """
+        operator = operator[0] if isinstance(operator, list) else operator
+
+        def _get_setups(tasks):
+            setups = [i for i in tasks if i.is_setup]
+            if not setups:
+                all_lists = [task.downstream_list + task.upstream_list for 
task in tasks]
+                new_list = [
+                    x
+                    for sublist in all_lists
+                    for x in sublist
+                    if (isinstance(operator, list) and x in operator) or x != 
operator
+                ]
+                if not new_list:
+                    return []
+                return _get_setups(new_list)
+            return setups
+
+        setups = _get_setups(operator.upstream_list)
+        setup_task = cls._setup_upstream_of_teardown
+        if setup_task and setup_task != setups:
+            
cls._previous_setup_upstream_of_teardown.append(cls._setup_upstream_of_teardown)
+        cls._setup_upstream_of_teardown = setups
+
+    @classmethod
+    def set_teardown_task_as_leaves(cls, leaves):
+        teardown_task = cls._teardown_downstream_of_setup
+        if cls._context_managed_teardown_task:
+            cls.set_dependency(cls._context_managed_teardown_task, 
teardown_task)
+        else:
+            cls.set_dependency(leaves, teardown_task)
+
+    @classmethod
+    def set_setup_task_as_roots(cls, roots):
+        setup_task = cls._setup_upstream_of_teardown
+        if cls._context_managed_setup_task:
+            cls.set_dependency(cls._context_managed_setup_task, setup_task, 
upstream=False)
+        else:
+            cls.set_dependency(roots, setup_task, upstream=False)
 
     @classmethod
     def set_work_task_roots_and_leaves(cls):
-        if setup_task := cls.get_context_managed_setup_task():
+        """Sets the work task roots and leaves."""
+        if setup_task := cls._context_managed_setup_task:
             if isinstance(setup_task, list):
                 setup_task = tuple(setup_task)
-            tasks_in_context = cls.context_map.get(setup_task, [])
+            tasks_in_context = [
+                x for x in cls.context_map.get(setup_task, []) if not 
x.is_teardown and not x.is_setup
+            ]
             if tasks_in_context:
                 roots = [task for task in tasks_in_context if not 
task.upstream_list]
                 if not roots:
-                    setup_task >> tasks_in_context[0]
-                elif isinstance(setup_task, tuple):
-                    for task in setup_task:
-                        task >> roots
+                    setup_task >> list(tasks_in_context)[0]
                 else:
-                    setup_task >> roots
-        if teardown_task := cls.get_context_managed_teardown_task():
+                    cls.set_dependency(roots, setup_task, upstream=False)
+                leaves = [task for task in tasks_in_context if not 
task.downstream_list]
+                cls.set_teardown_task_as_leaves(leaves)
+
+        if teardown_task := cls._context_managed_teardown_task:
             if isinstance(teardown_task, list):
                 teardown_task = tuple(teardown_task)
-            tasks_in_context = cls.context_map.get(teardown_task, [])
+            tasks_in_context = [
+                x for x in cls.context_map.get(teardown_task, []) if not 
x.is_teardown and not x.is_setup
+            ]
             if tasks_in_context:
                 leaves = [task for task in tasks_in_context if not 
task.downstream_list]
                 if not leaves:
-                    teardown_task << tasks_in_context[-1]
-                elif isinstance(teardown_task, tuple):
-                    for task in teardown_task:
-                        task << leaves
+                    teardown_task << list(tasks_in_context).pop()
                 else:
-                    teardown_task << leaves
+                    cls.set_dependency(leaves, teardown_task)
+                roots = [task for task in tasks_in_context if not 
task.upstream_list]
+                cls.set_setup_task_as_roots(roots)
+        cls.set_setup_teardown_relationships()
+        cls.active = False
+
+    @classmethod
+    def set_setup_teardown_relationships(cls):
+        """
+        Here we set relationship between setup to setup and
+        teardown to teardown.
+
+        code:: python
+            with setuptask >> teardowntask:
+                with setuptask2 >> teardowntask2:
+                    ...
+
+        We set setuptask >> setuptask2, teardowntask >> teardowntask2

Review Comment:
   it may be out of the scope for this PR, but i think the behavior should be 
that we don't arrow teardowntask to teardowntask2 -- by default i think they 
should run in parallel.  wdyt?  making this change would make it consistent 
with how we handle teardowns in task groups -- they are ignored when 
calculating leaves.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to