This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b2c42b713b8 Add map_index validation in categorize_task_instances 
(#54791)
b2c42b713b8 is described below

commit b2c42b713b8201b2e42a6142f7fdadc693eb231c
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Thu Aug 28 14:27:20 2025 +0800

    Add map_index validation in categorize_task_instances (#54791)
---
 .../core_api/services/public/task_instances.py     |  14 ++-
 .../unit/api_fastapi/core_api/services/__init__.py |  16 +++
 .../core_api/services/public/__init__.py           |  16 +++
 .../services/public/test_task_instances.py         | 124 +++++++++++++++++++++
 4 files changed, 165 insertions(+), 5 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index e80fc9cbc0e..4f68b808b7c 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -168,25 +168,29 @@ class 
BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
         self.user = user
 
     def categorize_task_instances(
-        self, task_ids: set[tuple[str, int]]
+        self, task_keys: set[tuple[str, int]]
     ) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str, 
int]]]:
         """
         Categorize the given task_ids into matched_task_keys and 
not_found_task_keys based on existing task_ids.
 
-        :param task_ids: set of task_ids
+        :param task_keys: set of task_keys (tuple of task_id and map_index)
         :return: tuple of (task_instances_map, matched_task_keys, 
not_found_task_keys)
         """
         query = select(TI).where(
             TI.dag_id == self.dag_id,
             TI.run_id == self.dag_run_id,
-            TI.task_id.in_([task_id for task_id, _ in task_ids]),
+            TI.task_id.in_([task_id for task_id, _ in task_keys]),
         )
         task_instances = self.session.scalars(query).all()
         task_instances_map = {
             (ti.task_id, ti.map_index if ti.map_index is not None else -1): ti 
for ti in task_instances
         }
-        matched_task_keys = {(task_id, map_index) for (task_id, map_index) in 
task_instances_map.keys()}
-        not_found_task_keys = {(task_id, map_index) for task_id, map_index in 
task_ids} - matched_task_keys
+        matched_task_keys = {
+            (task_id, map_index)
+            for (task_id, map_index) in task_instances_map.keys()
+            if (task_id, map_index) in task_keys
+        }
+        not_found_task_keys = {(task_id, map_index) for task_id, map_index in 
task_keys} - matched_task_keys
         return task_instances_map, matched_task_keys, not_found_task_keys
 
     def handle_bulk_create(
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py 
b/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/core_api/services/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py 
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/core_api/services/public/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
new file mode 100644
index 00000000000..5c8b5991ac5
--- /dev/null
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/services/public/test_task_instances.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.api_fastapi.core_api.datamodels.common import BulkBody
+from airflow.api_fastapi.core_api.services.public.task_instances import 
BulkTaskInstanceService
+from airflow.providers.standard.operators.bash import BashOperator
+
+from tests_common.test_utils.db import (
+    clear_db_runs,
+)
+
+pytestmark = pytest.mark.db_test
+DAG_ID = "TEST_DAG"
+DAG_RUN_ID = "TEST_DAG_RUN"
+TASK_ID_1 = "TEST_TASK_1"
+TASK_ID_2 = "TEST_TASK_2"
+
+
+class TestTaskInstanceEndpoint:
+    @staticmethod
+    def clear_db():
+        clear_db_runs()
+
+
+class TestCategorizeTaskInstances(TestTaskInstanceEndpoint):
+    """Tests for the categorize_task_instances method in 
BulkTaskInstanceService."""
+
+    def setup_method(self):
+        self.clear_db()
+
+    def teardown_method(self):
+        self.clear_db()
+
+    class MockUser:
+        def get_id(self) -> str:
+            return "test_user"
+
+        def get_name(self) -> str:
+            return "test_user"
+
+    @pytest.mark.parametrize(
+        "task_keys, expected_matched_keys, expected_not_found_keys, 
expected_matched_count, expected_not_found_count",
+        [
+            pytest.param(
+                {(TASK_ID_1, -1), (TASK_ID_2, -1)},
+                {(TASK_ID_1, -1), (TASK_ID_2, -1)},
+                set(),
+                2,
+                0,
+                id="all_found",
+            ),
+            pytest.param(
+                {("nonexistent_task", -1), ("nonexistent_task", 0)},
+                set(),
+                {("nonexistent_task", -1), ("nonexistent_task", 0)},
+                0,
+                2,
+                id="none_found",
+            ),
+            pytest.param(
+                {(TASK_ID_1, -1), (TASK_ID_1, 0)},
+                {(TASK_ID_1, -1)},
+                {(TASK_ID_1, 0)},
+                1,
+                1,
+                id="mixed_found_and_not_found",
+            ),
+            pytest.param(set(), set(), set(), 0, 0, id="empty_input"),
+        ],
+    )
+    def test_categorize_task_instances(
+        self,
+        session,
+        dag_maker,
+        task_keys,
+        expected_matched_keys,
+        expected_not_found_keys,
+        expected_matched_count,
+        expected_not_found_count,
+    ):
+        """Test categorize_task_instances with various scenarios."""
+        with dag_maker(dag_id=DAG_ID, session=session):
+            BashOperator(task_id=TASK_ID_1, bash_command="echo 1")
+            BashOperator(task_id=TASK_ID_2, bash_command="echo 2")
+
+        dag_maker.create_dagrun(run_id=DAG_RUN_ID)
+
+        session.commit()
+
+        user = self.MockUser()
+        bulk_request = BulkBody(actions=[])
+        service = BulkTaskInstanceService(
+            session=session,
+            request=bulk_request,
+            dag_id=DAG_ID,
+            dag_run_id=DAG_RUN_ID,
+            dag_bag=dag_maker.dagbag,
+            user=user,
+        )
+
+        _, matched_task_keys, not_found_task_keys = 
service.categorize_task_instances(task_keys)
+
+        assert len(matched_task_keys) == expected_matched_count
+        assert len(not_found_task_keys) == expected_not_found_count
+        assert matched_task_keys == expected_matched_keys
+        assert not_found_task_keys == expected_not_found_keys

Reply via email to