This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d2f42e62cfc Optimize TaskGroup.topological_sort for reverse-declared 
Dags (#67688)
d2f42e62cfc is described below

commit d2f42e62cfcc9c16de8cd1dbe108a1e201bfb51f
Author: Shahar Epstein <[email protected]>
AuthorDate: Fri Jun 12 07:20:25 2026 +0300

    Optimize TaskGroup.topological_sort for reverse-declared Dags (#67688)
---
 airflow-core/newsfragments/67688.improvement.rst   |  1 +
 .../airflow/serialization/definitions/taskgroup.py | 67 ++++++++++++++--
 airflow-core/tests/unit/utils/test_task_group.py   | 42 ++++++++++
 task-sdk/src/airflow/sdk/definitions/taskgroup.py  | 92 ++++++++++++++++++----
 .../tests/task_sdk/definitions/test_taskgroup.py   | 45 +++++++++++
 5 files changed, 224 insertions(+), 23 deletions(-)

diff --git a/airflow-core/newsfragments/67688.improvement.rst 
b/airflow-core/newsfragments/67688.improvement.rst
new file mode 100644
index 00000000000..d2a641ac00c
--- /dev/null
+++ b/airflow-core/newsfragments/67688.improvement.rst
@@ -0,0 +1 @@
+Further optimize ``TaskGroup.topological_sort`` for reverse-declared DAGs via 
pass-number traversal; dramatically improves the O(N²) worst-case for 
adversarial shapes (e.g., reverse-insertion chains).
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py 
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index a5d8b730b05..65d59cb15f1 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -227,9 +227,24 @@ class SerializedTaskGroup(DAGNode):
         children = self.children
         if not children:
             return []
+
         nodes = list(children.values())
+        n = len(nodes)
         id_to_idx = {nid: i for i, nid in enumerate(children)}
-        projected = [self._project_child_deps(i, c, id_to_idx) for i, c in 
enumerate(nodes)]
+
+        projected: list[tuple[int, ...]] = [()] * n
+        nodes_with_back_edge = 0
+        for i, child in enumerate(nodes):
+            deps = self._project_child_deps(i, child, id_to_idx)
+            if deps:
+                projected[i] = deps
+                if any(d > i for d in deps):
+                    nodes_with_back_edge += 1
+
+        # The ratio catches dense back-heavy groups; a 32-node absolute cutoff 
keeps
+        # padded reverse-declared runs on the fast path once sweep rescans 
overtake pass-numbering.
+        if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
+            return self._sort_via_pass_numbering(nodes, projected)
         return self._sweep_projection(nodes, projected)
 
     def _project_child_deps(
@@ -242,16 +257,18 @@ class SerializedTaskGroup(DAGNode):
         for edge_id in upstream_ids:
             j = id_to_idx.get(edge_id)
             if j is not None:
-                sib_deps.add(j)
+                if j != child_idx:
+                    sib_deps.add(j)
                 continue
-            tg = self.dag.get_task(edge_id).task_group
+            edge = self.dag.get_task(edge_id)
+            tg = edge.task_group
             while tg is not None:
-                j = id_to_idx.get(tg.node_id)
-                if j is not None:
-                    sib_deps.add(j)
+                anc_idx = id_to_idx.get(tg.node_id)
+                if anc_idx is not None:
+                    if anc_idx != child_idx:
+                        sib_deps.add(anc_idx)
                     break
                 tg = tg.parent_group
-        sib_deps.discard(child_idx)
         return tuple(sib_deps)
 
     def _sweep_projection(self, nodes: list[DAGNode], projected: 
list[tuple[int, ...]]) -> list[DAGNode]:
@@ -291,6 +308,42 @@ class SerializedTaskGroup(DAGNode):
             pending = next_pending
         return order
 
+    def _sort_via_pass_numbering(
+        self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
+    ) -> list[DAGNode]:
+        n = len(nodes)
+        in_degree = [len(deps) for deps in projected]
+        successors: list[list[int]] = [[] for _ in range(n)]
+        for i, deps in enumerate(projected):
+            for d in deps:
+                successors[d].append(i)
+
+        pass_of = [0] * n
+        queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
+        processed = 0
+        while queue:
+            i = queue.popleft()
+            my_pass = 1
+            for d in projected[i]:
+                d_pass = pass_of[d]
+                if d < i:
+                    if d_pass > my_pass:
+                        my_pass = d_pass
+                elif d_pass + 1 > my_pass:
+                    my_pass = d_pass + 1
+            pass_of[i] = my_pass
+            processed += 1
+            for s in successors[i]:
+                in_degree[s] -= 1
+                if in_degree[s] == 0:
+                    queue.append(s)
+
+        if processed != n:
+            raise ValueError(f"A cyclic dependency occurred in dag: 
{self.dag_id}")
+
+        sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
+        return [nodes[i] for i in sorted_indices]
+
     def add(self, node: DAGNode) -> DAGNode:
         # Set the TG first, as setting it might change the return value of 
node_id!
         node.task_group = weakref.proxy(self)
diff --git a/airflow-core/tests/unit/utils/test_task_group.py 
b/airflow-core/tests/unit/utils/test_task_group.py
index e866bce62af..2d1458e95fb 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -1100,6 +1100,21 @@ def test_hierarchical_alphabetical_sort():
     ]
 
 
+def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> 
DAG:
+    with DAG(
+        f"padded_reverse_chain_{chain_length}_{independent_count}",
+        schedule=None,
+        start_date=DEFAULT_DATE,
+    ) as dag:
+        tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in 
range(chain_length)]
+        by_id = {task.task_id: task for task in tasks}
+        for i in range(chain_length - 1):
+            by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
+        for i in range(independent_count):
+            EmptyOperator(task_id=f"i{i}")
+    return dag
+
+
 def test_topological_group_dep():
     logical_date = pendulum.parse("20200101")
     with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag:
@@ -1163,6 +1178,33 @@ def test_topological_sort_serialized_layered():
                 )
 
 
+def 
test_topological_sort_serialized_padded_reverse_chain_uses_pass_numbering(monkeypatch):
+    dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
+    serialized = create_scheduler_dag(dag)
+    serialized.task_group.children = {
+        **{f"r{i}": serialized.task_group.children[f"r{i}"] for i in range(79, 
-1, -1)},
+        **{f"i{i}": serialized.task_group.children[f"i{i}"] for i in 
range(80)},
+    }
+
+    called = {"value": False}
+    serialized_task_group_cls = type(serialized.task_group)
+    original = serialized_task_group_cls._sort_via_pass_numbering
+
+    def spy(self, nodes, projected):
+        called["value"] = True
+        return original(self, nodes, projected)
+
+    monkeypatch.setattr(serialized_task_group_cls, "_sort_via_pass_numbering", 
spy)
+
+    order = [node.node_id for node in serialized.task_group.topological_sort()]
+    position = {node_id: i for i, node_id in enumerate(order)}
+
+    assert called["value"]
+    assert set(position) == {*(f"r{i}" for i in range(80)), *(f"i{i}" for i in 
range(80))}
+    for i in range(79):
+        assert position[f"r{i}"] < position[f"r{i + 1}"]
+
+
 def test_task_group_arrow_with_setup_group():
     with DAG(dag_id="setup_group_teardown_group") as dag:
         with TaskGroup("group_1") as g1:
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index cc1fc5cbda6..14bb2fba319 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -544,25 +544,42 @@ class TaskGroup(DAGNode):
         """
         Sort children topologically — a task always comes after its upstream 
dependencies.
 
-        Projects each child's per-task upstream IDs onto sibling-level integer 
indices once,
-        then runs a greedy multi-pass sweep using a bytearray-backed emission 
flag. Equivalent
-        in emission order to the previous modified-Kahn implementation, but 
moves the per-edge
-        ``upstream_list`` materialization and ``parent_group`` walks out of 
the sweep's inner
-        loop so they happen once per call instead of once per outer-loop pass.
+        Projects per-task upstream edges onto sibling-level integer indices, 
then dispatches:
+
+        - Forward-declared DAGs (few/no children declared after their 
dependents): greedy
+          multi-pass sweep over the projection, O(V + E) for the common case.
+        - Reverse-declared DAGs (many children declared before their 
dependents): pass-number
+          traversal, O((V + E) log V), avoids the O(N²) blowup the sweep would 
hit.
+
+        Both branches produce the same emission order: level-by-legacy-pass, 
ties broken by
+        children insertion order.
         """
         children = self.children
         if not children:
             return []
+
         nodes = list(children.values())
+        n = len(nodes)
         id_to_idx = {nid: i for i, nid in enumerate(children)}
-        projected = [self._project_child_deps(i, c, id_to_idx) for i, c in 
enumerate(nodes)]
+
+        projected: list[tuple[int, ...]] = [()] * n
+        nodes_with_back_edge = 0
+        for i, child in enumerate(nodes):
+            deps = self._project_child_deps(i, child, id_to_idx)
+            if deps:
+                projected[i] = deps
+                if any(d > i for d in deps):
+                    nodes_with_back_edge += 1
+
+        # The ratio catches dense back-heavy groups; a 32-node absolute cutoff 
keeps
+        # padded reverse-declared runs on the fast path once sweep rescans 
overtake pass-numbering.
+        if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
+            return self._sort_via_pass_numbering(nodes, projected)
         return self._sweep_projection(nodes, projected)
 
     def _project_child_deps(
         self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
     ) -> tuple[int, ...]:
-        # Project one child's per-task upstream IDs onto sibling-level integer 
indices.
-        # Self-deps are filtered once at the end via ``discard`` so the inner 
loop stays tight.
         upstream_ids = child.upstream_task_ids
         if not upstream_ids:
             return ()
@@ -570,23 +587,25 @@ class TaskGroup(DAGNode):
         for edge_id in upstream_ids:
             j = id_to_idx.get(edge_id)
             if j is not None:
-                sib_deps.add(j)
+                if j != child_idx:
+                    sib_deps.add(j)
                 continue
-            tg = self.dag.get_task(edge_id).task_group
+            edge = self.dag.get_task(edge_id)
+            tg = edge.task_group
             while tg is not None:
-                j = id_to_idx.get(tg.node_id)
-                if j is not None:
-                    sib_deps.add(j)
+                anc_idx = id_to_idx.get(tg.node_id)
+                if anc_idx is not None:
+                    if anc_idx != child_idx:
+                        sib_deps.add(anc_idx)
                     break
                 tg = tg.parent_group
-        sib_deps.discard(child_idx)
         return tuple(sib_deps)
 
     def _sweep_projection(self, nodes: list[DAGNode], projected: 
list[tuple[int, ...]]) -> list[DAGNode]:
         # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been 
emitted.
         # Pass 1 iterates range(n) directly; only blocked nodes are recorded 
into
-        # ``pending`` and re-checked in subsequent passes. Avoids paying for a
-        # ``list(range(n))`` allocation on single-pass shapes (the common 
case) while
+        # `pending` and re-checked in subsequent passes. Avoids paying for a
+        # `list(range(n))` allocation on single-pass shapes (the common case) 
while
         # still skipping already-emitted nodes on multi-pass shapes (e.g. a 
