This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new efce3b19955 fix: eliminate duplicate JOINs in get_task_instances
endpoint (#62910)
efce3b19955 is described below
commit efce3b19955c2bfaf34f0d5e8c711434ac883c34
Author: Shubham Gondane <[email protected]>
AuthorDate: Wed Mar 18 08:27:36 2026 -0700
fix: eliminate duplicate JOINs in get_task_instances endpoint (#62910)
* fix: eliminate duplicate JOINs in get_task_instances endpoint (#62027)
Replaces joinedload with contains_eager for already-joined tables, reducing
the main query from 7 JOINs to 5.
* tests: replace query-count test with SQL-structure regression test for
#62027
* fix: mypy type annotation for dialect in benchmark script
* fix: centralise join + eager-loading in
eager_load_TI_and_TIH_for_validation for all task instance endpoints
* ci: retrigger checks
---
.../api_fastapi/common/db/task_instances.py | 34 +++++---
.../core_api/routes/public/task_instances.py | 31 +++-----
.../core_api/routes/public/test_task_instances.py | 93 ++++++++++++++++++++++
3 files changed, 129 insertions(+), 29 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
index 423e57f2316..615e1d260e0 100644
--- a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
+++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
@@ -17,8 +17,8 @@
from __future__ import annotations
-from sqlalchemy.orm import joinedload
-from sqlalchemy.orm.interfaces import LoaderOption
+from sqlalchemy import Select
+from sqlalchemy.orm import contains_eager, joinedload
from airflow.models import Base
from airflow.models.dag_version import DagVersion
@@ -26,15 +26,31 @@ from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
-def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) ->
tuple[LoaderOption, ...]:
- """Construct the eager loading options necessary for both
TaskInstanceResponse and TaskInstanceHistoryResponse objects."""
+def eager_load_TI_and_TIH_for_validation(
+ query: Select,
+ orm_model: Base | None = None,
+) -> Select:
+ """
+ Add JOINs and eager-loading options for TaskInstanceResponse and
TaskInstanceHistoryResponse.
+
+ Adds ``join(dag_run)`` and ``outerjoin(dag_version)`` to the query and
+ configures ``contains_eager`` so SQLAlchemy reuses those joins for
+ populating the related objects (dag_run, dag_model, dag_version, bundle).
+ This keeps the join logic centralised, avoids duplicate JOINs that would
+ otherwise occur when combining explicit joins with ``joinedload``, and
+ ensures ORDER BY / WHERE clauses on DagRun columns resolve correctly.
+
+ :param query: The SELECT statement to augment.
+ :param orm_model: The ORM model to load options for (defaults to
TaskInstance).
+ """
if orm_model is None:
orm_model = TaskInstance
- options: tuple[LoaderOption, ...] = (
- joinedload(orm_model.dag_version).joinedload(DagVersion.bundle),
- joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+ query = query.join(orm_model.dag_run).outerjoin(orm_model.dag_version)
+ query = query.options(
+
contains_eager(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+
contains_eager(orm_model.dag_version).options(joinedload(DagVersion.bundle)),
)
if orm_model is TaskInstance:
- options += (joinedload(orm_model.task_instance_note),)
- return options
+ query = query.options(joinedload(orm_model.task_instance_note))
+ return query
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 10ea000ace9..97fa930b1c1 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -197,11 +197,10 @@ def get_mapped_task_instances(
session: SessionDep,
) -> TaskInstanceCollectionResponse:
"""Get list of mapped task instances."""
- query = (
- select(TI)
- .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
- .join(TI.dag_run)
- .options(*eager_load_TI_and_TIH_for_validation())
+ query = eager_load_TI_and_TIH_for_validation(
+ select(TI).where(
+ TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0
+ )
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an
automatic 404
unfiltered_total_count = get_query_count(query, session=session)
@@ -326,17 +325,15 @@ def get_task_instance_tries(
"""Get list of task instances history."""
def _query(orm_object: Base) -> Select:
- query = (
- select(orm_object)
- .where(
+ query = eager_load_TI_and_TIH_for_validation(
+ select(orm_object).where(
orm_object.dag_id == dag_id,
orm_object.run_id == dag_run_id,
orm_object.task_id == task_id,
orm_object.map_index == map_index,
- )
- .options(*eager_load_TI_and_TIH_for_validation(orm_object))
- .options(joinedload(orm_object.hitl_detail))
- )
+ ),
+ orm_model=orm_object,
+ ).options(joinedload(orm_object.hitl_detail))
return query
# Exclude TaskInstance with state UP_FOR_RETRY since they have been
recorded in TaskInstanceHistory
@@ -480,9 +477,7 @@ def get_task_instances(
and DAG runs.
"""
dag_run = None
- query = (
-
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
- )
+ query = eager_load_TI_and_TIH_for_validation(select(TI))
if dag_run_id != "~":
if dag_id == "~":
raise HTTPException(
@@ -619,9 +614,7 @@ def get_task_instances_batch(
TI,
).set_value([body.order_by] if body.order_by else None)
- query = (
-
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
- )
+ query = eager_load_TI_and_TIH_for_validation(select(TI))
task_instance_select, total_entries = paginated_select(
statement=query,
filters=[
@@ -646,8 +639,6 @@ def get_task_instances_batch(
)
task_instance_select = task_instance_select.options(
joinedload(TI.rendered_task_instance_fields),
- joinedload(TI.task_instance_note),
- joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)),
)
task_instances = session.scalars(task_instance_select)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6bf7213c9c8..ba8f06a7424 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -985,6 +985,36 @@ class TestGetMappedTaskInstances:
assert response.status_code == 404
assert response.json()["detail"] == "Task id nonexistent_task not
found"
+ def test_no_duplicate_joins_in_get_mapped_task_instances_query(
+ self, one_task_with_mapped_tis, test_client
+ ):
+ """Regression test for #62027: the get_mapped_task_instances endpoint
must not emit duplicate JOINs."""
+ from sqlalchemy import event
+
+ import airflow.settings
+
+ executed_statements: list[str] = []
+
+ def capture(_conn, _cursor, statement, _parameters, _context,
_executemany):
+ executed_statements.append(statement.upper())
+
+ event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+ try:
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
+ finally:
+ event.remove(airflow.settings.engine, "before_cursor_execute",
capture)
+
+ assert response.status_code == 200
+
+ ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE"
in s and "JOIN DAG_RUN" in s]
+ assert ti_queries, "Expected at least one query selecting from
task_instance with JOIN dag_run"
+ for q in ti_queries:
+ assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly
once in JOINs"
+ if "JOIN DAG_VERSION" in q:
+ assert q.count("JOIN DAG_VERSION") == 1, "dag_version must
appear exactly once in JOINs"
+
class TestGetTaskInstances(TestTaskInstanceEndpoint):
@pytest.mark.parametrize(
@@ -1514,6 +1544,41 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
== f"Invalid value for state. Valid values are {',
'.join(TaskInstanceState)}"
)
+ def test_no_duplicate_joins_in_get_task_instances_query(self, test_client,
session):
+ """Regression test for #62027: the get_task_instances endpoint must
not emit duplicate JOINs.
+
+ Combining explicit join() with joinedload() on the same tables causes
SQLAlchemy
+ to emit duplicate JOINs (dag_run twice, dag_version twice). By relying
solely on
+ joinedload via eager_load_TI_and_TIH_for_validation, each table must
appear
+ exactly once in the SQL emitted by the real endpoint.
+ """
+ from sqlalchemy import event
+
+ import airflow.settings
+
+ self.create_task_instances(session)
+
+ executed_statements: list[str] = []
+
+ def capture(_conn, _cursor, statement, _parameters, _context,
_executemany):
+ executed_statements.append(statement.upper())
+
+ event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+ try:
+ response = test_client.get("/dags/~/dagRuns/~/taskInstances")
+ finally:
+ event.remove(airflow.settings.engine, "before_cursor_execute",
capture)
+
+ assert response.status_code == 200
+
+ # Find all statements that query task_instance joined with dag_run
+ ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE"
in s and "JOIN DAG_RUN" in s]
+ assert ti_queries, "Expected at least one query selecting from
task_instance with JOIN dag_run"
+ for q in ti_queries:
+ assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly
once in JOINs"
+ if "JOIN DAG_VERSION" in q:
+ assert q.count("JOIN DAG_VERSION") == 1, "dag_version must
appear exactly once in JOINs"
+
def test_return_TI_only_from_readable_dags(self, test_client, session):
task_instances = {
"example_python_operator": 1,
@@ -2100,6 +2165,34 @@ class
TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
assert num_entries_batch3 == ti_count
assert len(response_batch3.json()["task_instances"]) == ti_count
+ def test_no_duplicate_joins_in_get_task_instances_batch_query(self,
test_client, session):
+ """Regression test for #62027: the get_task_instances_batch endpoint
must not emit duplicate JOINs."""
+ from sqlalchemy import event
+
+ import airflow.settings
+
+ self.create_task_instances(session)
+
+ executed_statements: list[str] = []
+
+ def capture(_conn, _cursor, statement, _parameters, _context,
_executemany):
+ executed_statements.append(statement.upper())
+
+ event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+ try:
+ response =
test_client.post("/dags/~/dagRuns/~/taskInstances/list", json={})
+ finally:
+ event.remove(airflow.settings.engine, "before_cursor_execute",
capture)
+
+ assert response.status_code == 200
+
+ ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE"
in s and "JOIN DAG_RUN" in s]
+ assert ti_queries, "Expected at least one query selecting from
task_instance with JOIN dag_run"
+ for q in ti_queries:
+ assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly
once in JOINs"
+ if "JOIN DAG_VERSION" in q:
+ assert q.count("JOIN DAG_VERSION") == 1, "dag_version must
appear exactly once in JOINs"
+
class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
def test_should_respond_200(self, test_client, session):