This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fad297c900 Speed up calculation of leaves and roots for task groups 
(#32592)
fad297c900 is described below

commit fad297c900551301c1dcb33e9128959e18fe737e
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Jul 14 10:50:52 2023 -0700

    Speed up calculation of leaves and roots for task groups (#32592)
    
    Previously, every call to has_task would iterate the group.  Also, using 
set operation is faster than `any`.
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/utils/task_group.py | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index e55a4abbe1..2b117ca7da 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -362,8 +362,10 @@ class TaskGroup(DAGNode):
         Returns a generator of tasks that are root tasks, i.e. those with no 
upstream
         dependencies within the TaskGroup.
         """
-        for task in self:
-            if not any(self.has_task(parent) for parent in 
task.get_direct_relatives(upstream=True)):
+        tasks = list(self)
+        ids = {x.task_id for x in tasks}
+        for task in tasks:
+            if not task.upstream_task_ids.intersection(ids):
                 yield task
 
     def get_leaves(self) -> Generator[BaseOperator, None, None]:
@@ -371,22 +373,24 @@ class TaskGroup(DAGNode):
         Returns a generator of tasks that are leaf tasks, i.e. those with no 
downstream
         dependencies within the TaskGroup.
         """
+        tasks = list(self)
+        ids = {x.task_id for x in tasks}
 
-        def recurse_for_first_non_setup_teardown(group, task):
+        def recurse_for_first_non_setup_teardown(task):
             for upstream_task in task.upstream_list:
-                if not group.has_task(upstream_task):
+                if upstream_task.task_id not in ids:
                     continue
                 if upstream_task.is_setup or upstream_task.is_teardown:
-                    yield from recurse_for_first_non_setup_teardown(group, 
upstream_task)
+                    yield from 
recurse_for_first_non_setup_teardown(upstream_task)
                 else:
                     yield upstream_task
 
-        for task in self:
-            if not any(self.has_task(x) for x in 
task.get_direct_relatives(upstream=False)):
+        for task in tasks:
+            if not task.downstream_task_ids.intersection(ids):
                 if not (task.is_teardown or task.is_setup):
                     yield task
                 else:
-                    yield from recurse_for_first_non_setup_teardown(self, task)
+                    yield from recurse_for_first_non_setup_teardown(task)
 
     def child_id(self, label):
         """

Reply via email to