diamond's
         # single trailing sink).
         n = len(nodes)
@@ -625,6 +644,47 @@ class TaskGroup(DAGNode):
             pending = next_pending
         return order
 
+    def _sort_via_pass_numbering(
+        self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
+    ) -> list[DAGNode]:
+        # Sort by (pass_number, insertion_index). pass_number(X) is the 
earliest pass at
+        # which a greedy-sweep emission of X would occur:
+        #   pass(X) = max over deps d of (pass(d) if idx(d) < idx(X) else 
pass(d)+1)
+        # A dep declared before X can be emitted in the same pass; a dep 
declared after X
+        # forces X into the next pass. Computed via Kahn's traversal in O(V + 
E).
+        n = len(nodes)
+        in_degree = [len(deps) for deps in projected]
+        successors: list[list[int]] = [[] for _ in range(n)]
+        for i, deps in enumerate(projected):
+            for d in deps:
+                successors[d].append(i)
+
+        pass_of = [0] * n
+        queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
+        processed = 0
+        while queue:
+            i = queue.popleft()
+            my_pass = 1
+            for d in projected[i]:
+                d_pass = pass_of[d]
+                if d < i:
+                    if d_pass > my_pass:
+                        my_pass = d_pass
+                elif d_pass + 1 > my_pass:
+                    my_pass = d_pass + 1
+            pass_of[i] = my_pass
+            processed += 1
+            for s in successors[i]:
+                in_degree[s] -= 1
+                if in_degree[s] == 0:
+                    queue.append(s)
+
+        if processed != n:
+            raise AirflowDagCycleException(f"A cyclic dependency occurred in 
dag: {self.dag_id}")
+
+        sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
+        return [nodes[i] for i in sorted_indices]
+
     def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
         """
         Return mapped task groups in the hierarchy.
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py 
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
index cf6d309f305..d11f8eb3632 100644
--- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -1011,6 +1011,21 @@ def _make_reverse_chain(n: int) -> DAG:
     return dag
 
 
+def _make_padded_reverse_chain(chain_length: int, independent_count: int) -> 
DAG:
+    with DAG(
+        f"padded_reverse_chain_{chain_length}_{independent_count}",
+        schedule=None,
+        start_date=DEFAULT_DATE,
+    ) as dag:
+        tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in 
range(chain_length)]
+        by_id = {t.task_id: t for t in tasks}
+        for i in range(chain_length - 1):
+            by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
+        for i in range(independent_count):
+            EmptyOperator(task_id=f"i{i}")
+    return dag
+
+
 def _make_diamond(n: int) -> DAG:
     with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
         root = EmptyOperator(task_id="root")
@@ -1124,3 +1139,33 @@ def test_topological_sort_shape_correctness(shape, 
builder, n):
     for group in _walk_groups(dag.task_group):
         order = [node.node_id for node in group.topological_sort()]
         _assert_valid_topological_order(group, order)
+
+
+def test_topological_sort_reverse_declared_order_matches_sweep():
+    dag = _make_reverse_chain(100)
+    group = dag.task_group
+    nodes = list(group.children.values())
+    id_to_idx = {nid: i for i, nid in enumerate(group.children)}
+    projected = [group._project_child_deps(i, child, id_to_idx) for i, child 
in enumerate(nodes)]
+
+    sweep_order = [node.node_id for node in group._sweep_projection(nodes, 
projected)]
+    pass_number_order = [node.node_id for node in 
group._sort_via_pass_numbering(nodes, projected)]
+
+    assert pass_number_order == sweep_order
+
+
+def 
test_topological_sort_padded_reverse_chain_uses_pass_numbering(monkeypatch):
+    dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
+    called = {"value": False}
+    original = TaskGroup._sort_via_pass_numbering
+
+    def spy(self, nodes, projected):
+        called["value"] = True
+        return original(self, nodes, projected)
+
+    monkeypatch.setattr(TaskGroup, "_sort_via_pass_numbering", spy)
+
+    order = [node.node_id for node in dag.task_group.topological_sort()]
+
+    assert called["value"]
+    _assert_valid_topological_order(dag.task_group, order)

Reply via email to