This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7bf71a3d444 Speed up TaskGroup.topological_sort with int-indexed
projected sweep (#67288)
7bf71a3d444 is described below
commit 7bf71a3d444e2e1f5ef22b4c18f782c03281d001
Author: Shahar Epstein <[email protected]>
AuthorDate: Wed May 27 21:39:43 2026 +0300
Speed up TaskGroup.topological_sort with int-indexed projected sweep
(#67288)
---
airflow-core/newsfragments/67288.improvement.rst | 1 +
.../airflow/serialization/definitions/taskgroup.py | 97 ++++++++++----
airflow-core/tests/unit/utils/test_task_group.py | 28 ++++
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 125 +++++++++++-------
.../tests/task_sdk/definitions/test_taskgroup.py | 144 +++++++++++++++++++++
5 files changed, 322 insertions(+), 73 deletions(-)
diff --git a/airflow-core/newsfragments/67288.improvement.rst
b/airflow-core/newsfragments/67288.improvement.rst
new file mode 100644
index 00000000000..03293e4ffa2
--- /dev/null
+++ b/airflow-core/newsfragments/67288.improvement.rst
@@ -0,0 +1 @@
+Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond,
layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups.
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index d971c303c7c..5db656019f1 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -18,7 +18,6 @@
from __future__ import annotations
-import copy
import functools
import operator
import weakref
@@ -217,35 +216,79 @@ class SerializedTaskGroup(DAGNode):
def topological_sort(self) -> list[DAGNode]:
"""
- Sorts children in topographical order.
+ Sort children topologically — a task always comes after its upstream
dependencies.
- A task in the result would come after any of its upstream dependencies.
+ See ``TaskGroup.topological_sort`` in task-sdk for the algorithm.
Cycles are
+ treated as corrupt input: ``DAG.check_cycle`` rejects cyclic Dags
before
+ serialization, so a cycle reaching this code indicates malformed
serialized data,
+ and we raise ``ValueError`` rather than silently looping forever.
"""
- # This uses a modified version of Kahn's Topological Sort algorithm to
- # not have to pre-compute the "in-degree" of the nodes.
- graph_unsorted = copy.copy(self.children)
- graph_sorted: list[DAGNode] = []
- if not self.children:
- return graph_sorted
- while graph_unsorted:
- for node in list(graph_unsorted.values()):
- for edge in node.upstream_list:
- if edge.node_id in graph_unsorted:
+ children = self.children
+ if not children:
+ return []
+ nodes = list(children.values())
+ id_to_idx = {nid: i for i, nid in enumerate(children)}
+ projected = [self._project_child_deps(i, c, id_to_idx) for i, c in
enumerate(nodes)]
+ return self._sweep_projection(nodes, projected)
+
+ def _project_child_deps(
+ self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
+ ) -> tuple[int, ...]:
+ upstream_ids = child.upstream_task_ids
+ if not upstream_ids:
+ return ()
+ sib_deps: set[int] = set()
+ for edge_id in upstream_ids:
+ j = id_to_idx.get(edge_id)
+ if j is not None:
+ sib_deps.add(j)
+ continue
+ tg = self.dag.get_task(edge_id).task_group
+ while tg is not None:
+ j = id_to_idx.get(tg.node_id)
+ if j is not None:
+ sib_deps.add(j)
+ break
+ tg = tg.parent_group
+ sib_deps.discard(child_idx)
+ return tuple(sib_deps)
+
+ def _sweep_projection(self, nodes: list[DAGNode], projected:
list[tuple[int, ...]]) -> list[DAGNode]:
+ n = len(nodes)
+ emitted = bytearray(n)
+ order: list[DAGNode] = []
+ order_append = order.append
+ pending: list[int] = []
+ pending_append = pending.append
+ for i in range(n):
+ blocked = False
+ for d in projected[i]:
+ if not emitted[d]:
+ blocked = True
+ break
+ if blocked:
+ pending_append(i)
+ continue
+ emitted[i] = 1
+ order_append(nodes[i])
+ while pending:
+ next_pending: list[int] = []
+ next_pending_append = next_pending.append
+ for i in pending:
+ blocked = False
+ for d in projected[i]:
+ if not emitted[d]:
+ blocked = True
break
- # Check for task's group is a child (or grand child) of
this TG,
- tg = edge.task_group
- while tg:
- if tg.node_id in graph_unsorted:
- break
- tg = tg.parent_group
-
- if tg:
- # We are already going to visit that TG
- break
- else:
- del graph_unsorted[node.node_id]
- graph_sorted.append(node)
- return graph_sorted
+ if blocked:
+ next_pending_append(i)
+ continue
+ emitted[i] = 1
+ order_append(nodes[i])
+ if len(next_pending) == len(pending):
+ raise ValueError(f"A cyclic dependency occurred in dag:
{self.dag_id}")
+ pending = next_pending
+ return order
def add(self, node: DAGNode) -> DAGNode:
# Set the TG first, as setting it might change the return value of
node_id!
diff --git a/airflow-core/tests/unit/utils/test_task_group.py
b/airflow-core/tests/unit/utils/test_task_group.py
index 3b62ad75a72..ffc217fc078 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -1117,6 +1117,34 @@ def test_topological_group_dep():
]
+def test_topological_sort_serialized_layered():
+ """SerializedTaskGroup.topological_sort emits a valid order after DAG
round-trip.
+
+ Exercises the projected-sweep path on the serialization variant (which is
otherwise
+ untested), using a layered shape that forces multi-pass behavior.
+ """
+ with DAG("test_topo_sort_serialized", schedule=None,
start_date=DEFAULT_DATE) as dag:
+ layers: list[list[BaseOperator]] = []
+ for layer_idx in range(4):
+ cur = [EmptyOperator(task_id=f"L{layer_idx}_t{i}") for i in
range(3)]
+ if layers:
+ for upstream in layers[-1]:
+ upstream >> cur
+ layers.append(cur)
+
+ serialized = create_scheduler_dag(dag)
+ order = [node.node_id for node in serialized.task_group.topological_sort()]
+ position = {nid: i for i, nid in enumerate(order)}
+
+ assert set(position) == {t.task_id for layer in layers for t in layer}
+ for layer_idx in range(len(layers) - 1):
+ for upstream in layers[layer_idx]:
+ for downstream in layers[layer_idx + 1]:
+ assert position[upstream.task_id] <
position[downstream.task_id], (
+ f"{upstream.task_id!r} must precede
{downstream.task_id!r}, got {order!r}"
+ )
+
+
def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group") as dag:
with TaskGroup("group_1") as g1:
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 50527f6b43b..67376cb817a 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -523,57 +523,90 @@ class TaskGroup(DAGNode):
key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
)
- def topological_sort(self):
+ def topological_sort(self) -> list[DAGNode]:
"""
- Sorts children in topographical order, such that a task comes after
any of its upstream dependencies.
+ Sort children topologically — a task always comes after its upstream
dependencies.
- :return: list of tasks in topological order
+ Projects each child's per-task upstream IDs onto sibling-level integer
indices once,
+ then runs a greedy multi-pass sweep using a bytearray-backed emission
flag. Equivalent
+ in emission order to the previous modified-Kahn implementation, but
moves the per-edge
+ ``upstream_list`` materialization and ``parent_group`` walks out of
the sweep's inner
+ loop so they happen once per call instead of once per outer-loop pass.
"""
- # This uses a modified version of Kahn's Topological Sort algorithm to
- # not have to pre-compute the "in-degree" of the nodes.
- graph_unsorted = copy.copy(self.children)
-
- graph_sorted: list[DAGNode] = []
-
- # special case
- if not self.children:
- return graph_sorted
-
- # Run until the unsorted graph is empty.
- while graph_unsorted:
- # Go through each of the node/edges pairs in the unsorted graph.
If a set of edges doesn't contain
- # any nodes that haven't been resolved, that is, that are still in
the unsorted graph, remove the
- # pair from the unsorted graph, and append it to the sorted graph.
Note here that by using
- # the values() method for iterating, a copy of the unsorted graph
is used, allowing us to modify
- # the unsorted graph as we move through it.
- #
- # We also keep a flag for checking that graph is acyclic, which is
true if any nodes are resolved
- # during each pass through the graph. If not, we need to exit as
the graph therefore can't be
- # sorted.
- acyclic = False
- for node in list(graph_unsorted.values()):
- for edge in node.upstream_list:
- if edge.node_id in graph_unsorted:
- break
- # Check for task's group is a child (or grand child) of
this TG,
- tg = edge.task_group
- while tg:
- if tg.node_id in graph_unsorted:
- break
- tg = tg.parent_group
-
- if tg:
- # We are already going to visit that TG
+ children = self.children
+ if not children:
+ return []
+ nodes = list(children.values())
+ id_to_idx = {nid: i for i, nid in enumerate(children)}
+ projected = [self._project_child_deps(i, c, id_to_idx) for i, c in
enumerate(nodes)]
+ return self._sweep_projection(nodes, projected)
+
+ def _project_child_deps(
+ self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
+ ) -> tuple[int, ...]:
+ # Project one child's per-task upstream IDs onto sibling-level integer
indices.
+ # Self-deps are filtered once at the end via ``discard`` so the inner
loop stays tight.
+ upstream_ids = child.upstream_task_ids
+ if not upstream_ids:
+ return ()
+ sib_deps: set[int] = set()
+ for edge_id in upstream_ids:
+ j = id_to_idx.get(edge_id)
+ if j is not None:
+ sib_deps.add(j)
+ continue
+ tg = self.dag.get_task(edge_id).task_group
+ while tg is not None:
+ j = id_to_idx.get(tg.node_id)
+ if j is not None:
+ sib_deps.add(j)
+ break
+ tg = tg.parent_group
+ sib_deps.discard(child_idx)
+ return tuple(sib_deps)
+
+ def _sweep_projection(self, nodes: list[DAGNode], projected:
list[tuple[int, ...]]) -> list[DAGNode]:
+ # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been
emitted.
+ # Pass 1 iterates range(n) directly; only blocked nodes are recorded
into
+ # ``pending`` and re-checked in subsequent passes. Avoids paying for a
+ # ``list(range(n))`` allocation on single-pass shapes (the common
case) while
+ # still skipping already-emitted nodes on multi-pass shapes (e.g. a
diamond's
+ # single trailing sink).
+ n = len(nodes)
+ emitted = bytearray(n)
+ order: list[DAGNode] = []
+ order_append = order.append
+ pending: list[int] = []
+ pending_append = pending.append
+ for i in range(n):
+ blocked = False
+ for d in projected[i]:
+ if not emitted[d]:
+ blocked = True
+ break
+ if blocked:
+ pending_append(i)
+ continue
+ emitted[i] = 1
+ order_append(nodes[i])
+ while pending:
+ next_pending: list[int] = []
+ next_pending_append = next_pending.append
+ for i in pending:
+ blocked = False
+ for d in projected[i]:
+ if not emitted[d]:
+ blocked = True
break
- else:
- acyclic = True
- del graph_unsorted[node.node_id]
- graph_sorted.append(node)
-
- if not acyclic:
+ if blocked:
+ next_pending_append(i)
+ continue
+ emitted[i] = 1
+ order_append(nodes[i])
+ if len(next_pending) == len(pending):
raise AirflowDagCycleException(f"A cyclic dependency occurred
in dag: {self.dag_id}")
-
- return graph_sorted
+ pending = next_pending
+ return order
def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
index 18c7f65faf2..d1ba11e3056 100644
--- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -957,3 +957,147 @@ class TestTaskGroupGetItem:
with pytest.raises(KeyError):
tg["nonexistent"]
+
+
+# --- topological_sort: cross-shape correctness ---
+#
+# Mirrors the shapes covered by the benchmark gist referenced from PR #67288
+# (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb).
+# Wall-clock timing is intentionally not asserted here — CI runners are too
+# variable for ms thresholds to be meaningful. The gist above can be run
+# manually to gauge performance.
+
+
+def _make_chain(n: int) -> DAG:
+ with DAG(f"chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ prev = None
+ for i in range(n):
+ t = EmptyOperator(task_id=f"t{i}")
+ if prev is not None:
+ prev >> t
+ prev = t
+ return dag
+
+
+def _make_reverse_chain(n: int) -> DAG:
+ with DAG(f"reverse_chain_{n}", schedule=None, start_date=DEFAULT_DATE) as
dag:
+ tasks = [EmptyOperator(task_id=f"t{n - 1 - i}") for i in range(n)]
+ by_id = {t.task_id: t for t in tasks}
+ for i in range(n - 1):
+ by_id[f"t{i}"] >> by_id[f"t{i + 1}"]
+ return dag
+
+
+def _make_diamond(n: int) -> DAG:
+ with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ root = EmptyOperator(task_id="root")
+ sink = EmptyOperator(task_id="sink")
+ middles = [EmptyOperator(task_id=f"m{i}") for i in range(max(n - 2,
1))]
+ root >> middles >> sink
+ return dag
+
+
+def _make_independent(n: int) -> DAG:
+ with DAG(f"independent_{n}", schedule=None, start_date=DEFAULT_DATE) as
dag:
+ for i in range(n):
+ EmptyOperator(task_id=f"t{i}")
+ return dag
+
+
+def _make_layered(n: int, layers: int = 4) -> DAG:
+ per_layer = max(n // layers, 1)
+ with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ prev_layer: list[EmptyOperator] = []
+ for layer in range(layers):
+ cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in
range(per_layer)]
+ if prev_layer:
+ for upstream in prev_layer:
+ upstream >> cur
+ prev_layer = cur
+ return dag
+
+
+def _make_nested_groups(n: int, depth: int = 3) -> DAG:
+ per_group = max(n // (depth * depth), 1)
+ with DAG(f"nested_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+
+ def build_group(level: int, idx: int) -> TaskGroup:
+ with TaskGroup(group_id=f"g{level}_{idx}") as tg:
+ prev = None
+ for i in range(per_group):
+ t = EmptyOperator(task_id=f"l{level}_g{idx}_t{i}")
+ if prev is not None:
+ prev >> t
+ prev = t
+ if level + 1 < depth:
+ inner_prev = None
+ for j in range(depth):
+ inner = build_group(level + 1, j)
+ if inner_prev is not None:
+ inner_prev >> inner
+ inner_prev = inner
+ return tg
+
+ top_prev = None
+ for j in range(depth):
+ top = build_group(0, j)
+ if top_prev is not None:
+ top_prev >> top
+ top_prev = top
+ return dag
+
+
+def _project_sibling(group: TaskGroup, upstream_task_id: str, child_id: str)
-> str | None:
+ """Mirror of TaskGroup._project_child_deps' projection, returning a string
ID."""
+ children = group.children
+ if upstream_task_id in children:
+ return upstream_task_id if upstream_task_id != child_id else None
+ upstream = group.dag.get_task(upstream_task_id)
+ tg = upstream.task_group
+ while tg is not None:
+ if tg.node_id in children:
+ return tg.node_id if tg.node_id != child_id else None
+ tg = tg.parent_group
+ return None
+
+
+def _walk_groups(tg: TaskGroup):
+ yield tg
+ for child in tg.children.values():
+ if isinstance(child, TaskGroup):
+ yield from _walk_groups(child)
+
+
+def _assert_valid_topological_order(group: TaskGroup, order: list[str]) ->
None:
+ position = {node_id: i for i, node_id in enumerate(order)}
+ assert set(position) == set(group.children), (
+ f"topological_sort output {order!r} does not cover children of
{group.node_id!r}"
+ )
+ for child_id, child in group.children.items():
+ for upstream_id in child.upstream_task_ids:
+ sib = _project_sibling(group, upstream_id, child_id)
+ if sib is None:
+ continue
+ assert position[sib] < position[child_id], (
+ f"In group {group.node_id!r}: sibling {sib!r} must precede
{child_id!r}, got order {order!r}"
+ )
+
+
[email protected](
+ ("shape", "builder"),
+ [
+ ("chain", _make_chain),
+ ("rev-chain", _make_reverse_chain),
+ ("diamond", _make_diamond),
+ ("independent", _make_independent),
+ ("layered", _make_layered),
+ ("nested", _make_nested_groups),
+ ],
+)
[email protected]("n", [20, 100])
+def test_topological_sort_shape_correctness(shape, builder, n):
+ """topological_sort emits a valid order for every nested group across DAG
shapes."""
+ dag = builder(n)
+ for group in _walk_groups(dag.task_group):
+ order = [node.node_id for node in group.topological_sort()]
+ _assert_valid_topological_order(group, order)