pierrejeambrun commented on code in PR #57441:
URL: https://github.com/apache/airflow/pull/57441#discussion_r2494658774


##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -167,32 +169,155 @@ def __init__(
         self.dag_bag = dag_bag
         self.user = user
 
-    def categorize_task_instances(
-        self, task_keys: set[tuple[str, int]]
-    ) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str, 
int]]]:
+    def _extract_task_identifiers(
+        self, entity: str | BulkTaskInstanceBody
+    ) -> tuple[str, str, str, int | None]:
+        """
+        Extract task identifiers from an id or entity object.
+
+        :param entity: Task identifier as string or BulkTaskInstanceBody object
+        :return: tuple of (dag_id, dag_run_id, task_id, map_index)
+        """
+        if isinstance(entity, str):
+            dag_id = self.dag_id
+            dag_run_id = self.dag_run_id
+            task_id = entity
+            map_index = None
+        else:
+            dag_id = entity.dag_id if entity.dag_id else self.dag_id
+            dag_run_id = entity.dag_run_id if entity.dag_run_id else 
self.dag_run_id
+            task_id = entity.task_id
+            map_index = entity.map_index
+
+        return dag_id, dag_run_id, task_id, map_index
+
+    def _categorize_entities(
+        self,
+        entities: Sequence[str | BulkTaskInstanceBody],
+        results: BulkActionResponse,
+    ) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
         """
-        Categorize the given task_ids into matched_task_keys and 
not_found_task_keys based on existing task_ids.
+        Validate entities and categorize them into specific and all map index 
update sets.
 
-        :param task_keys: set of task_keys (tuple of task_id and map_index)
+        :param entities: Sequence of entities to validate
+        :param results: BulkActionResponse object to track errors
+        :return: tuple of (specific_map_index_task_keys, 
all_map_index_task_keys)
+        """
+        specific_map_index_task_keys = set()
+        all_map_index_task_keys = set()
+
+        for entity in entities:
+            dag_id, dag_run_id, task_id, map_index = 
self._extract_task_identifiers(entity)
+
+            # Validate that we have specific values, not wildcards
+            if dag_id == "~" or dag_run_id == "~":
+                if isinstance(entity, str):
+                    error_msg = f"When using wildcard in path, dag_id and 
dag_run_id must be specified in BulkTaskInstanceBody object, not as string for 
task_id: {entity}"
+                else:
+                    error_msg = f"When using wildcard in path, dag_id and 
dag_run_id must be specified in request body for task_id: {entity.task_id}"
+                results.errors.append(
+                    {
+                        "error": error_msg,
+                        "status_code": status.HTTP_400_BAD_REQUEST,
+                    }
+                )
+                continue
+
+            # Separate logic for "update all" vs "update specific"
+            if map_index is not None:
+                specific_map_index_task_keys.add((dag_id, dag_run_id, task_id, 
map_index))
+            else:
+                all_map_index_task_keys.add((dag_id, dag_run_id, task_id))
+
+        return specific_map_index_task_keys, all_map_index_task_keys
+
+    def _categorize_task_instances(
+        self, task_keys: set[tuple[str, str, str, int]]
+    ) -> tuple[
+        dict[tuple[str, str, str, int], TI], set[tuple[str, str, str, int]], 
set[tuple[str, str, str, int]]
+    ]:
+        """
+        Categorize the given task_keys into matched and not_found based on 
existing task instances.
+
+        :param task_keys: set of task_keys (tuple of dag_id, dag_run_id, 
task_id, and map_index)
         :return: tuple of (task_instances_map, matched_task_keys, 
not_found_task_keys)
         """
-        query = select(TI).where(
-            TI.dag_id == self.dag_id,
-            TI.run_id == self.dag_run_id,
-            TI.task_id.in_([task_id for task_id, _ in task_keys]),
-        )
+        # Build a query that handles potentially multiple dag_id/dag_run_id 
combinations
+        query = select(TI)
+
+        # If path params are specific, use them for filtering
+        if self.dag_id != "~":
+            query = query.where(TI.dag_id == self.dag_id)
+        if self.dag_run_id != "~":
+            query = query.where(TI.run_id == self.dag_run_id)
+
+        # Extract unique dag_ids, run_ids, and task_ids from task_keys
+        dag_ids = {dag_id for dag_id, _, _, _ in task_keys}
+        run_ids = {run_id for _, run_id, _, _ in task_keys}
+        task_ids = {task_id for _, _, task_id, _ in task_keys}
+
+        # Apply filters based on the extracted values when using wildcards
+        if self.dag_id == "~" and dag_ids:
+            query = query.where(TI.dag_id.in_(dag_ids))
+        if self.dag_run_id == "~" and run_ids:
+            query = query.where(TI.run_id.in_(run_ids))
+        if task_ids:
+            query = query.where(TI.task_id.in_(task_ids))
+
         task_instances = self.session.scalars(query).all()
         task_instances_map = {
-            (ti.task_id, ti.map_index if ti.map_index is not None else -1): ti 
for ti in task_instances
+            (ti.dag_id, ti.run_id, ti.task_id, ti.map_index if ti.map_index is 
not None else -1): ti
+            for ti in task_instances
         }
         matched_task_keys = {
-            (task_id, map_index)
-            for (task_id, map_index) in task_instances_map.keys()
-            if (task_id, map_index) in task_keys
+            (dag_id, run_id, task_id, map_index)
+            for (dag_id, run_id, task_id, map_index) in 
task_instances_map.keys()
+            if (dag_id, run_id, task_id, map_index) in task_keys

Review Comment:
   Can't we do this in the query step? 
   
   Instead of querying too much data, and then filtering them again on the 
python side?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to