This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new efce3b19955 fix: eliminate duplicate JOINs in get_task_instances 
endpoint (#62910)
efce3b19955 is described below

commit efce3b19955c2bfaf34f0d5e8c711434ac883c34
Author: Shubham Gondane <[email protected]>
AuthorDate: Wed Mar 18 08:27:36 2026 -0700

    fix: eliminate duplicate JOINs in get_task_instances endpoint (#62910)
    
    * fix: eliminate duplicate JOINs in get_task_instances endpoint (#62027)
    
    Replaces joinedload with contains_eager for already-joined tables, reducing 
the main query from 7 JOINs to 5.
    
    * tests: replace query-count test with SQL-structure regression test for 
#62027
    
    * fix: mypy type annotation for dialect in benchmark script
    
    * fix: centralise join + eager-loading in 
eager_load_TI_and_TIH_for_validation for all task instance endpoints
    
    * ci: retrigger checks
---
 .../api_fastapi/common/db/task_instances.py        | 34 +++++---
 .../core_api/routes/public/task_instances.py       | 31 +++-----
 .../core_api/routes/public/test_task_instances.py  | 93 ++++++++++++++++++++++
 3 files changed, 129 insertions(+), 29 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
index 423e57f2316..615e1d260e0 100644
--- a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
+++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
@@ -17,8 +17,8 @@
 
 from __future__ import annotations
 
-from sqlalchemy.orm import joinedload
-from sqlalchemy.orm.interfaces import LoaderOption
+from sqlalchemy import Select
+from sqlalchemy.orm import contains_eager, joinedload
 
 from airflow.models import Base
 from airflow.models.dag_version import DagVersion
@@ -26,15 +26,31 @@ from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 
 
-def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) -> 
tuple[LoaderOption, ...]:
-    """Construct the eager loading options necessary for both 
TaskInstanceResponse and TaskInstanceHistoryResponse objects."""
+def eager_load_TI_and_TIH_for_validation(
+    query: Select,
+    orm_model: Base | None = None,
+) -> Select:
+    """
+    Add JOINs and eager-loading options for TaskInstanceResponse and 
TaskInstanceHistoryResponse.
+
+    Adds ``join(dag_run)`` and ``outerjoin(dag_version)`` to the query and
+    configures ``contains_eager`` so SQLAlchemy reuses those joins for
+    populating the related objects (dag_run, dag_model, dag_version, bundle).
+    This keeps the join logic centralised, avoids duplicate JOINs that would
+    otherwise occur when combining explicit joins with ``joinedload``, and
+    ensures ORDER BY / WHERE clauses on DagRun columns resolve correctly.
+
+    :param query: The SELECT statement to augment.
+    :param orm_model: The ORM model to load options for (defaults to 
TaskInstance).
+    """
     if orm_model is None:
         orm_model = TaskInstance
 
-    options: tuple[LoaderOption, ...] = (
-        joinedload(orm_model.dag_version).joinedload(DagVersion.bundle),
-        joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+    query = query.join(orm_model.dag_run).outerjoin(orm_model.dag_version)
+    query = query.options(
+        
contains_eager(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+        
contains_eager(orm_model.dag_version).options(joinedload(DagVersion.bundle)),
     )
     if orm_model is TaskInstance:
-        options += (joinedload(orm_model.task_instance_note),)
-    return options
+        query = query.options(joinedload(orm_model.task_instance_note))
+    return query
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 10ea000ace9..97fa930b1c1 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -197,11 +197,10 @@ def get_mapped_task_instances(
     session: SessionDep,
 ) -> TaskInstanceCollectionResponse:
     """Get list of mapped task instances."""
-    query = (
-        select(TI)
-        .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index >= 0)
-        .join(TI.dag_run)
-        .options(*eager_load_TI_and_TIH_for_validation())
+    query = eager_load_TI_and_TIH_for_validation(
+        select(TI).where(
+            TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index >= 0
+        )
     )
     # 0 can mean a mapped TI that expanded to an empty list, so it is not an 
automatic 404
     unfiltered_total_count = get_query_count(query, session=session)
@@ -326,17 +325,15 @@ def get_task_instance_tries(
     """Get list of task instances history."""
 
     def _query(orm_object: Base) -> Select:
-        query = (
-            select(orm_object)
-            .where(
+        query = eager_load_TI_and_TIH_for_validation(
+            select(orm_object).where(
                 orm_object.dag_id == dag_id,
                 orm_object.run_id == dag_run_id,
                 orm_object.task_id == task_id,
                 orm_object.map_index == map_index,
-            )
-            .options(*eager_load_TI_and_TIH_for_validation(orm_object))
-            .options(joinedload(orm_object.hitl_detail))
-        )
+            ),
+            orm_model=orm_object,
+        ).options(joinedload(orm_object.hitl_detail))
         return query
 
     # Exclude TaskInstance with state UP_FOR_RETRY since they have been 
recorded in TaskInstanceHistory
@@ -480,9 +477,7 @@ def get_task_instances(
     and DAG runs.
     """
     dag_run = None
-    query = (
-        
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
-    )
+    query = eager_load_TI_and_TIH_for_validation(select(TI))
     if dag_run_id != "~":
         if dag_id == "~":
             raise HTTPException(
@@ -619,9 +614,7 @@ def get_task_instances_batch(
         TI,
     ).set_value([body.order_by] if body.order_by else None)
 
-    query = (
-        
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
-    )
+    query = eager_load_TI_and_TIH_for_validation(select(TI))
     task_instance_select, total_entries = paginated_select(
         statement=query,
         filters=[
@@ -646,8 +639,6 @@ def get_task_instances_batch(
     )
     task_instance_select = task_instance_select.options(
         joinedload(TI.rendered_task_instance_fields),
-        joinedload(TI.task_instance_note),
-        joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)),
     )
 
     task_instances = session.scalars(task_instance_select)
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6bf7213c9c8..ba8f06a7424 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -985,6 +985,36 @@ class TestGetMappedTaskInstances:
         assert response.status_code == 404
         assert response.json()["detail"] == "Task id nonexistent_task not 
found"
 
+    def test_no_duplicate_joins_in_get_mapped_task_instances_query(
+        self, one_task_with_mapped_tis, test_client
+    ):
+        """Regression test for #62027: the get_mapped_task_instances endpoint 
must not emit duplicate JOINs."""
+        from sqlalchemy import event
+
+        import airflow.settings
+
+        executed_statements: list[str] = []
+
+        def capture(_conn, _cursor, statement, _parameters, _context, 
_executemany):
+            executed_statements.append(statement.upper())
+
+        event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+        try:
+            response = test_client.get(
+                
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+            )
+        finally:
+            event.remove(airflow.settings.engine, "before_cursor_execute", 
capture)
+
+        assert response.status_code == 200
+
+        ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE" 
in s and "JOIN DAG_RUN" in s]
+        assert ti_queries, "Expected at least one query selecting from 
task_instance with JOIN dag_run"
+        for q in ti_queries:
+            assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly 
once in JOINs"
+            if "JOIN DAG_VERSION" in q:
+                assert q.count("JOIN DAG_VERSION") == 1, "dag_version must 
appear exactly once in JOINs"
+
 
 class TestGetTaskInstances(TestTaskInstanceEndpoint):
     @pytest.mark.parametrize(
@@ -1514,6 +1544,41 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
             == f"Invalid value for state. Valid values are {', 
'.join(TaskInstanceState)}"
         )
 
+    def test_no_duplicate_joins_in_get_task_instances_query(self, test_client, 
session):
+        """Regression test for #62027: the get_task_instances endpoint must 
not emit duplicate JOINs.
+
+        Combining explicit join() with joinedload() on the same tables causes 
SQLAlchemy
+        to emit duplicate JOINs (dag_run twice, dag_version twice). By relying 
solely on
+        joinedload via eager_load_TI_and_TIH_for_validation, each table must 
appear
+        exactly once in the SQL emitted by the real endpoint.
+        """
+        from sqlalchemy import event
+
+        import airflow.settings
+
+        self.create_task_instances(session)
+
+        executed_statements: list[str] = []
+
+        def capture(_conn, _cursor, statement, _parameters, _context, 
_executemany):
+            executed_statements.append(statement.upper())
+
+        event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+        try:
+            response = test_client.get("/dags/~/dagRuns/~/taskInstances")
+        finally:
+            event.remove(airflow.settings.engine, "before_cursor_execute", 
capture)
+
+        assert response.status_code == 200
+
+        # Find all statements that query task_instance joined with dag_run
+        ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE" 
in s and "JOIN DAG_RUN" in s]
+        assert ti_queries, "Expected at least one query selecting from 
task_instance with JOIN dag_run"
+        for q in ti_queries:
+            assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly 
once in JOINs"
+            if "JOIN DAG_VERSION" in q:
+                assert q.count("JOIN DAG_VERSION") == 1, "dag_version must 
appear exactly once in JOINs"
+
     def test_return_TI_only_from_readable_dags(self, test_client, session):
         task_instances = {
             "example_python_operator": 1,
@@ -2100,6 +2165,34 @@ class 
TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         assert num_entries_batch3 == ti_count
         assert len(response_batch3.json()["task_instances"]) == ti_count
 
+    def test_no_duplicate_joins_in_get_task_instances_batch_query(self, 
test_client, session):
+        """Regression test for #62027: the get_task_instances_batch endpoint 
must not emit duplicate JOINs."""
+        from sqlalchemy import event
+
+        import airflow.settings
+
+        self.create_task_instances(session)
+
+        executed_statements: list[str] = []
+
+        def capture(_conn, _cursor, statement, _parameters, _context, 
_executemany):
+            executed_statements.append(statement.upper())
+
+        event.listen(airflow.settings.engine, "before_cursor_execute", capture)
+        try:
+            response = 
test_client.post("/dags/~/dagRuns/~/taskInstances/list", json={})
+        finally:
+            event.remove(airflow.settings.engine, "before_cursor_execute", 
capture)
+
+        assert response.status_code == 200
+
+        ti_queries = [s for s in executed_statements if "FROM TASK_INSTANCE" 
in s and "JOIN DAG_RUN" in s]
+        assert ti_queries, "Expected at least one query selecting from 
task_instance with JOIN dag_run"
+        for q in ti_queries:
+            assert q.count("JOIN DAG_RUN") == 1, "dag_run must appear exactly 
once in JOINs"
+            if "JOIN DAG_VERSION" in q:
+                assert q.count("JOIN DAG_VERSION") == 1, "dag_version must 
appear exactly once in JOINs"
+
 
 class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
     def test_should_respond_200(self, test_client, session):

Reply via email to