ephraimbuddy commented on code in PR #32687:
URL: https://github.com/apache/airflow/pull/32687#discussion_r1269348327


##########
airflow/utils/setup_teardown.py:
##########
@@ -134,57 +168,147 @@ def push_setup_teardown_task(cls, operator: 
AbstractOperator | list[AbstractOper
     @classmethod
     def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], 
setup: bool = False):
         if isinstance(operator, list):
-            upstream_tasks = operator[0].upstream_list
-            downstream_list = operator[0].downstream_list
             if not all(task.is_setup == operator[0].is_setup for task in 
operator):
                 cls.error("All tasks in the list must be either setup or 
teardown tasks")
-        else:
-            upstream_tasks = operator.upstream_list
-            downstream_list = operator.downstream_list
         if setup:
             cls.push_context_managed_setup_task(operator)
-            if downstream_list:
-                cls.push_context_managed_teardown_task(list(downstream_list))
+            # workout the teardown
+            cls._update_teardown_downstream(operator)
         else:
             cls.push_context_managed_teardown_task(operator)
-            if upstream_tasks:
-                cls.push_context_managed_setup_task(list(upstream_tasks))
+            # workout the setups
+            cls._update_setup_upstream(operator)
+
+    @classmethod
+    def _update_teardown_downstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks downstream of the setup in 
the context manager,
+        if found, updates the _teardown_downstream_of_setup accordingly.
+
+        """
+        operator = operator[0] if isinstance(operator, list) else operator
+
+        def _get_teardowns(tasks):
+            teardowns = [i for i in tasks if i.is_teardown]
+            if not teardowns:
+                all_lists = [task.downstream_list + task.upstream_list for 
task in tasks]
+                new_list = [
+                    x
+                    for sublist in all_lists
+                    for x in sublist
+                    if (isinstance(operator, list) and x in operator) or x != 
operator
+                ]
+                if not new_list:
+                    return []
+                return _get_teardowns(new_list)
+            return teardowns
+
+        teardowns = _get_teardowns(operator.downstream_list)
+        teardown_task = cls._teardown_downstream_of_setup
+        if teardown_task and teardown_task != teardowns:
+            
cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
+        cls._teardown_downstream_of_setup = teardowns
+
+    @classmethod
+    def _update_setup_upstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks upstream of the teardown task 
in the context manager,
+        if found, updates the _setup_upstream_of_teardown accordingly.
+
+        """
+        operator = operator[0] if isinstance(operator, list) else operator
+
+        def _get_setups(tasks):
+            setups = [i for i in tasks if i.is_setup]
+            if not setups:
+                all_lists = [task.downstream_list + task.upstream_list for 
task in tasks]
+                new_list = [
+                    x
+                    for sublist in all_lists
+                    for x in sublist
+                    if (isinstance(operator, list) and x in operator) or x != 
operator
+                ]
+                if not new_list:
+                    return []
+                return _get_setups(new_list)
+            return setups
+
+        setups = _get_setups(operator.upstream_list)
+        setup_task = cls._setup_upstream_of_teardown
+        if setup_task and setup_task != setups:
+            
cls._previous_setup_upstream_of_teardown.append(cls._setup_upstream_of_teardown)
+        cls._setup_upstream_of_teardown = setups
+
+    @classmethod
+    def set_teardown_task_as_leaves(cls, leaves):
+        teardown_task = cls._teardown_downstream_of_setup
+        if cls._context_managed_teardown_task:
+            cls.set_dependency(cls._context_managed_teardown_task, 
teardown_task)
+        else:
+            cls.set_dependency(leaves, teardown_task)
+
+    @classmethod
+    def set_setup_task_as_roots(cls, roots):
+        setup_task = cls._setup_upstream_of_teardown
+        if cls._context_managed_setup_task:
+            cls.set_dependency(cls._context_managed_setup_task, setup_task, 
upstream=False)
+        else:
+            cls.set_dependency(roots, setup_task, upstream=False)
 
     @classmethod
     def set_work_task_roots_and_leaves(cls):
-        if setup_task := cls.get_context_managed_setup_task():
+        """Sets the work task roots and leaves."""
+        if setup_task := cls._context_managed_setup_task:
             if isinstance(setup_task, list):
                 setup_task = tuple(setup_task)
-            tasks_in_context = cls.context_map.get(setup_task, [])
+            tasks_in_context = [
+                x for x in cls.context_map.get(setup_task, []) if not 
x.is_teardown and not x.is_setup
+            ]
             if tasks_in_context:
                 roots = [task for task in tasks_in_context if not 
task.upstream_list]
                 if not roots:
-                    setup_task >> tasks_in_context[0]
-                elif isinstance(setup_task, tuple):
-                    for task in setup_task:
-                        task >> roots
+                    setup_task >> list(tasks_in_context)[0]
                 else:
-                    setup_task >> roots
-        if teardown_task := cls.get_context_managed_teardown_task():
+                    cls.set_dependency(roots, setup_task, upstream=False)
+                leaves = [task for task in tasks_in_context if not 
task.downstream_list]
+                cls.set_teardown_task_as_leaves(leaves)
+
+        if teardown_task := cls._context_managed_teardown_task:
             if isinstance(teardown_task, list):
                 teardown_task = tuple(teardown_task)
-            tasks_in_context = cls.context_map.get(teardown_task, [])
+            tasks_in_context = [
+                x for x in cls.context_map.get(teardown_task, []) if not 
x.is_teardown and not x.is_setup
+            ]
             if tasks_in_context:
                 leaves = [task for task in tasks_in_context if not 
task.downstream_list]
                 if not leaves:
-                    teardown_task << tasks_in_context[-1]
-                elif isinstance(teardown_task, tuple):
-                    for task in teardown_task:
-                        task << leaves
+                    teardown_task << list(tasks_in_context).pop()
                 else:
-                    teardown_task << leaves
+                    cls.set_dependency(leaves, teardown_task)
+                roots = [task for task in tasks_in_context if not 
task.upstream_list]
+                cls.set_setup_task_as_roots(roots)
+        cls.set_setup_teardown_relationships()
+        cls.active = False
+
+    @classmethod
+    def set_setup_teardown_relationships(cls):
+        """
+        Here we set relationship between setup to setup and
+        teardown to teardown.
+
+        code:: python
+            with setuptask >> teardowntask:
+                with setuptask2 >> teardowntask2:
+                    ...
+
+        We set setuptask >> setuptask2, teardowntask >> teardowntask2

Review Comment:
   I'm not sure if running in parallel will align with the construct when using 
the context manager. Users would expect a run from inside to outside if the 
context manager is nested but when not nested, they should run in parallel IMO. 
   e.g
   ```
   with s1 >> t1:
       with s2 >> t2:
   ```
   Making that such that t1 and t2 runs in parallel won't align with what the 
context manager really said, i.e from outside run s1 -> s2, when exiting run t2 
-> t1. Changing it to run in parallel will not really make sense from the 
readability side of it. 
   Something like below can be seen that they can run in parallel:
   ```
   with t1, t2:
   ```
   or 
   ```
   with s1 > t1:
        w1
   with s2 > t2:
       w2
   ```
   It shows no connection between t1 and t2 but nested context managers shows a 
connection IMO
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to