This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 827e090824a Add number of queries guard in public task instances list 
endpoints (#57645)
827e090824a is described below

commit 827e090824acaef0659a4d2a4f6b92fbc701d6e0
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Tue Nov 4 11:41:56 2025 +0100

    Add number of queries guard in public task instances list endpoints (#57645)
---
 .../api_fastapi/common/db/task_instances.py        |  40 ++++++++
 .../core_api/routes/public/task_instances.py       |  17 ++--
 .../core_api/routes/public/test_task_instances.py  | 102 +++++++++++++++------
 3 files changed, 121 insertions(+), 38 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
new file mode 100644
index 00000000000..423e57f2316
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm.interfaces import LoaderOption
+
+from airflow.models import Base
+from airflow.models.dag_version import DagVersion
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+
+
+def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) -> 
tuple[LoaderOption, ...]:
+    """Construct the eager loading options necessary for both 
TaskInstanceResponse and TaskInstanceHistoryResponse objects."""
+    if orm_model is None:
+        orm_model = TaskInstance
+
+    options: tuple[LoaderOption, ...] = (
+        joinedload(orm_model.dag_version).joinedload(DagVersion.bundle),
+        joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+    )
+    if orm_model is TaskInstance:
+        options += (joinedload(orm_model.task_instance_note),)
+    return options
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 740dc986830..25263a0e237 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -33,6 +33,7 @@ from airflow.api_fastapi.common.dagbag import (
     get_latest_version_of_dag,
 )
 from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
+from airflow.api_fastapi.common.db.task_instances import 
eager_load_TI_and_TIH_for_validation
 from airflow.api_fastapi.common.parameters import (
     FilterOptionEnum,
     FilterParam,
@@ -193,8 +194,7 @@ def get_mapped_task_instances(
         select(TI)
         .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index >= 0)
         .join(TI.dag_run)
-        .options(joinedload(TI.dag_version))
-        .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+        .options(*eager_load_TI_and_TIH_for_validation())
     )
     # 0 can mean a mapped TI that expanded to an empty list, so it is not an 
automatic 404
     unfiltered_total_count = get_query_count(query, session=session)
@@ -324,8 +324,7 @@ def get_task_instance_tries(
                 orm_object.task_id == task_id,
                 orm_object.map_index == map_index,
             )
-            .options(joinedload(orm_object.dag_version))
-            
.options(joinedload(orm_object.dag_run).options(joinedload(DagRun.dag_model)))
+            .options(*eager_load_TI_and_TIH_for_validation(orm_object))
             .options(joinedload(orm_object.hitl_detail))
         )
         return query
