This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d2f42e62cfc Optimize TaskGroup.topological_sort for reverse-declared
Dags (#67688)
d2f42e62cfc is described below
commit d2f42e62cfcc9c16de8cd1dbe108a1e201bfb51f
Author: Shahar Epstein <[email protected]>
AuthorDate: Fri Jun 12 07:20:25 2026 +0300
Optimize TaskGroup.topological_sort for reverse-declared Dags (#67688)
---
airflow-core/newsfragments/67688.improvement.rst | 1 +
.../airflow/serialization/definitions/taskgroup.py | 67 ++++++++++++++--
airflow-core/tests/unit/utils/test_task_group.py | 42 ++++++++++
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 92 ++++++++++++++++++----
.../tests/task_sdk/definitions/test_taskgroup.py | 45 +++++++++++
5 files changed, 224 insertions(+), 23 deletions(-)
diff --git a/airflow-core/newsfragments/67688.improvement.rst
b/airflow-core/newsfragments/67688.improvement.rst
new file mode 100644
index 00000000000..d2a641ac00c
--- /dev/null
+++ b/airflow-core/newsfragments/67688.improvement.rst
@@ -0,0 +1 @@
+Further optimize ``TaskGroup.topological_sort`` for reverse-declared DAGs via
pass-number traversal; dramatically improves the O(N²) worst-case for
adversarial shapes (e.g., reverse-insertion chains).
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index a5d8b730b05..65d59cb15f1 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -227,9 +227,24 @@ class SerializedTaskGroup(DAGNode):
children = self.children
if not children:
return []
+
nodes = list(children.values())
+ n = len(nodes)
id_to_idx = {nid: i for i, nid in enumerate(children)}
- projected = [self._project_child_deps(i, c, id_to_idx) for i, c in
enumerate(nodes)]
+
+ projected: list[tuple[int, ...]] = [()] * n
+ nodes_with_back_edge = 0
+ for i, child in enumerate(nodes):
+ deps = self._project_child_deps(i, child, id_to_idx)
+ if deps:
+ projected[i] = deps
+ if any(d > i for d in deps):
+ nodes_with_back_edge += 1
+
+ # The ratio catches dense back-heavy groups; a 32-node absolute cutoff
keeps
+ # padded reverse-declared runs on the fast path once sweep rescans
overtake pass-numbering.
+ if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
+ return self._sort_via_pass_numbering(nodes, projected)
return self._sweep_projection(nodes, projected)
def _project_child_deps(
@@ -242,16 +257,18 @@ class SerializedTaskGroup(DAGNode):
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
- sib_deps.add(j)
+ if j != child_idx:
+ sib_deps.add(j)
continue
- tg = self.dag.get_task(edge_id).task_group
+ edge = self.dag.get_task(edge_id)
+ tg = edge.task_group
while tg is not None:
- j = id_to_idx.get(tg.node_id)
- if j is not None:
- sib_deps.add(j)
+ anc_idx = id_to_idx.get(tg.node_id)
+ if anc_idx is not None:
+ if anc_idx != child_idx:
+ sib_deps.add(anc_idx)
break
tg = tg.parent_group
- sib_deps.discard(child_idx)
return tuple(sib_deps)
def _sweep_projection(self, nodes: list[DAGNode], projected:
list[tuple[int, ...]]) -> list[DAGNode]:
@@ -291,6 +308,42 @@ class SerializedTaskGroup(DAGNode):
pending = next_pending
return order
+ def _sort_via_pass_numbering(
+ self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
+ ) -> list[DAGNode]:
+ n = len(nodes)
+ in_degree = [len(deps) for deps in projected]
+ successors: list[list[int]] = [[] for _ in range(n)]
+ for i, deps in enumerate(projected):
+ for d in deps:
+ successors[d].append(i)
+
+ pass_of = [0] * n
+ queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
+ processed = 0
+ while queue:
+ i = queue.popleft()
+ my_pass = 1
+ for d in projected[i]:
+ d_pass = pass_of[d]
+ if d < i:
+ if d_pass > my_pass:
+ my_pass = d_pass
+ elif d_pass + 1 > my_pass:
+ my_pass = d_pass + 1
+ pass_of[i] = my_pass
+ processed += 1
+ for s in successors[i]:
+ in_degree[s] -= 1
+ if in_degree[s] == 0:
+ queue.append(s)
+
+ if processed != n:
+ raise ValueError(f"A cyclic dependency occurred in dag:
{self.dag_id}")
+
+ sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
+ return [nodes[i] for i in sorted_indices]
+
def add(self, node: DAGNode) -> DAGNode:
# Set the TG first, as setting it might change the return value of
node_id!
node.task_group = weakref.proxy(self)
diff --git a/airflow-core/tests/unit/utils/test_task_group.py
b/airflow-core/tests/unit/utils/test_task_group.py
index e866bce62af..2d1458e95fb 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -1100,6 +1100,21 @@ def test_hierarchical_alphabetical_sort():
]
+def _make_padded_reverse_chain(chain_length: int, independent_count: int) ->
DAG:
+ with DAG(
+ f"padded_reverse_chain_{chain_length}_{independent_count}",
+ schedule=None,
+ start_date=DEFAULT_DATE,
+ ) as dag:
+ tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in
range(chain_length)]
+ by_id = {task.task_id: task for task in tasks}
+ for i in range(chain_length - 1):
+ by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
+ for i in range(independent_count):
+ EmptyOperator(task_id=f"i{i}")
+ return dag
+
+
def test_topological_group_dep():
logical_date = pendulum.parse("20200101")
with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag:
@@ -1163,6 +1178,33 @@ def test_topological_sort_serialized_layered():
)
+def
test_topological_sort_serialized_padded_reverse_chain_uses_pass_numbering(monkeypatch):
+ dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
+ serialized = create_scheduler_dag(dag)
+ serialized.task_group.children = {
+ **{f"r{i}": serialized.task_group.children[f"r{i}"] for i in range(79,
-1, -1)},
+ **{f"i{i}": serialized.task_group.children[f"i{i}"] for i in
range(80)},
+ }
+
+ called = {"value": False}
+ serialized_task_group_cls = type(serialized.task_group)
+ original = serialized_task_group_cls._sort_via_pass_numbering
+
+ def spy(self, nodes, projected):
+ called["value"] = True
+ return original(self, nodes, projected)
+
+ monkeypatch.setattr(serialized_task_group_cls, "_sort_via_pass_numbering",
spy)
+
+ order = [node.node_id for node in serialized.task_group.topological_sort()]
+ position = {node_id: i for i, node_id in enumerate(order)}
+
+ assert called["value"]
+ assert set(position) == {*(f"r{i}" for i in range(80)), *(f"i{i}" for i in
range(80))}
+ for i in range(79):
+ assert position[f"r{i}"] < position[f"r{i + 1}"]
+
+
def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group") as dag:
with TaskGroup("group_1") as g1:
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index cc1fc5cbda6..14bb2fba319 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -544,25 +544,42 @@ class TaskGroup(DAGNode):
"""
Sort children topologically — a task always comes after its upstream
dependencies.
- Projects each child's per-task upstream IDs onto sibling-level integer
indices once,
- then runs a greedy multi-pass sweep using a bytearray-backed emission
flag. Equivalent
- in emission order to the previous modified-Kahn implementation, but
moves the per-edge
- ``upstream_list`` materialization and ``parent_group`` walks out of
the sweep's inner
- loop so they happen once per call instead of once per outer-loop pass.
+ Projects per-task upstream edges onto sibling-level integer indices,
then dispatches:
+
+ - Forward-declared DAGs (few/no children declared after their
dependents): greedy
+ multi-pass sweep over the projection, O(V + E) for the common case.
+ - Reverse-declared DAGs (many children declared before their
dependents): pass-number
+ traversal, O((V + E) log V), avoids the O(N²) blowup the sweep would
hit.
+
+ Both branches produce the same emission order: level-by-legacy-pass,
ties broken by
+ children insertion order.
"""
children = self.children
if not children:
return []
+
nodes = list(children.values())
+ n = len(nodes)
id_to_idx = {nid: i for i, nid in enumerate(children)}
- projected = [self._project_child_deps(i, c, id_to_idx) for i, c in
enumerate(nodes)]
+
+ projected: list[tuple[int, ...]] = [()] * n
+ nodes_with_back_edge = 0
+ for i, child in enumerate(nodes):
+ deps = self._project_child_deps(i, child, id_to_idx)
+ if deps:
+ projected[i] = deps
+ if any(d > i for d in deps):
+ nodes_with_back_edge += 1
+
+ # The ratio catches dense back-heavy groups; a 32-node absolute cutoff
keeps
+ # padded reverse-declared runs on the fast path once sweep rescans
overtake pass-numbering.
+ if nodes_with_back_edge >= 32 or nodes_with_back_edge * 2 > n:
+ return self._sort_via_pass_numbering(nodes, projected)
return self._sweep_projection(nodes, projected)
def _project_child_deps(
self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
) -> tuple[int, ...]:
- # Project one child's per-task upstream IDs onto sibling-level integer
indices.
- # Self-deps are filtered once at the end via ``discard`` so the inner
loop stays tight.
upstream_ids = child.upstream_task_ids
if not upstream_ids:
return ()
@@ -570,23 +587,25 @@ class TaskGroup(DAGNode):
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
- sib_deps.add(j)
+ if j != child_idx:
+ sib_deps.add(j)
continue
- tg = self.dag.get_task(edge_id).task_group
+ edge = self.dag.get_task(edge_id)
+ tg = edge.task_group
while tg is not None:
- j = id_to_idx.get(tg.node_id)
- if j is not None:
- sib_deps.add(j)
+ anc_idx = id_to_idx.get(tg.node_id)
+ if anc_idx is not None:
+ if anc_idx != child_idx:
+ sib_deps.add(anc_idx)
break
tg = tg.parent_group
- sib_deps.discard(child_idx)
return tuple(sib_deps)
def _sweep_projection(self, nodes: list[DAGNode], projected:
list[tuple[int, ...]]) -> list[DAGNode]:
# Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been
emitted.
# Pass 1 iterates range(n) directly; only blocked nodes are recorded
into
- # ``pending`` and re-checked in subsequent passes. Avoids paying for a
- # ``list(range(n))`` allocation on single-pass shapes (the common
case) while
+ # `pending` and re-checked in subsequent passes. Avoids paying for a
+ # `list(range(n))` allocation on single-pass shapes (the common case)
while
# still skipping already-emitted nodes on multi-pass shapes (e.g. a
diamond's
# single trailing sink).
n = len(nodes)
@@ -625,6 +644,47 @@ class TaskGroup(DAGNode):
pending = next_pending
return order
+ def _sort_via_pass_numbering(
+ self, nodes: list[DAGNode], projected: list[tuple[int, ...]]
+ ) -> list[DAGNode]:
+ # Sort by (pass_number, insertion_index). pass_number(X) is the
earliest pass at
+ # which a greedy-sweep emission of X would occur:
+ # pass(X) = max over deps d of (pass(d) if idx(d) < idx(X) else
pass(d)+1)
+ # A dep declared before X can be emitted in the same pass; a dep
declared after X
+ # forces X into the next pass. Computed via Kahn's traversal in O(V +
E).
+ n = len(nodes)
+ in_degree = [len(deps) for deps in projected]
+ successors: list[list[int]] = [[] for _ in range(n)]
+ for i, deps in enumerate(projected):
+ for d in deps:
+ successors[d].append(i)
+
+ pass_of = [0] * n
+ queue: deque[int] = deque(i for i in range(n) if in_degree[i] == 0)
+ processed = 0
+ while queue:
+ i = queue.popleft()
+ my_pass = 1
+ for d in projected[i]:
+ d_pass = pass_of[d]
+ if d < i:
+ if d_pass > my_pass:
+ my_pass = d_pass
+ elif d_pass + 1 > my_pass:
+ my_pass = d_pass + 1
+ pass_of[i] = my_pass
+ processed += 1
+ for s in successors[i]:
+ in_degree[s] -= 1
+ if in_degree[s] == 0:
+ queue.append(s)
+
+ if processed != n:
+ raise AirflowDagCycleException(f"A cyclic dependency occurred in
dag: {self.dag_id}")
+
+ sorted_indices = sorted(range(n), key=lambda i: (pass_of[i], i))
+ return [nodes[i] for i in sorted_indices]
+
def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""
Return mapped task groups in the hierarchy.
diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
index cf6d309f305..d11f8eb3632 100644
--- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
+++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py
@@ -1011,6 +1011,21 @@ def _make_reverse_chain(n: int) -> DAG:
return dag
+def _make_padded_reverse_chain(chain_length: int, independent_count: int) ->
DAG:
+ with DAG(
+ f"padded_reverse_chain_{chain_length}_{independent_count}",
+ schedule=None,
+ start_date=DEFAULT_DATE,
+ ) as dag:
+ tasks = [EmptyOperator(task_id=f"r{chain_length - 1 - i}") for i in
range(chain_length)]
+ by_id = {t.task_id: t for t in tasks}
+ for i in range(chain_length - 1):
+ by_id[f"r{i}"] >> by_id[f"r{i + 1}"]
+ for i in range(independent_count):
+ EmptyOperator(task_id=f"i{i}")
+ return dag
+
+
def _make_diamond(n: int) -> DAG:
with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
root = EmptyOperator(task_id="root")
@@ -1124,3 +1139,33 @@ def test_topological_sort_shape_correctness(shape,
builder, n):
for group in _walk_groups(dag.task_group):
order = [node.node_id for node in group.topological_sort()]
_assert_valid_topological_order(group, order)
+
+
+def test_topological_sort_reverse_declared_order_matches_sweep():
+ dag = _make_reverse_chain(100)
+ group = dag.task_group
+ nodes = list(group.children.values())
+ id_to_idx = {nid: i for i, nid in enumerate(group.children)}
+ projected = [group._project_child_deps(i, child, id_to_idx) for i, child
in enumerate(nodes)]
+
+ sweep_order = [node.node_id for node in group._sweep_projection(nodes,
projected)]
+ pass_number_order = [node.node_id for node in
group._sort_via_pass_numbering(nodes, projected)]
+
+ assert pass_number_order == sweep_order
+
+
+def
test_topological_sort_padded_reverse_chain_uses_pass_numbering(monkeypatch):
+ dag = _make_padded_reverse_chain(chain_length=80, independent_count=80)
+ called = {"value": False}
+ original = TaskGroup._sort_via_pass_numbering
+
+ def spy(self, nodes, projected):
+ called["value"] = True
+ return original(self, nodes, projected)
+
+ monkeypatch.setattr(TaskGroup, "_sort_via_pass_numbering", spy)
+
+ order = [node.node_id for node in dag.task_group.topological_sort()]
+
+ assert called["value"]
+ _assert_valid_topological_order(dag.task_group, order)