This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b2c42b713b8 Add map_index validation in categorize_task_instances
(#54791)
b2c42b713b8 is described below
commit b2c42b713b8201b2e42a6142f7fdadc693eb231c
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Thu Aug 28 14:27:20 2025 +0800
Add map_index validation in categorize_task_instances (#54791)
---
.../core_api/services/public/task_instances.py | 14 ++-
.../unit/api_fastapi/core_api/services/__init__.py | 16 +++
.../core_api/services/public/__init__.py | 16 +++
.../services/public/test_task_instances.py | 124 +++++++++++++++++++++
4 files changed, 165 insertions(+), 5 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index e80fc9cbc0e..4f68b808b7c 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -168,25 +168,29 @@ class
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
self.user = user
def categorize_task_instances(
- self, task_ids: set[tuple[str, int]]
+ self, task_keys: set[tuple[str, int]]
) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str,
int]]]:
"""
Categorize the given task_ids into matched_task_keys and
not_found_task_keys based on existing task_ids.
- :param task_ids: set of task_ids
+ :param task_keys: set of task_keys (tuple of task_id and map_index)
:return: tuple of (task_instances_map, matched_task_keys,
not_found_task_keys)
"""
query = select(TI).where(
TI.dag_id == self.dag_id,
TI.run_id == self.dag_run_id,
- TI.task_id.in_([task_id for task_id, _ in task_ids]),
+ TI.task_id.in_([task_id for task_id, _ in task_keys]),
)
task_instances = self.session.scalars(query).all()
task_instances_map = {
(ti.task_id, ti.map_index if ti.map_index is not None else -1): ti
for ti in task_instances
}
- matched_task_keys = {(task_id, map_index) for (task_id, map_index) in
task_instances_map.keys()}
- not_found_task_keys = {(task_id, map_index) for task_id, map_index in
task_ids} - matched_task_keys
+ matched_task_keys = {
+ (task_id, map_index)
+ for (task_id, map_index) in task_instances_map.keys()
+ if (task_id, map_index) in task_keys
+ }
+ not_found_task_keys = {(task_id, map_index) for task_id, map_index in
task_keys} - matched_task_keys
return task_instances_map, matched_task_keys, not_found_task_keys
def handle_bulk_create(
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py
b/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
new file mode 100644
index 00000000000..5c8b5991ac5
--- /dev/null
+++
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.api_fastapi.core_api.datamodels.common import BulkBody
+from airflow.api_fastapi.core_api.services.public.task_instances import
BulkTaskInstanceService
+from airflow.providers.standard.operators.bash import BashOperator
+
+from tests_common.test_utils.db import (
+ clear_db_runs,
+)
+
+pytestmark = pytest.mark.db_test
+DAG_ID = "TEST_DAG"
+DAG_RUN_ID = "TEST_DAG_RUN"
+TASK_ID_1 = "TEST_TASK_1"
+TASK_ID_2 = "TEST_TASK_2"
+
+
+class TestTaskInstanceEndpoint:
+ @staticmethod
+ def clear_db():
+ clear_db_runs()
+
+
+class TestCategorizeTaskInstances(TestTaskInstanceEndpoint):
+ """Tests for the categorize_task_instances method in
BulkTaskInstanceService."""
+
+ def setup_method(self):
+ self.clear_db()
+
+ def teardown_method(self):
+ self.clear_db()
+
+ class MockUser:
+ def get_id(self) -> str:
+ return "test_user"
+
+ def get_name(self) -> str:
+ return "test_user"
+
+ @pytest.mark.parametrize(
+ "task_keys, expected_matched_keys, expected_not_found_keys,
expected_matched_count, expected_not_found_count",
+ [
+ pytest.param(
+ {(TASK_ID_1, -1), (TASK_ID_2, -1)},
+ {(TASK_ID_1, -1), (TASK_ID_2, -1)},
+ set(),
+ 2,
+ 0,
+ id="all_found",
+ ),
+ pytest.param(
+ {("nonexistent_task", -1), ("nonexistent_task", 0)},
+ set(),
+ {("nonexistent_task", -1), ("nonexistent_task", 0)},
+ 0,
+ 2,
+ id="none_found",
+ ),
+ pytest.param(
+ {(TASK_ID_1, -1), (TASK_ID_1, 0)},
+ {(TASK_ID_1, -1)},
+ {(TASK_ID_1, 0)},
+ 1,
+ 1,
+ id="mixed_found_and_not_found",
+ ),
+ pytest.param(set(), set(), set(), 0, 0, id="empty_input"),
+ ],
+ )
+ def test_categorize_task_instances(
+ self,
+ session,
+ dag_maker,
+ task_keys,
+ expected_matched_keys,
+ expected_not_found_keys,
+ expected_matched_count,
+ expected_not_found_count,
+ ):
+ """Test categorize_task_instances with various scenarios."""
+ with dag_maker(dag_id=DAG_ID, session=session):
+ BashOperator(task_id=TASK_ID_1, bash_command="echo 1")
+ BashOperator(task_id=TASK_ID_2, bash_command="echo 2")
+
+ dag_maker.create_dagrun(run_id=DAG_RUN_ID)
+
+ session.commit()
+
+ user = self.MockUser()
+ bulk_request = BulkBody(actions=[])
+ service = BulkTaskInstanceService(
+ session=session,
+ request=bulk_request,
+ dag_id=DAG_ID,
+ dag_run_id=DAG_RUN_ID,
+ dag_bag=dag_maker.dagbag,
+ user=user,
+ )
+
+ _, matched_task_keys, not_found_task_keys =
service.categorize_task_instances(task_keys)
+
+ assert len(matched_task_keys) == expected_matched_count
+ assert len(not_found_task_keys) == expected_not_found_count
+ assert matched_task_keys == expected_matched_keys
+ assert not_found_task_keys == expected_not_found_keys