This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new df5d1c7f650 Add number of queries guard in public task instances list
endpoints (#57645) (#57794)
df5d1c7f650 is described below
commit df5d1c7f650e0b8813b09dbeb8b695bcf61bb48b
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Tue Nov 4 14:05:29 2025 +0100
Add number of queries guard in public task instances list endpoints
(#57645) (#57794)
(cherry picked from commit 827e090824acaef0659a4d2a4f6b92fbc701d6e0)
---
.../api_fastapi/common/db/task_instances.py | 40 ++++++++++
.../core_api/routes/public/task_instances.py | 17 ++--
.../core_api/routes/public/test_task_instances.py | 91 ++++++++++++++++------
3 files changed, 113 insertions(+), 35 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
new file mode 100644
index 00000000000..423e57f2316
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm.interfaces import LoaderOption
+
+from airflow.models import Base
+from airflow.models.dag_version import DagVersion
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+
+
+def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) ->
tuple[LoaderOption, ...]:
+ """Construct the eager loading options necessary for both
TaskInstanceResponse and TaskInstanceHistoryResponse objects."""
+ if orm_model is None:
+ orm_model = TaskInstance
+
+ options: tuple[LoaderOption, ...] = (
+ joinedload(orm_model.dag_version).joinedload(DagVersion.bundle),
+ joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+ )
+ if orm_model is TaskInstance:
+ options += (joinedload(orm_model.task_instance_note),)
+ return options
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 7169b34145a..68764b5456a 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -33,6 +33,7 @@ from airflow.api_fastapi.common.dagbag import (
get_latest_version_of_dag,
)
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
+from airflow.api_fastapi.common.db.task_instances import
eager_load_TI_and_TIH_for_validation
from airflow.api_fastapi.common.parameters import (
FilterOptionEnum,
FilterParam,
@@ -189,8 +190,7 @@ def get_mapped_task_instances(
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
.join(TI.dag_run)
- .options(joinedload(TI.dag_version))
- .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+ .options(*eager_load_TI_and_TIH_for_validation())
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an
automatic 404
unfiltered_total_count = get_query_count(query, session=session)
@@ -317,8 +317,7 @@ def get_task_instance_tries(
orm_object.task_id == task_id,
orm_object.map_index == map_index,
)
- .options(joinedload(orm_object.dag_version))
-
.options(joinedload(orm_object.dag_run).options(joinedload(DagRun.dag_model)))
+ .options(*eager_load_TI_and_TIH_for_validation(orm_object))
)
return query
@@ -458,11 +457,7 @@ def get_task_instances(
"""
dag_run = None
query = (
- select(TI)
- .join(TI.dag_run)
- .outerjoin(TI.dag_version)
- .options(joinedload(TI.dag_version))
- .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
)
if dag_run_id != "~":
dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id))
@@ -587,7 +582,9 @@ def get_task_instances_batch(
TI,
).set_value([body.order_by] if body.order_by else None)
- query = select(TI)
+ query = (
+
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
+ )
task_instance_select, total_entries = paginated_select(
statement=query,
filters=[
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index fd8662aea4f..6f743cd7d4b 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -46,6 +46,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
from airflow.utils.types import DagRunType
from tests_common.test_utils.api_fastapi import _check_task_instance_note
+from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import (
clear_db_runs,
clear_rendered_ti_fields,
@@ -758,9 +759,10 @@ class TestGetMappedTaskInstances:
assert response.json() == {"detail": "The Dag with ID: `mapped_tis`
was not found"}
def test_should_respond_200(self, one_task_with_many_mapped_tis,
test_client):
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
assert response.status_code == 200
assert response.json()["total_entries"] == 110
@@ -799,10 +801,11 @@ class TestGetMappedTaskInstances:
def test_mapped_instances_order(
self, test_client, session, params, expected_map_indexes,
one_task_with_many_mapped_tis
):
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params=params,
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params=params,
+ )
assert response.status_code == 200
body = response.json()
@@ -830,10 +833,11 @@ class TestGetMappedTaskInstances:
session.commit()
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params=params,
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params=params,
+ )
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 110
@@ -931,7 +935,7 @@ class TestGetMappedTaskInstances:
class TestGetTaskInstances(TestTaskInstanceEndpoint):
@pytest.mark.parametrize(
- "task_instances, update_extras, url, params, expected_ti",
+ "task_instances, update_extras, url, params, expected_ti,
expected_queries_number",
[
pytest.param(
[
@@ -943,6 +947,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances",
{"logical_date_lte": DEFAULT_DATETIME_1},
1,
+ 5,
id="test logical date filter",
),
pytest.param(
@@ -955,6 +960,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances",
{"start_date_gte": DEFAULT_DATETIME_1, "start_date_lte":
DEFAULT_DATETIME_STR_2},
2,
+ 5,
id="test start date filter",
),
pytest.param(
@@ -970,6 +976,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"start_date_lt": DEFAULT_DATETIME_STR_2,
},
1,
+ 5,
id="test start date gt and lt filter",
),
pytest.param(
@@ -982,6 +989,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances?",
{"end_date_gte": DEFAULT_DATETIME_1, "end_date_lte":
DEFAULT_DATETIME_STR_2},
2,
+ 5,
id="test end date filter",
),
pytest.param(
@@ -997,6 +1005,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"end_date_lt": (DEFAULT_DATETIME_2 +
dt.timedelta(hours=1)).isoformat(),
},
1,
+ 5,
id="test end date gt and lt filter",
),
pytest.param(
@@ -1009,6 +1018,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"duration_gte": 100, "duration_lte": 200},
3,
+ 7,
id="test duration filter",
),
pytest.param(
@@ -1021,6 +1031,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"duration_gte": 100, "duration_lte": 200},
3,
+ 3,
id="test duration filter ~",
),
pytest.param(
@@ -1033,6 +1044,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"duration_gt": 100, "duration_lt": 200},
1,
+ 3,
id="test duration gt and lt filter ~",
),
pytest.param(
@@ -1046,6 +1058,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"state": ["running", "queued", "none"]},
3,
+ 7,
id="test state filter",
),
pytest.param(
@@ -1059,6 +1072,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"state": ["no_status"]},
1,
+ 7,
id="test no_status state filter",
),
pytest.param(
@@ -1072,6 +1086,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{},
4,
+ 7,
id="test null states with no filter",
),
pytest.param(
@@ -1080,6 +1095,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"start_date_gte": DEFAULT_DATETIME_STR_1},
1,
+ 7,
id="test start_date coalesce with null",
),
pytest.param(
@@ -1092,6 +1108,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"pool": ["test_pool_1", "test_pool_2"]},
2,
+ 7,
id="test pool filter",
),
pytest.param(
@@ -1104,6 +1121,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"pool": ["test_pool_1", "test_pool_2"]},
2,
+ 3,
id="test pool filter ~",
),
pytest.param(
@@ -1116,6 +1134,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"queue": ["test_queue_1", "test_queue_2"]},
2,
+ 7,
id="test queue filter",
),
pytest.param(
@@ -1128,6 +1147,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"queue": ["test_queue_1", "test_queue_2"]},
2,
+ 3,
id="test queue filter ~",
),
pytest.param(
@@ -1140,6 +1160,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"executor": ["test_exec_1", "test_exec_2"]},
2,
+ 7,
id="test_executor_filter",
),
pytest.param(
@@ -1152,6 +1173,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"executor": ["test_exec_1", "test_exec_2"]},
2,
+ 3,
id="test executor filter ~",
),
pytest.param(
@@ -1164,6 +1186,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"task_display_name_pattern": "task_name"},
2,
+ 3,
id="test task_display_name_pattern filter",
),
pytest.param(
@@ -1176,6 +1199,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"task_id": "task_match_id_2"},
1,
+ 3,
id="test task_id filter",
),
pytest.param(
@@ -1186,6 +1210,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"version_number": [2]},
2,
+ 3,
id="test version number filter",
),
pytest.param(
@@ -1197,6 +1222,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
{"version_number": [1, 2, 3]},
7, # apart from the TIs in the fixture, we also get one from
# the create_task_instances method
+ 3,
id="test multiple version numbers filter",
),
pytest.param(
@@ -1212,6 +1238,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"try_number": [0, 1]},
5,
+ 7,
id="test_try_number_filter",
),
pytest.param(
@@ -1230,6 +1257,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"operator": ["FirstOperator", "SecondOperator"]},
5,
+ 3,
id="test operator type filter filter",
),
pytest.param(
@@ -1247,13 +1275,22 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"map_index": [0, 1]},
2,
+ 3,
id="test map_index filter",
),
],
)
@pytest.mark.usefixtures("make_dag_with_multiple_versions")
def test_should_respond_200(
- self, test_client, task_instances, update_extras, url, params,
expected_ti, session
+ self,
+ test_client,
+ task_instances,
+ update_extras,
+ url,
+ params,
+ expected_ti,
+ expected_queries_number,
+ session,
):
self.create_task_instances(
session,
@@ -1262,7 +1299,8 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
)
with mock.patch("airflow.models.dag_version.DagBundlesManager"):
# Mock DagBundlesManager to avoid checking if dags-folder bundle
is configured
- response = test_client.get(url, params=params)
+ with assert_queries_count(expected_queries_number):
+ response = test_client.get(url, params=params)
if params == {"task_id_pattern": "task_match_id"}:
import pprint
@@ -1619,10 +1657,11 @@ class
TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
update_extras=update_extras,
task_instances=task_instances,
)
- response = test_client.post(
- "/dags/~/dagRuns/~/taskInstances/list",
- json=payload,
- )
+ with assert_queries_count(4):
+ response = test_client.post(
+ "/dags/~/dagRuns/~/taskInstances/list",
+ json=payload,
+ )
body = response.json()
assert response.status_code == 200, body
assert expected_ti_count == body["total_entries"]
@@ -3211,9 +3250,10 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
self.create_task_instances(
session=session, task_instances=[{"state": State.SUCCESS}],
with_ti_history=True
)
- response = test_client.get(
-
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
- )
+ with assert_queries_count(3):
+ response = test_client.get(
+
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
+ )
assert response.status_code == 200
assert response.json()["total_entries"] == 2 # The task instance and
its history
assert len(response.json()["task_instances"]) == 2
@@ -3369,10 +3409,11 @@ class
TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
# in each loop, we should get the right mapped TI back
for map_index in (1, 2):
# Get the info from TIHistory: try_number 1, try_number 2 is TI
table(latest)
- response = test_client.get(
-
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
- f"/print_the_context/{map_index}/tries",
- )
+ with assert_queries_count(3):
+ response = test_client.get(
+
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+ f"/print_the_context/{map_index}/tries",
+ )
assert response.status_code == 200
assert (
response.json()["total_entries"] == 2