This is an automated email from the ASF dual-hosted git repository.

shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bf71a3d444 Speed up TaskGroup.topological_sort with int-indexed 
projected sweep (#67288)
7bf71a3d444 is described below

commit 7bf71a3d444e2e1f5ef22b4c18f782c03281d001
Author: Shahar Epstein <[email protected]>
AuthorDate: Wed May 27 21:39:43 2026 +0300

    Speed up TaskGroup.topological_sort with int-indexed projected sweep 
(#67288)
---
 airflow-core/newsfragments/67288.improvement.rst   |   1 +
 .../airflow/serialization/definitions/taskgroup.py |  97 ++++++++++----
 airflow-core/tests/unit/utils/test_task_group.py   |  28 ++++
 task-sdk/src/airflow/sdk/definitions/taskgroup.py  | 125 +++++++++++-------
 .../tests/task_sdk/definitions/test_taskgroup.py   | 144 +++++++++++++++++++++
 5 files changed, 322 insertions(+), 73 deletions(-)

diff --git a/airflow-core/newsfragments/67288.improvement.rst 
b/airflow-core/newsfragments/67288.improvement.rst
new file mode 100644
index 00000000000..03293e4ffa2
--- /dev/null
+++ b/airflow-core/newsfragments/67288.improvement.rst
@@ -0,0 +1 @@
+Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond, 
layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups.
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py 
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index d971c303c7c..5db656019f1 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -18,7 +18,6 @@
 
 from __future__ import annotations
 
-import copy
 import functools
 import operator
 import weakref
@@ -217,35 +216,79 @@ class SerializedTaskGroup(DAGNode):
 
     def topological_sort(self) -> list[DAGNode]:
         """
-        Sorts children in topographical order.
+        Sort children topologically — a task always comes after its upstream 
dependencies.
 
-        A task in the result would come after any of its upstream dependencies.
+        See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. 
Cycles are
+        treated as corrupt input: ``DAG.check_cycle`` rejects cyclic Dags 
before
+        serialization, so a cycle reaching this code indicates malformed 
serialized data,
+        and we raise ``ValueError`` rather than silently looping forever.
         """
-        # This uses a modified version of Kahn's Topological Sort algorithm to
-        # not have to pre-compute the "in-degree" of the nodes.
-        graph_unsorted = copy.copy(self.children)
-        graph_sorted: list[DAGNode] = []
-        if not self.children:
-            return graph_sorted
-        while graph_unsorted:
-            for node in list(graph_unsorted.values()):
-                for edge in node.upstream_list:
-                    if edge.node_id in graph_unsorted:
+        children = self.children
+        if not children:
+            return []
+        nodes = list(children.values())
+        id_to_idx = {nid: i for i, nid in enumerate(children)}
+        projected = [self._project_child_deps(i, c, id_to_idx) for i, c in 
enumerate(nodes)]
+        return self._sweep_projection(nodes, projected)
+
+    def _project_child_deps(
+        self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
+    ) -> tuple[int, ...]:
+        upstream_ids = child.upstream_task_ids
+        if not upstream_ids:
+            return ()
+        sib_deps: set[int] = set()
+        for edge_id in upstream_ids:
+            j = id_to_idx.get(edge_id)
+            if j is not None:
+                sib_deps.add(j)
+                continue
+            tg = self.dag.get_task(edge_id).task_group
+            while tg is not None:
+                j = id_to_idx.get(tg.node_id)
+                if j is not None:
+                    sib_deps.add(j)
+                    break
+                tg = tg.parent_group
+        sib_deps.discard(child_idx)
+        return tuple(sib_deps)
+
+    def _sweep_projection(self, nodes: list[DAGNode], projected: 
list[tuple[int, ...]]) -> list[DAGNode]:
+        n = len(nodes)
+        emitted = bytearray(n)
+        order: list[DAGNode] = []
+        order_append = order.append
+        pending: list[int] = []
+        pending_append = pending.append
+        for i in range(n):
+            blocked = False
+            for d in projected[i]:
+                if not emitted[d]:
+                    blocked = True
+                    break
+            if blocked:
+                pending_append(i)
+                continue
+            emitted[i] = 1
+            order_append(nodes[i])
+        while pending:
+            next_pending: list[int] = []
+            next_pending_append = next_pending.append
+            for i in pending:
+                blocked = False
+                for d in projected[i]:
+                    if not emitted[d]:
+                        blocked = True
                         break
-                    # Check for task's group is a child (or grand child) of 
this TG,
-                    tg = edge.task_group
-                    while tg:
-                        if tg.node_id in graph_unsorted:
-                            break
-                        tg = tg.parent_group
-
-                    if tg:
-                        # We are already going to visit that TG
-                        break
-                else:
-                    del graph_unsorted[node.node_id]
-                    graph_sorted.append(node)
-        return graph_sorted
+                if blocked:
+                    next_pending_append(i)
+                    continue
+                emitted[i] = 1
+                order_append(nodes[i])
+            if len(next_pending) == len(pending):
+                raise ValueError(f"A cyclic dependency occurred in dag: 
{self.dag_id}")
+            pending = next_pending
+        return order
 
     def add(self, node: DAGNode) -> DAGNode:
         # Set the TG first, as setting it might change the return value of 
node_id!
diff --git a/airflow-core/tests/unit/utils/test_task_group.py 
b/airflow-core/tests/unit/utils/test_task_group.py
index 3b62ad75a72..ffc217fc078 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -1117,6 +1117,34 @@ def test_topological_group_dep():
     ]
 
 
+def test_topological_sort_serialized_layered():
+    """SerializedTaskGroup.topological_sort emits a valid order after DAG 
round-trip.
+
+    Exercises the projected-sweep path on the serialization variant (which is 
otherwise
+    untested), using a layered shape that forces multi-pass behavior.
+    """
+    with DAG("test_topo_sort_serialized", schedule=None, 
start_date=DEFAULT_DATE) as dag:
+        layers: list[list[BaseOperator]] = []
+        for layer_idx in range(4):
+            cur = [EmptyOperator(task_id=f"L{layer_idx}_t{i}") for i in 
range(3)]
+            if layers:
+                for upstream in layers[-1]:
+                    upstream >> cur
+            layers.append(cur)
+
+    serialized = create_scheduler_dag(dag)
+    order = [node.node_id for node in serialized.task_group.topological_sort()]
+    position = {nid: i for i, nid in enumerate(order)}
+
+    assert set(position) == {t.task_id for layer in layers for t in layer}
+    for layer_idx in range(len(layers) - 1):
+        for upstream in layers[layer_idx]:
+            for downstream in layers[layer_idx + 1]:
+                assert position[upstream.task_id] < 
position[downstream.task_id], (
+                    f"{upstream.task_id!r} must precede 
{downstream.task_id!r}, got {order!r}"
+                )
+
+
 def test_task_group_arrow_with_setup_group():
     with DAG(dag_id="setup_group_teardown_group") as dag:
         with TaskGroup("group_1") as g1:
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 50527f6b43b..67376cb817a 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -523,57 +523,90 @@ class TaskGroup(DAGNode):
             key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
         )
 
-    def topological_sort(self):
+    def topological_sort(self) -> list[DAGNode]:
         """
-        Sorts children in topographical order, such that a task comes after 
any of its upstream dependencies.
+        Sort children topologically — a task always comes after its upstream 
dependencies.
 
-        :return: list of tasks in topological order
+        Projects each child's per-task upstream IDs onto sibling-level integer 
indices once,
+        then runs a greedy multi-pass sweep using a bytearray-backed emission 
flag. Equivalent
+        in emission order to the previous modified-Kahn implementation, but 
moves the per-edge
+        ``upstream_list`` materialization and ``parent_group`` walks out of 
the sweep's inner
+        loop so they happen once per call instead of once per outer-loop pass.
         """
-        # This uses a modified version of Kahn's Topological Sort algorithm to
-        # not have to pre-compute the "in-degree" of the nodes.
-        graph_unsorted = copy.copy(self.children)
-
-        graph_sorted: list[DAGNode] = []
-
-        # special case
-        if not self.children:
-            return graph_sorted
-
-        # Run until the unsorted graph is empty.
-        while graph_unsorted:
-            # Go through each of the node/edges pairs in the unsorted graph. 
If a set of edges doesn't contain
-            # any nodes that haven't been resolved, that is, that are still in 
the unsorted graph, remove the
-            # pair from the unsorted graph, and append it to the sorted graph. 
Note here that by using
-            # the values() method for iterating, a copy of the unsorted graph 
is used, allowing us to modify
-            # the unsorted graph as we move through it.
-            #
-            # We also keep a flag for checking that graph is acyclic, which is 
true if any nodes are resolved
-            # during each pass through the graph. If not, we need to exit as 
the graph therefore can't be
-            # sorted.
-            acyclic = False
-            for node in list(graph_unsorted.values()):
-                for edge in node.upstream_list:
-                    if edge.node_id in graph_unsorted:
-                        break
-                    # Check for task's group is a child (or grand child) of 
this TG,
-                    tg = edge.task_group
-                    while tg:
-                        if tg.node_id in graph_unsorted:
-                            break
-                        tg = tg.parent_group
-
-                    if tg:
-                        # We are already going to visit that TG
+        children = self.children
+        if not children:
+            return []
+        nodes = list(children.values())
+        id_to_idx = {nid: i for i, nid in enumerate(children)}
+        projected = [self._project_child_deps(i, c, id_to_idx) for i, c in 
enumerate(nodes)]
+        return self._sweep_projection(nodes, projected)
+
+    def _project_child_deps(
+        self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
+    ) -> tuple[int, ...]:
+        # Project one child's per-task upstream IDs onto sibling-level integer 
indices.
+        # Self-deps are filtered once at the end via ``discard`` so the inner 
loop stays tight.
+        upstream_ids = child.upstream_task_ids
+        if not upstream_ids:
+            return ()
+        sib_deps: set[int] = set()
+        for edge_id in upstream_ids:
+            j = id_to_idx.get(edge_id)
+            if j is not None:
+                sib_deps.add(j)
+                continue
+            tg = self.dag.get_task(edge_id).task_group
+            while tg is not None:
+                j = id_to_idx.get(tg.node_id)
+                if j is not None:
+                    sib_deps.add(j)
+                    break
+                tg = tg.parent_group
+        sib_deps.discard(child_idx)
+        return tuple(sib_deps)
+
+    def _sweep_projection(self, nodes: list[DAGNode], projected: 
list[tuple[int, ...]]) -> list[DAGNode]:
+        # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been 
emitted.
+        # Pass 1 iterates range(n) directly; only blocked nodes are recorded 
into
+        # ``pending`` and re-checked in subsequent passes. Avoids paying for a
+        # ``list(range(n))`` allocation on single-pass shapes (the common 
case) while
+        # still skipping already-emitted nodes on multi-pass shapes (e.g. a 
diamond's
+        # single trailing sink).
+        n = len(nodes)
+        emitted = bytearray(n)
+        order: list[DAGNode] = []
+        order_append = order.append
+        pending: list[int] = []
+        pending_append = pending.append
+        for i in range(n):
+            blocked = False
+            for d in projected[i]:
+                if not emitted[d]:
+                    blocked = True
+                    break
+            if blocked:
+                pending_append(i)
+                continue
+            emitted[i] = 1
+            order_append(nodes[i])
+        while pending:
+            next_pending: list[int] = []
+            next_pending_append = next_pending.append
+            for i in pending:
+                blocked = False
+                for d in projected[i]:
+                    if not emitted[d]:
+                        blocked = True
                         break
