This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 827e090824a Add number of queries guard in public task instances list
endpoints (#57645)
827e090824a is described below
commit 827e090824acaef0659a4d2a4f6b92fbc701d6e0
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Tue Nov 4 11:41:56 2025 +0100
Add number of queries guard in public task instances list endpoints (#57645)
---
.../api_fastapi/common/db/task_instances.py | 40 ++++++++
.../core_api/routes/public/task_instances.py | 17 ++--
.../core_api/routes/public/test_task_instances.py | 102 +++++++++++++++------
3 files changed, 121 insertions(+), 38 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
new file mode 100644
index 00000000000..423e57f2316
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm.interfaces import LoaderOption
+
+from airflow.models import Base
+from airflow.models.dag_version import DagVersion
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+
+
+def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) ->
tuple[LoaderOption, ...]:
+ """Construct the eager loading options necessary for both
TaskInstanceResponse and TaskInstanceHistoryResponse objects."""
+ if orm_model is None:
+ orm_model = TaskInstance
+
+ options: tuple[LoaderOption, ...] = (
+ joinedload(orm_model.dag_version).joinedload(DagVersion.bundle),
+ joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)),
+ )
+ if orm_model is TaskInstance:
+ options += (joinedload(orm_model.task_instance_note),)
+ return options
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 740dc986830..25263a0e237 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -33,6 +33,7 @@ from airflow.api_fastapi.common.dagbag import (
get_latest_version_of_dag,
)
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
+from airflow.api_fastapi.common.db.task_instances import
eager_load_TI_and_TIH_for_validation
from airflow.api_fastapi.common.parameters import (
FilterOptionEnum,
FilterParam,
@@ -193,8 +194,7 @@ def get_mapped_task_instances(
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
.join(TI.dag_run)
- .options(joinedload(TI.dag_version))
- .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+ .options(*eager_load_TI_and_TIH_for_validation())
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an
automatic 404
unfiltered_total_count = get_query_count(query, session=session)
@@ -324,8 +324,7 @@ def get_task_instance_tries(
orm_object.task_id == task_id,
orm_object.map_index == map_index,
)
- .options(joinedload(orm_object.dag_version))
-
.options(joinedload(orm_object.dag_run).options(joinedload(DagRun.dag_model)))
+ .options(*eager_load_TI_and_TIH_for_validation(orm_object))
.options(joinedload(orm_object.hitl_detail))
)
return query
@@ -467,11 +466,7 @@ def get_task_instances(
"""
dag_run = None
query = (
- select(TI)
- .join(TI.dag_run)
- .outerjoin(TI.dag_version)
- .options(joinedload(TI.dag_version))
- .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
)
if dag_run_id != "~":
dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id))
@@ -597,7 +592,9 @@ def get_task_instances_batch(
TI,
).set_value([body.order_by] if body.order_by else None)
- query = select(TI).join(TI.dag_run).outerjoin(TI.dag_version)
+ query = (
+
select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation())
+ )
task_instance_select, total_entries = paginated_select(
statement=query,
filters=[
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6a49c9fb63b..f1c6564e242 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -47,6 +47,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
from airflow.utils.types import DagRunType
from tests_common.test_utils.api_fastapi import _check_task_instance_note
+from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import (
clear_db_runs,
clear_rendered_ti_fields,
@@ -762,9 +763,10 @@ class TestGetMappedTaskInstances:
assert response.json() == {"detail": "The Dag with ID: `mapped_tis`
was not found"}
def test_should_respond_200(self, one_task_with_many_mapped_tis,
test_client):
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
assert response.status_code == 200
assert response.json()["total_entries"] == 110
@@ -803,10 +805,11 @@ class TestGetMappedTaskInstances:
def test_mapped_instances_order(
self, test_client, session, params, expected_map_indexes,
one_task_with_many_mapped_tis
):
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params=params,
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params=params,
+ )
assert response.status_code == 200
body = response.json()
@@ -834,10 +837,11 @@ class TestGetMappedTaskInstances:
session.commit()
- response = test_client.get(
-
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
- params=params,
- )
+ with assert_queries_count(4):
+ response = test_client.get(
+
"/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params=params,
+ )
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 110
@@ -935,7 +939,7 @@ class TestGetMappedTaskInstances:
class TestGetTaskInstances(TestTaskInstanceEndpoint):
@pytest.mark.parametrize(
- "task_instances, update_extras, url, params, expected_ti",
+ "task_instances, update_extras, url, params, expected_ti,
expected_queries_number",
[
pytest.param(
[
@@ -947,6 +951,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances",
{"logical_date_lte": DEFAULT_DATETIME_1},
1,
+ 5,
id="test logical date filter",
),
pytest.param(
@@ -959,6 +964,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances",
{"start_date_gte": DEFAULT_DATETIME_1, "start_date_lte":
DEFAULT_DATETIME_STR_2},
2,
+ 5,
id="test start date filter",
),
pytest.param(
@@ -974,6 +980,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"start_date_lt": DEFAULT_DATETIME_STR_2,
},
1,
+ 5,
id="test start date gt and lt filter",
),
pytest.param(
@@ -986,6 +993,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/~/taskInstances?",
{"end_date_gte": DEFAULT_DATETIME_1, "end_date_lte":
DEFAULT_DATETIME_STR_2},
2,
+ 5,
id="test end date filter",
),
pytest.param(
@@ -1001,6 +1009,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"end_date_lt": (DEFAULT_DATETIME_2 +
dt.timedelta(hours=1)).isoformat(),
},
1,
+ 5,
id="test end date gt and lt filter",
),
pytest.param(
@@ -1013,6 +1022,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"duration_gte": 100, "duration_lte": 200},
3,
+ 7,
id="test duration filter",
),
pytest.param(
@@ -1025,6 +1035,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"duration_gte": 100, "duration_lte": 200},
3,
+ 3,
id="test duration filter ~",
),
pytest.param(
@@ -1037,6 +1048,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"duration_gt": 100, "duration_lt": 200},
1,
+ 3,
id="test duration gt and lt filter ~",
),
pytest.param(
@@ -1050,6 +1062,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"state": ["running", "queued", "none"]},
3,
+ 7,
id="test state filter",
),
pytest.param(
@@ -1063,6 +1076,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"state": ["no_status"]},
1,
+ 7,
id="test no_status state filter",
),
pytest.param(
@@ -1076,6 +1090,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{},
4,
+ 7,
id="test null states with no filter",
),
pytest.param(
@@ -1084,6 +1099,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"start_date_gte": DEFAULT_DATETIME_STR_1},
1,
+ 7,
id="test start_date coalesce with null",
),
pytest.param(
@@ -1096,6 +1112,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"pool": ["test_pool_1", "test_pool_2"]},
2,
+ 7,
id="test pool filter",
),
pytest.param(
@@ -1108,6 +1125,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"pool": ["test_pool_1", "test_pool_2"]},
2,
+ 3,
id="test pool filter ~",
),
pytest.param(
@@ -1120,6 +1138,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances",
{"queue": ["test_queue_1", "test_queue_2"]},
2,
+ 7,
id="test queue filter",
),
pytest.param(
@@ -1132,6 +1151,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"queue": ["test_queue_1", "test_queue_2"]},
2,
+ 3,
id="test queue filter ~",
),
pytest.param(
@@ -1144,6 +1164,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"executor": ["test_exec_1", "test_exec_2"]},
2,
+ 7,
id="test_executor_filter",
),
pytest.param(
@@ -1156,6 +1177,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"executor": ["test_exec_1", "test_exec_2"]},
2,
+ 3,
id="test executor filter ~",
),
pytest.param(
@@ -1168,6 +1190,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"task_display_name_pattern": "task_name"},
2,
+ 3,
id="test task_display_name_pattern filter",
),
pytest.param(
@@ -1180,6 +1203,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"task_id": "task_match_id_2"},
1,
+ 3,
id="test task_id filter",
),
pytest.param(
@@ -1190,6 +1214,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"version_number": [2]},
2,
+ 3,
id="test version number filter",
),
pytest.param(
@@ -1201,6 +1226,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
{"version_number": [1, 2, 3]},
7, # apart from the TIs in the fixture, we also get one from
# the create_task_instances method
+ 3,
id="test multiple version numbers filter",
),
pytest.param(
@@ -1216,6 +1242,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"),
{"try_number": [0, 1]},
5,
+ 7,
id="test_try_number_filter",
),
pytest.param(
@@ -1234,6 +1261,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"operator": ["FirstOperator", "SecondOperator"]},
5,
+ 3,
id="test operator type filter filter",
),
pytest.param(
@@ -1251,6 +1279,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
("/dags/~/dagRuns/~/taskInstances"),
{"map_index": [0, 1]},
2,
+ 3,
id="test map_index filter",
),
pytest.param(
@@ -1259,6 +1288,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_python_operator"},
9, # Based on test failure - example_python_operator creates
9 task instances
+ 3,
id="test dag_id_pattern exact match",
),
pytest.param(
@@ -1267,6 +1297,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_%"},
17, # Based on test failure - both DAGs together create 17
task instances
+ 3,
id="test dag_id_pattern wildcard prefix",
),
pytest.param(
@@ -1275,6 +1306,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "%skip%"},
8, # Based on test failure - example_skip_dag creates 8 task
instances
+ 3,
id="test dag_id_pattern wildcard contains",
),
pytest.param(
@@ -1283,13 +1315,22 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "nonexistent"},
0,
+ 3,
id="test dag_id_pattern no match",
),
],
)
@pytest.mark.usefixtures("make_dag_with_multiple_versions")
def test_should_respond_200(
- self, test_client, task_instances, update_extras, url, params,
expected_ti, session
+ self,
+ test_client,
+ task_instances,
+ update_extras,
+ url,
+ params,
+ expected_ti,
+ expected_queries_number,
+ session,
):
# Special handling for dag_id_pattern tests that require multiple DAGs
if task_instances == "dag_id_pattern_test":
@@ -1307,7 +1348,8 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
with mock.patch("airflow.models.dag_version.DagBundlesManager") as
dag_bundle_manager_mock:
dag_bundle_manager_mock.return_value.view_url.return_value =
"some_url"
# Mock DagBundlesManager to avoid checking if dags-folder bundle
is configured
- response = test_client.get(url, params=params)
+ with assert_queries_count(expected_queries_number):
+ response = test_client.get(url, params=params)
if params == {"task_id_pattern": "task_match_id"}:
import pprint
@@ -1664,10 +1706,11 @@ class
TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
update_extras=update_extras,
task_instances=task_instances,
)
- response = test_client.post(
- "/dags/~/dagRuns/~/taskInstances/list",
- json=payload,
- )
+ with assert_queries_count(4):
+ response = test_client.post(
+ "/dags/~/dagRuns/~/taskInstances/list",
+ json=payload,
+ )
body = response.json()
assert response.status_code == 200, body
assert expected_ti_count == body["total_entries"]
@@ -3333,9 +3376,10 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
self.create_task_instances(
session=session, task_instances=[{"state": State.SUCCESS}],
with_ti_history=True
)
- response = test_client.get(
-
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
- )
+ with assert_queries_count(3):
+ response = test_client.get(
+
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries"
+ )
assert response.status_code == 200
assert response.json()["total_entries"] == 2 # The task instance and
its history
assert len(response.json()["task_instances"]) == 2
@@ -3425,9 +3469,10 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
TaskInstanceHistory.record_ti(ti, session=session)
session.flush()
- response = test_client.get(
-
f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries",
- )
+ with assert_queries_count(3):
+ response = test_client.get(
+
f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries",
+ )
assert response.status_code == 200
assert response.json() == {
"task_instances": [
@@ -3569,10 +3614,11 @@ class
TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
# in each loop, we should get the right mapped TI back
for map_index in (1, 2):
# Get the info from TIHistory: try_number 1, try_number 2 is TI
table(latest)
- response = test_client.get(
-
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
- f"/print_the_context/{map_index}/tries",
- )
+ with assert_queries_count(3):
+ response = test_client.get(
+
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+ f"/print_the_context/{map_index}/tries",
+ )
assert response.status_code == 200
assert (
response.json()["total_entries"] == 2