jedcunningham commented on code in PR #32687:
URL: https://github.com/apache/airflow/pull/32687#discussion_r1268733976


##########
airflow/utils/setup_teardown.py:
##########
@@ -134,57 +168,147 @@ def push_setup_teardown_task(cls, operator: 
AbstractOperator | list[AbstractOper
     @classmethod
     def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], 
setup: bool = False):
         if isinstance(operator, list):
-            upstream_tasks = operator[0].upstream_list
-            downstream_list = operator[0].downstream_list
             if not all(task.is_setup == operator[0].is_setup for task in 
operator):
                 cls.error("All tasks in the list must be either setup or 
teardown tasks")
-        else:
-            upstream_tasks = operator.upstream_list
-            downstream_list = operator.downstream_list
         if setup:
             cls.push_context_managed_setup_task(operator)
-            if downstream_list:
-                cls.push_context_managed_teardown_task(list(downstream_list))
+            # workout the teardown
+            cls._update_teardown_downstream(operator)
         else:
             cls.push_context_managed_teardown_task(operator)
-            if upstream_tasks:
-                cls.push_context_managed_setup_task(list(upstream_tasks))
+            # workout the setups
+            cls._update_setup_upstream(operator)
+
+    @classmethod
+    def _update_teardown_downstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks downstream of the setup in 
the context manager,
+        if found, updates the _teardown_downstream_of_setup accordingly.
+
+        """
+        operator = operator[0] if isinstance(operator, list) else operator
+
+        def _get_teardowns(tasks):
+            teardowns = [i for i in tasks if i.is_teardown]
+            if not teardowns:
+                all_lists = [task.downstream_list + task.upstream_list for 
task in tasks]
+                new_list = [
+                    x
+                    for sublist in all_lists
+                    for x in sublist
+                    if (isinstance(operator, list) and x in operator) or x != 
operator
+                ]
+                if not new_list:
+                    return []
+                return _get_teardowns(new_list)
+            return teardowns
+
+        teardowns = _get_teardowns(operator.downstream_list)
+        teardown_task = cls._teardown_downstream_of_setup
+        if teardown_task and teardown_task != teardowns:
+            
cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
+        cls._teardown_downstream_of_setup = teardowns
+
+    @classmethod
+    def _update_setup_upstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks upstream of the teardown task 
in the context manager,
+        if found, updates the _setup_upstream_of_teardown accordingly.
+
+        """

Review Comment:
   ```suggestion
           """This recursively goes through the tasks upstream of the teardown 
task in the context manager,
           if found, updates the _setup_upstream_of_teardown accordingly.
           """
   ```



##########
airflow/utils/setup_teardown.py:
##########
@@ -134,57 +168,147 @@ def push_setup_teardown_task(cls, operator: 
AbstractOperator | list[AbstractOper
     @classmethod
     def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], 
setup: bool = False):
         if isinstance(operator, list):
-            upstream_tasks = operator[0].upstream_list
-            downstream_list = operator[0].downstream_list
             if not all(task.is_setup == operator[0].is_setup for task in 
operator):
                 cls.error("All tasks in the list must be either setup or 
teardown tasks")
-        else:
-            upstream_tasks = operator.upstream_list
-            downstream_list = operator.downstream_list
         if setup:
             cls.push_context_managed_setup_task(operator)
-            if downstream_list:
-                cls.push_context_managed_teardown_task(list(downstream_list))
+            # workout the teardown
+            cls._update_teardown_downstream(operator)
         else:
             cls.push_context_managed_teardown_task(operator)
-            if upstream_tasks:
-                cls.push_context_managed_setup_task(list(upstream_tasks))
+            # workout the setups
+            cls._update_setup_upstream(operator)
+
+    @classmethod
+    def _update_teardown_downstream(cls, operator: AbstractOperator | 
list[AbstractOperator]):
+        """This recursively go through the tasks downstream of the setup in 
the context manager,
+        if found, updates the _teardown_downstream_of_setup accordingly.
+
+        """

Review Comment:
   ```suggestion
           """This recursively goes through the tasks downstream of the setup 
in the context manager,
           if found, updates the _teardown_downstream_of_setup accordingly.
           """
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to