-                else:
-                    acyclic = True
-                    del graph_unsorted[node.node_id]
-                    graph_sorted.append(node)
-
-            if not acyclic:
+                if blocked:
+                    next_pending_append(i)
+                    continue
+                emitted[i] = 1
+                order_append(nodes[i])
+            if len(next_pending) == len(pending):
                 raise AirflowDagCycleException(f"A cyclic dependency occurred 
in dag: {self.dag_id}")
-
-        return graph_sorted
+            pending = next_pending
+        return order
 
     def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
         """
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py 
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
index 18c7f65faf2..d1ba11e3056 100644
--- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -957,3 +957,147 @@ class TestTaskGroupGetItem:
 
         with pytest.raises(KeyError):
             tg["nonexistent"]
+
+
+# --- topological_sort: cross-shape correctness ---
+#
+# Mirrors the shapes covered by the benchmark gist referenced from PR #67288
+# (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb).
+# Wall-clock timing is intentionally not asserted here — CI runners are too
+# variable for ms thresholds to be meaningful. The gist above can be run
+# manually to gauge performance.
+
+
+def _make_chain(n: int) -> DAG:
+    with DAG(f"chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+        prev = None
+        for i in range(n):
+            t = EmptyOperator(task_id=f"t{i}")
+            if prev is not None:
+                prev >> t
+            prev = t
+    return dag
+
+
+def _make_reverse_chain(n: int) -> DAG:
+    with DAG(f"reverse_chain_{n}", schedule=None, start_date=DEFAULT_DATE) as 
dag:
+        tasks = [EmptyOperator(task_id=f"t{n - 1 - i}") for i in range(n)]
+        by_id = {t.task_id: t for t in tasks}
+        for i in range(n - 1):
+            by_id[f"t{i}"] >> by_id[f"t{i + 1}"]
+    return dag
+
+
+def _make_diamond(n: int) -> DAG:
+    with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+        root = EmptyOperator(task_id="root")
+        sink = EmptyOperator(task_id="sink")
+        middles = [EmptyOperator(task_id=f"m{i}") for i in range(max(n - 2, 
1))]
+        root >> middles >> sink
+    return dag
+
+
+def _make_independent(n: int) -> DAG:
+    with DAG(f"independent_{n}", schedule=None, start_date=DEFAULT_DATE) as 
dag:
+        for i in range(n):
+            EmptyOperator(task_id=f"t{i}")
+    return dag
+
+
+def _make_layered(n: int, layers: int = 4) -> DAG:
+    per_layer = max(n // layers, 1)
+    with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+        prev_layer: list[EmptyOperator] = []
+        for layer in range(layers):
+            cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in 
range(per_layer)]
+            if prev_layer:
+                for upstream in prev_layer:
+                    upstream >> cur
+            prev_layer = cur
+    return dag
+
+
+def _make_nested_groups(n: int, depth: int = 3) -> DAG:
+    per_group = max(n // (depth * depth), 1)
+    with DAG(f"nested_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+
+        def build_group(level: int, idx: int) -> TaskGroup:
+            with TaskGroup(group_id=f"g{level}_{idx}") as tg:
+                prev = None
+                for i in range(per_group):
+                    t = EmptyOperator(task_id=f"l{level}_g{idx}_t{i}")
+                    if prev is not None:
+                        prev >> t
+                    prev = t
+                if level + 1 < depth:
+                    inner_prev = None
+                    for j in range(depth):
+                        inner = build_group(level + 1, j)
+                        if inner_prev is not None:
+                            inner_prev >> inner
+                        inner_prev = inner
+            return tg
+
+        top_prev = None
+        for j in range(depth):
+            top = build_group(0, j)
+            if top_prev is not None:
+                top_prev >> top
+            top_prev = top
+    return dag
+
+
+def _project_sibling(group: TaskGroup, upstream_task_id: str, child_id: str) 
-> str | None:
+    """Mirror of TaskGroup._project_child_deps' projection, returning a string 
ID."""
+    children = group.children
+    if upstream_task_id in children:
+        return upstream_task_id if upstream_task_id != child_id else None
+    upstream = group.dag.get_task(upstream_task_id)
+    tg = upstream.task_group
+    while tg is not None:
+        if tg.node_id in children:
+            return tg.node_id if tg.node_id != child_id else None
+        tg = tg.parent_group
+    return None
+
+
+def _walk_groups(tg: TaskGroup):
+    yield tg
+    for child in tg.children.values():
+        if isinstance(child, TaskGroup):
+            yield from _walk_groups(child)
+
+
+def _assert_valid_topological_order(group: TaskGroup, order: list[str]) -> 
None:
+    position = {node_id: i for i, node_id in enumerate(order)}
+    assert set(position) == set(group.children), (
+        f"topological_sort output {order!r} does not cover children of 
{group.node_id!r}"
+    )
+    for child_id, child in group.children.items():
+        for upstream_id in child.upstream_task_ids:
+            sib = _project_sibling(group, upstream_id, child_id)
+            if sib is None:
+                continue
+            assert position[sib] < position[child_id], (
+                f"In group {group.node_id!r}: sibling {sib!r} must precede 
{child_id!r}, got order {order!r}"
+            )
+
+
[email protected](
+    ("shape", "builder"),
+    [
+        ("chain", _make_chain),
+        ("rev-chain", _make_reverse_chain),
+        ("diamond", _make_diamond),
+        ("independent", _make_independent),
+        ("layered", _make_layered),
+        ("nested", _make_nested_groups),
+    ],
+)
[email protected]("n", [20, 100])
+def test_topological_sort_shape_correctness(shape, builder, n):
+    """topological_sort emits a valid order for every nested group across DAG 
shapes."""
+    dag = builder(n)
+    for group in _walk_groups(dag.task_group):
+        order = [node.node_id for node in group.topological_sort()]
+        _assert_valid_topological_order(group, order)

Reply via email to