@@ -467,11 +466,7 @@ def get_task_instances(
     """
     dag_run = None
     query = (
-        select(TI)
-        .join(TI.dag_run)
-        .outerjoin(TI.dag_version)
-        .options(joinedload(TI.dag_version))
-        .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+        
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
     )
     if dag_run_id != "~":
         dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id))
@@ -597,7 +592,9 @@ def get_task_instances_batch(
         TI,
     ).set_value([body.order_by] if body.order_by else None)
 
-    query = select(TI).join(TI.dag_run).outerjoin(TI.dag_version)
+    query = (
+        
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
+    )
     task_instance_select, total_entries = paginated_select(
         statement=query,
         filters=[
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6a49c9fb63b..f1c6564e242 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -47,6 +47,7 @@ from airflow.utils.state import DagRunState, State, 
TaskInstanceState
 from airflow.utils.types import DagRunType
 
 from tests_common.test_utils.api_fastapi import _check_task_instance_note
+from tests_common.test_utils.asserts import assert_queries_count
 from tests_common.test_utils.db import (
     clear_db_runs,
     clear_rendered_ti_fields,
@@ -762,9 +763,10 @@ class TestGetMappedTaskInstances:
         assert response.json() == {"detail": "The Dag with ID: `mapped_tis` 
was not found"}
 
     def test_should_respond_200(self, one_task_with_many_mapped_tis, 
test_client):
-        response = test_client.get(
-            
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-        )
+        with assert_queries_count(4):
+            response = test_client.get(
+                
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+            )
 
         assert response.status_code == 200
         assert response.json()["total_entries"] == 110
@@ -803,10 +805,11 @@ class TestGetMappedTaskInstances:
     def test_mapped_instances_order(
         self, test_client, session, params, expected_map_indexes, 
one_task_with_many_mapped_tis
     ):
-        response = test_client.get(
-            
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params=params,
-        )
+        with assert_queries_count(4):
+            response = test_client.get(
+                
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+                params=params,
+            )
 
         assert response.status_code == 200
         body = response.json()
@@ -834,10 +837,11 @@ class TestGetMappedTaskInstances:
 
         session.commit()
 
-        response = test_client.get(
-            
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
-            params=params,
-        )
+        with assert_queries_count(4):
+            response = test_client.get(
+                
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+                params=params,
+            )
         assert response.status_code == 200
         body = response.json()
         assert body["total_entries"] == 110
@@ -935,7 +939,7 @@ class TestGetMappedTaskInstances:
 
 class TestGetTaskInstances(TestTaskInstanceEndpoint):
     @pytest.mark.parametrize(
-        "task_instances, update_extras, url, params, expected_ti",
+        "task_instances, update_extras, url, params, expected_ti, 
expected_queries_number",
         [
             pytest.param(
                 [
@@ -947,6 +951,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/example_python_operator/dagRuns/~/taskInstances",
                 {"logical_date_lte": DEFAULT_DATETIME_1},
                 1,
+                5,
                 id="test logical date filter",
             ),
             pytest.param(
@@ -959,6 +964,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/example_python_operator/dagRuns/~/taskInstances",
                 {"start_date_gte": DEFAULT_DATETIME_1, "start_date_lte": 
DEFAULT_DATETIME_STR_2},
                 2,
+                5,
                 id="test start date filter",
             ),
             pytest.param(
@@ -974,6 +980,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                     "start_date_lt": DEFAULT_DATETIME_STR_2,
                 },
                 1,
+                5,
                 id="test start date gt and lt filter",
             ),
             pytest.param(
@@ -986,6 +993,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/example_python_operator/dagRuns/~/taskInstances?",
                 {"end_date_gte": DEFAULT_DATETIME_1, "end_date_lte": 
DEFAULT_DATETIME_STR_2},
                 2,
+                5,
                 id="test end date filter",
             ),
             pytest.param(
@@ -1001,6 +1009,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                     "end_date_lt": (DEFAULT_DATETIME_2 + 
dt.timedelta(hours=1)).isoformat(),
                 },
                 1,
+                5,
                 id="test end date gt and lt filter",
             ),
             pytest.param(
@@ -1013,6 +1022,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
                 {"duration_gte": 100, "duration_lte": 200},
                 3,
+                7,
                 id="test duration filter",
             ),
             pytest.param(
@@ -1025,6 +1035,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"duration_gte": 100, "duration_lte": 200},
                 3,
+                3,
                 id="test duration filter ~",
             ),
             pytest.param(
@@ -1037,6 +1048,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"duration_gt": 100, "duration_lt": 200},
                 1,
+                3,
                 id="test duration gt and lt filter ~",
             ),
             pytest.param(
@@ -1050,6 +1062,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {"state": ["running", "queued", "none"]},
                 3,
+                7,
                 id="test state filter",
             ),
             pytest.param(
@@ -1063,6 +1076,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {"state": ["no_status"]},
                 1,
+                7,
                 id="test no_status state filter",
             ),
             pytest.param(
@@ -1076,6 +1090,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {},
                 4,
+                7,
                 id="test null states with no filter",
             ),
             pytest.param(
@@ -1084,6 +1099,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
                 {"start_date_gte": DEFAULT_DATETIME_STR_1},
                 1,
+                7,
                 id="test start_date coalesce with null",
             ),
             pytest.param(
@@ -1096,6 +1112,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {"pool": ["test_pool_1", "test_pool_2"]},
                 2,
+                7,
                 id="test pool filter",
             ),
             pytest.param(
@@ -1108,6 +1125,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"pool": ["test_pool_1", "test_pool_2"]},
                 2,
+                3,
                 id="test pool filter ~",
             ),
             pytest.param(
@@ -1120,6 +1138,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
                 {"queue": ["test_queue_1", "test_queue_2"]},
                 2,
+                7,
                 id="test queue filter",
             ),
             pytest.param(
@@ -1132,6 +1151,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"queue": ["test_queue_1", "test_queue_2"]},
                 2,
+                3,
                 id="test queue filter ~",
             ),
             pytest.param(
@@ -1144,6 +1164,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {"executor": ["test_exec_1", "test_exec_2"]},
                 2,
+                7,
                 id="test_executor_filter",
             ),
             pytest.param(
@@ -1156,6 +1177,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"executor": ["test_exec_1", "test_exec_2"]},
                 2,
+                3,
                 id="test executor filter ~",
             ),
             pytest.param(
@@ -1168,6 +1190,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ("/dags/~/dagRuns/~/taskInstances"),
                 {"task_display_name_pattern": "task_name"},
                 2,
+                3,
                 id="test task_display_name_pattern filter",
             ),
             pytest.param(
@@ -1180,6 +1203,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ("/dags/~/dagRuns/~/taskInstances"),
                 {"task_id": "task_match_id_2"},
                 1,
+                3,
                 id="test task_id filter",
             ),
             pytest.param(
@@ -1190,6 +1214,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ("/dags/~/dagRuns/~/taskInstances"),
                 {"version_number": [2]},
                 2,
+                3,
                 id="test version number filter",
             ),
             pytest.param(
@@ -1201,6 +1226,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 {"version_number": [1, 2, 3]},
                 7,  # apart from the TIs in the fixture, we also get one from
                 # the create_task_instances method
+                3,
                 id="test multiple version numbers filter",
             ),
             pytest.param(
@@ -1216,6 +1242,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
                 {"try_number": [0, 1]},
                 5,
+                7,
                 id="test_try_number_filter",
             ),
             pytest.param(
@@ -1234,6 +1261,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ("/dags/~/dagRuns/~/taskInstances"),
                 {"operator": ["FirstOperator", "SecondOperator"]},
                 5,
+                3,
                 id="test operator type filter filter",
             ),
             pytest.param(
@@ -1251,6 +1279,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ("/dags/~/dagRuns/~/taskInstances"),
                 {"map_index": [0, 1]},
                 2,
+                3,
                 id="test map_index filter",
             ),
             pytest.param(
@@ -1259,6 +1288,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "example_python_operator"},
                 9,  # Based on test failure - example_python_operator creates 
9 task instances
+                3,
                 id="test dag_id_pattern exact match",
             ),
             pytest.param(
@@ -1267,6 +1297,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "example_%"},
                 17,  # Based on test failure - both DAGs together create 17 
task instances
+                3,
                 id="test dag_id_pattern wildcard prefix",
             ),
             pytest.param(
@@ -1275,6 +1306,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "%skip%"},
                 8,  # Based on test failure - example_skip_dag creates 8 task 
instances
+                3,
                 id="test dag_id_pattern wildcard contains",
             ),
             pytest.param(
@@ -1283,13 +1315,22 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 "/dags/~/dagRuns/~/taskInstances",
                 {"dag_id_pattern": "nonexistent"},
                 0,
+                3,
                 id="test dag_id_pattern no match",
             ),
         ],
     )
     @pytest.mark.usefixtures("make_dag_with_multiple_versions")
     def test_should_respond_200(
-        self, test_client, task_instances, update_extras, url, params, 
expected_ti, session
+        self,
+        test_client,
+        task_instances,
+        update_extras,
+        url,
+        params,
+        expected_ti,
+        expected_queries_number,
+        session,
     ):
         # Special handling for dag_id_pattern tests that require multiple DAGs
         if task_instances == "dag_id_pattern_test":
@@ -1307,7 +1348,8 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
         with mock.patch("airflow.models.dag_version.DagBundlesManager") as 
dag_bundle_manager_mock:
             dag_bundle_manager_mock.return_value.view_url.return_value = 
"some_url"
             # Mock DagBundlesManager to avoid checking if dags-folder bundle 
is configured
-            response = test_client.get(url, params=params)
+            with assert_queries_count(expected_queries_number):
+                response = test_client.get(url, params=params)
         if params == {"task_id_pattern": "task_match_id"}:
             import pprint
 
@@ -1664,10 +1706,11 @@ class 
TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
             update_extras=update_extras,
             task_instances=task_instances,
         )
-        response = test_client.post(
-            "/dags/~/dagRuns/~/taskInstances/list",
-            json=payload,
-        )
+        with assert_queries_count(4):
+            response = test_client.post(
+                "/dags/~/dagRuns/~/taskInstances/list",
+                json=payload,
+            )
         body = response.json()
         assert response.status_code == 200, body
         assert expected_ti_count == body["total_entries"]
@@ -3333,9 +3376,10 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
         self.create_task_instances(
             session=session, task_instances=[{"state": State.SUCCESS}], 
with_ti_history=True
         )
-        response = test_client.get(
-            
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
-        )
+        with assert_queries_count(3):
+            response = test_client.get(
+                
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
+            )
         assert response.status_code == 200
         assert response.json()["total_entries"] == 2  # The task instance and 
its history
         assert len(response.json()["task_instances"]) == 2
@@ -3425,9 +3469,10 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
         TaskInstanceHistory.record_ti(ti, session=session)
         session.flush()
 
-        response = test_client.get(
-            
f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries",
-        )
+        with assert_queries_count(3):
+            response = test_client.get(
+                
f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries",
+            )
         assert response.status_code == 200
         assert response.json() == {
             "task_instances": [
@@ -3569,10 +3614,11 @@ class 
TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
         # in each loop, we should get the right mapped TI back
         for map_index in (1, 2):
             # Get the info from TIHistory: try_number 1, try_number 2 is TI 
table(latest)
-            response = test_client.get(
-                
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
-                f"/print_the_context/{map_index}/tries",
-            )
+            with assert_queries_count(3):
+                response = test_client.get(
+                    
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+                    f"/print_the_context/{map_index}/tries",
+                )
             assert response.status_code == 200
             assert (
                 response.json()["total_entries"] == 2

Reply via email to