This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9bad850a21 optimise task instances filtering (#27102)
9bad850a21 is described below

commit 9bad850a21c0df47da18de146cf1c5a38677cb35
Author: James Ong <[email protected]>
AuthorDate: Thu Nov 17 20:21:16 2022 -0500

    optimise task instances filtering (#27102)
---
 airflow/models/taskinstance.py | 64 +++++++++++++++++++++++++++++++++++-------
 1 file changed, 54 insertions(+), 10 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 54ce2a2d2e..9ec2854be7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2415,34 +2415,78 @@ class TaskInstance(Base, LoggingMixin):
         run_id = first.run_id
         map_index = first.map_index
         first_task_id = first.task_id
+
+        # pre-compute the set of dag_id, run_id, map_indices and task_ids
+        dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
+        for t in tis:
+            dag_ids.add(t.dag_id)
+            run_ids.add(t.run_id)
+            map_indices.add(t.map_index)
+            task_ids.add(t.task_id)
+
         # Common path optimisations: when all TIs are for the same dag_id and 
run_id, or same dag_id
         # and task_id -- this can be over 150x faster for huge numbers of TIs 
(20k+)
-        if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == 
map_index for t in tis):
+        if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == 
{map_index}:
             return and_(
                 TaskInstance.dag_id == dag_id,
                 TaskInstance.run_id == run_id,
                 TaskInstance.map_index == map_index,
-                TaskInstance.task_id.in_(t.task_id for t in tis),
+                TaskInstance.task_id.in_(task_ids),
             )
-        if all(t.dag_id == dag_id and t.task_id == first_task_id and 
t.map_index == map_index for t in tis):
+        if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices 
== {map_index}:
             return and_(
                 TaskInstance.dag_id == dag_id,
-                TaskInstance.run_id.in_(t.run_id for t in tis),
+                TaskInstance.run_id.in_(run_ids),
                 TaskInstance.map_index == map_index,
                 TaskInstance.task_id == first_task_id,
             )
-        if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == 
first_task_id for t in tis):
+        if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == 
{first_task_id}:
             return and_(
                 TaskInstance.dag_id == dag_id,
                 TaskInstance.run_id == run_id,
-                TaskInstance.map_index.in_(t.map_index for t in tis),
+                TaskInstance.map_index.in_(map_indices),
                 TaskInstance.task_id == first_task_id,
             )
 
-        return tuple_in_condition(
-            (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, 
TaskInstance.map_index),
-            (ti.key.primary for ti in tis),
-        )
+        filter_condition = []
+        # create 2 nested groups, both primarily grouped by dag_id and run_id,
+        # and in the nested group 1 grouped by task_id the other by map_index.
+        task_id_groups: dict[tuple, dict[Any, list[Any]]] = 
defaultdict(lambda: defaultdict(list))
+        map_index_groups: dict[tuple, dict[Any, list[Any]]] = 
defaultdict(lambda: defaultdict(list))
+        for t in tis:
+            task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
+            map_index_groups[(t.dag_id, 
t.run_id)][t.map_index].append(t.task_id)
+
+        # this assumes that most dags have dag_id as the largest grouping, 
followed by run_id. even
+        # if its not, this is still  a significant optimization over querying 
for every single tuple key
+        for cur_dag_id in dag_ids:
+            for cur_run_id in run_ids:
+                # we compare the group size between task_id and map_index and 
use the smaller group
+                dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
+                dag_map_index_groups = map_index_groups[(cur_dag_id, 
cur_run_id)]
+
+                if len(dag_task_id_groups) <= len(dag_map_index_groups):
+                    for cur_task_id, cur_map_indices in 
dag_task_id_groups.items():
+                        filter_condition.append(
+                            and_(
+                                TaskInstance.dag_id == cur_dag_id,
+                                TaskInstance.run_id == cur_run_id,
+                                TaskInstance.task_id == cur_task_id,
+                                TaskInstance.map_index.in_(cur_map_indices),
+                            )
+                        )
+                else:
+                    for cur_map_index, cur_task_ids in 
dag_map_index_groups.items():
+                        filter_condition.append(
+                            and_(
+                                TaskInstance.dag_id == cur_dag_id,
+                                TaskInstance.run_id == cur_run_id,
+                                TaskInstance.task_id.in_(cur_task_ids),
+                                TaskInstance.map_index == cur_map_index,
+                            )
+                        )
+
+        return or_(*filter_condition)
 
     @classmethod
     def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> 
ColumnOperators:

Reply via email to