This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f5ccebc8362 Fix N+1 query pattern in task instance states and count
endpoints (#60352)
f5ccebc8362 is described below
commit f5ccebc8362e12e8283ea51d8fabbd7a5df9cf87
Author: Steve Ahn <[email protected]>
AuthorDate: Fri Apr 3 13:41:06 2026 -0700
Fix N+1 query pattern in task instance states and count endpoints (#60352)
* fix inefficient fetch all and filter
* add unittest case: map-index but no task-group
---
.../execution_api/routes/task_instances.py | 26 ++++++++++++++--------
.../versions/head/test_task_instances.py | 10 +++++++++
2 files changed, 27 insertions(+), 9 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index bb1666e5137..e1687206d55 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -914,15 +914,13 @@ def get_task_instance_count(
query = query.where(TI.run_id.in_(run_ids))
if task_group_id:
- group_tasks = _get_group_tasks(dag_id, task_group_id, session,
dag_bag, logical_dates, run_ids)
+ group_tasks = _get_group_tasks(
+ dag_id, task_group_id, session, dag_bag, logical_dates, run_ids,
map_index
+ )
# Get unique (task_id, map_index) pairs
-
task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks]
- if map_index is not None:
- task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks
if ti.map_index == map_index]
-
if not task_map_pairs:
# If no task group tasks found, default to checking the task group
ID itself
# This matches the behavior in _get_external_task_group_task_ids
@@ -1022,15 +1020,18 @@ def get_task_instance_states(
if run_ids:
query = query.where(TI.run_id.in_(run_ids))
+ if map_index is not None:
+ query = query.where(TI.map_index == map_index)
+
results = session.scalars(query).all()
if task_group_id:
- group_tasks = _get_group_tasks(dag_id, task_group_id, session,
dag_bag, logical_dates, run_ids)
+ group_tasks = _get_group_tasks(
+ dag_id, task_group_id, session, dag_bag, logical_dates, run_ids,
map_index
+ )
results = results + group_tasks if task_ids else group_tasks
- if map_index is not None:
- results = [task for task in results if task.map_index == map_index]
[
run_id_task_state_map[task.run_id].update(
{task.task_id: task.state}
@@ -1071,7 +1072,13 @@ def _is_eligible_to_retry(state: str, try_number: int,
max_tries: int) -> bool:
def _get_group_tasks(
- dag_id: str, task_group_id: str, session: SessionDep, dag_bag: DagBagDep,
logical_dates=None, run_ids=None
+ dag_id: str,
+ task_group_id: str,
+ session: SessionDep,
+ dag_bag: DagBagDep,
+ logical_dates=None,
+ run_ids=None,
+ map_index: int | None = None,
):
# Get all tasks in the task group
dag = get_latest_version_of_dag(dag_bag, dag_id, session,
include_reason=True)
@@ -1092,6 +1099,7 @@ def _get_group_tasks(
TI.task_id.in_(task.task_id for task in task_group.iter_tasks()),
*([TI.logical_date.in_(logical_dates)] if logical_dates else []),
*([TI.run_id.in_(run_ids)] if run_ids else []),
+ *([TI.map_index == map_index] if map_index is not None else []),
)
).all()
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index a41396e4960..c6135711be9 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -2370,6 +2370,7 @@ class TestGetCount:
("map_index", "dynamic_task_args", "task_ids", "task_group_name",
"expected_count"),
(
pytest.param(None, [1, 2, 3], None, None, 5,
id="use-default-map-index-None"),
+ pytest.param(0, [1, 2, 3], None, None, 1,
id="with-map-index-0-no-task-group"),
pytest.param(-1, [1, 2, 3], ["task1"], None, 1,
id="with-task-ids-and-map-index-(-1)"),
pytest.param(None, [1, 2, 3], None, "group1", 4,
id="with-task-group-id-and-map-index-None"),
pytest.param(0, [1, 2, 3], None, "group1", 1,
id="with-task-group-id-and-map-index-0"),
@@ -2887,6 +2888,15 @@ class TestGetTaskStates:
},
id="with-default-map-index-None",
),
+ pytest.param(
+ 0,
+ [1, 2, 3],
+ None,
+ None,
+ {"-1": State.SUCCESS, "0": State.FAILED, "1": State.SUCCESS,
"2": State.SUCCESS},
+ {"group1.add_one_0": "failed"},
+ id="with-map-index-0-no-task-group",
+ ),
pytest.param(
-1,
[1, 2, 3],