Copilot commented on code in PR #67288:
URL: https://github.com/apache/airflow/pull/67288#discussion_r3284179588
##########
task-sdk/tests/task_sdk/definitions/test_taskgroup.py:
##########
@@ -957,3 +957,147 @@ def test_getitem_missing_is_key_error(self):
with pytest.raises(KeyError):
tg["nonexistent"]
+
+
+# --- topological_sort: cross-shape correctness ---
+#
+# Mirrors the shapes covered by the benchmark gist referenced from PR #67288
+# (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb).
+# Wall-clock timing is intentionally not asserted here — CI runners are too
+# variable for ms thresholds to be meaningful. The gist above can be run
+# manually to gauge performance.
+
+
+def _make_chain(n: int) -> DAG:
+ with DAG(f"chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ prev = None
+ for i in range(n):
+ t = EmptyOperator(task_id=f"t{i}")
+ if prev is not None:
+ prev >> t
+ prev = t
+ return dag
+
+
+def _make_reverse_chain(n: int) -> DAG:
+ with DAG(f"reverse_chain_{n}", schedule=None, start_date=DEFAULT_DATE) as
dag:
+ tasks = [EmptyOperator(task_id=f"t{n - 1 - i}") for i in range(n)]
+ by_id = {t.task_id: t for t in tasks}
+ for i in range(n - 1):
+ by_id[f"t{i}"] >> by_id[f"t{i + 1}"]
+ return dag
+
+
+def _make_diamond(n: int) -> DAG:
+ with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ root = EmptyOperator(task_id="root")
+ sink = EmptyOperator(task_id="sink")
+ middles = [EmptyOperator(task_id=f"m{i}") for i in range(max(n - 2,
1))]
+ root >> middles >> sink
+ return dag
+
+
+def _make_independent(n: int) -> DAG:
+ with DAG(f"independent_{n}", schedule=None, start_date=DEFAULT_DATE) as
dag:
+ for i in range(n):
+ EmptyOperator(task_id=f"t{i}")
+ return dag
+
+
+def _make_layered(n: int, layers: int = 4) -> DAG:
+ per_layer = max(n // layers, 1)
+ with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+ prev_layer: list[EmptyOperator] = []
+ for layer in range(layers):
+ cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in
range(per_layer)]
+ if prev_layer:
+ for upstream in prev_layer:
+ upstream >> cur
+ prev_layer = cur
+ return dag
+
+
+def _make_nested_groups(n: int, depth: int = 3) -> DAG:
+ per_group = max(n // (depth * depth), 1)
+ with DAG(f"nested_{n}", schedule=None, start_date=DEFAULT_DATE) as dag:
+
+ def build_group(level: int, idx: int) -> TaskGroup:
+ with TaskGroup(group_id=f"g{level}_{idx}") as tg:
+ prev = None
+ for i in range(per_group):
+ t = EmptyOperator(task_id=f"l{level}_g{idx}_t{i}")
+ if prev is not None:
+ prev >> t
+ prev = t
+ if level + 1 < depth:
+ inner_prev = None
+ for j in range(depth):
+ inner = build_group(level + 1, j)
+ if inner_prev is not None:
+ inner_prev >> inner
+ inner_prev = inner
+ return tg
+
+ top_prev = None
+ for j in range(depth):
+ top = build_group(0, j)
+ if top_prev is not None:
+ top_prev >> top
+ top_prev = top
+ return dag
+
+
+def _project_sibling(group: TaskGroup, upstream_task_id: str, child_id: str)
-> str | None:
+ """Mirror of TaskGroup._project_child_deps' projection, returning a string
ID."""
+ children = group.children
+ if upstream_task_id in children:
+ return upstream_task_id if upstream_task_id != child_id else None
+ upstream = group.dag.get_task(upstream_task_id)
+ tg = upstream.task_group
+ while tg is not None:
+ if tg.node_id in children:
+ return tg.node_id if tg.node_id != child_id else None
+ tg = tg.parent_group
+ return None
+
+
+def _walk_groups(tg: TaskGroup):
+ yield tg
+ for child in tg.children.values():
+ if isinstance(child, TaskGroup):
+ yield from _walk_groups(child)
+
+
+def _assert_valid_topological_order(group: TaskGroup, order: list[str]) ->
None:
+ position = {node_id: i for i, node_id in enumerate(order)}
+ assert set(position) == set(group.children), (
+ f"topological_sort output {order!r} does not cover children of
{group.node_id!r}"
+ )
+ for child_id, child in group.children.items():
+ for upstream_id in child.upstream_task_ids:
+ sib = _project_sibling(group, upstream_id, child_id)
+ if sib is None:
+ continue
+ assert position[sib] < position[child_id], (
+ f"In group {group.node_id!r}: sibling {sib!r} must precede
{child_id!r}, got order {order!r}"
+ )
+
+
[email protected](
+ ("shape", "builder"),
+ [
+ ("chain", _make_chain),
+ ("rev-chain", _make_reverse_chain),
+ ("diamond", _make_diamond),
+ ("independent", _make_independent),
+ ("layered", _make_layered),
+ ("nested", _make_nested_groups),
+ ],
+)
[email protected]("n", [20, 100, 500])
+def test_topological_sort_shape_correctness(shape, builder, n):
+ """topological_sort emits a valid order for every nested group across DAG
shapes."""
+ dag = builder(n)
+ for group in _walk_groups(dag.task_group):
+ order = [node.node_id for node in group.topological_sort()]
+ _assert_valid_topological_order(group, order)
Review Comment:
The new shape-correctness test generates very large DAGs (e.g. for `n=500`,
`_make_layered` creates 500 tasks with ~46k edges, and `_make_nested_groups`
creates ~2k+ tasks). This significantly increases unit-test runtime and can
slow CI; consider reducing the parameter grid (smaller max `n`, fewer shapes at
large `n`, or splitting a single large stress case into a separate/marked test).
##########
airflow-core/newsfragments/67288.improvement.rst:
##########
@@ -0,0 +1 @@
+Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond,
layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups.
Review Comment:
Use “DAG” instead of “Dag” in the newsfragment for consistency with project
terminology.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]