jason810496 commented on code in PR #57441:
URL: https://github.com/apache/airflow/pull/57441#discussion_r2567535134


##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -167,32 +169,132 @@ def __init__(
         self.dag_bag = dag_bag
         self.user = user
 
-    def categorize_task_instances(
-        self, task_keys: set[tuple[str, int]]
-    ) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str, 
int]]]:
+    def _extract_task_identifiers(
+        self, entity: str | BulkTaskInstanceBody
+    ) -> tuple[str, str, str, int | None]:
+        """
+        Extract task identifiers from an id or entity object.
+
+        :param entity: Task identifier as string or BulkTaskInstanceBody object
+        :return: tuple of (dag_id, dag_run_id, task_id, map_index)
+        """
+        if isinstance(entity, str):
+            dag_id = self.dag_id
+            dag_run_id = self.dag_run_id
+            task_id = entity
+            map_index = None
+        else:
+            dag_id = entity.dag_id if entity.dag_id else self.dag_id
+            dag_run_id = entity.dag_run_id if entity.dag_run_id else 
self.dag_run_id
+            task_id = entity.task_id
+            map_index = entity.map_index
+
+        return dag_id, dag_run_id, task_id, map_index
+
+    def _categorize_entities(
+        self,
+        entities: Sequence[str | BulkTaskInstanceBody],
+        results: BulkActionResponse,
+    ) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
         """
-        Categorize the given task_ids into matched_task_keys and 
not_found_task_keys based on existing task_ids.
+        Validate entities and categorize them into specific and all map index 
update sets.
 
-        :param task_keys: set of task_keys (tuple of task_id and map_index)
+        :param entities: Sequence of entities to validate
+        :param results: BulkActionResponse object to track errors
+        :return: tuple of (specific_map_index_task_keys, 
all_map_index_task_keys)
+        """
+        specific_map_index_task_keys = set()
+        all_map_index_task_keys = set()
+
+        for entity in entities:
+            dag_id, dag_run_id, task_id, map_index = 
self._extract_task_identifiers(entity)
+
+            # Validate that we have specific values, not wildcards
+            if dag_id == "~" or dag_run_id == "~":
+                if isinstance(entity, str):
+                    error_msg = f"When using wildcard in path, dag_id and 
dag_run_id must be specified in BulkTaskInstanceBody object, not as string for 
task_id: {entity}"
+                else:
+                    error_msg = f"When using wildcard in path, dag_id and 
dag_run_id must be specified in request body for task_id: {entity.task_id}"
+                results.errors.append(
+                    {
+                        "error": error_msg,
+                        "status_code": status.HTTP_400_BAD_REQUEST,
+                    }
+                )
+                continue
+
+            # Separate logic for "update all" vs "update specific"
+            if map_index is not None:
+                specific_map_index_task_keys.add((dag_id, dag_run_id, task_id, 
map_index))
+            else:
+                all_map_index_task_keys.add((dag_id, dag_run_id, task_id))
+
+        return specific_map_index_task_keys, all_map_index_task_keys
+
+    def _categorize_task_instances(
+        self, task_keys: set[tuple[str, str, str, int]]
+    ) -> tuple[
+        dict[tuple[str, str, str, int], TI], set[tuple[str, str, str, int]], 
set[tuple[str, str, str, int]]
+    ]:
+        """
+        Categorize the given task_keys into matched and not_found based on 
existing task instances.
+
+        :param task_keys: set of task_keys (tuple of dag_id, dag_run_id, 
task_id, and map_index)
         :return: tuple of (task_instances_map, matched_task_keys, 
not_found_task_keys)
         """
-        query = select(TI).where(
-            TI.dag_id == self.dag_id,
-            TI.run_id == self.dag_run_id,
-            TI.task_id.in_([task_id for task_id, _ in task_keys]),
-        )
+        # Filter at database level using exact tuple matching instead of 
fetching all combinations
+        # and filtering in Python
+        task_keys_list = list(task_keys)
+        query = select(TI).where(tuple_(TI.dag_id, TI.run_id, TI.task_id, 
TI.map_index).in_(task_keys_list))

Review Comment:
   It _seems_ we don't need to cast to `list` again for the `in_` clause.
   
   ```suggestion
           query = select(TI).where(tuple_(TI.dag_id, TI.run_id, TI.task_id, 
TI.map_index).in_(task_keys))
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -207,55 +309,99 @@ def handle_bulk_update(
         self, action: BulkUpdateAction[BulkTaskInstanceBody], results: 
BulkActionResponse
     ) -> None:
         """Bulk Update Task Instances."""
-        to_update_task_keys = {
-            (task_instance.task_id, task_instance.map_index if 
task_instance.map_index is not None else -1)
-            for task_instance in action.entities
-        }
-        _, _, not_found_task_keys = 
self.categorize_task_instances(to_update_task_keys)
+        # Validate and categorize entities into specific and all map index 
update sets
+        update_specific_map_index_task_keys, update_all_map_index_task_keys = 
self._categorize_entities(
+            action.entities, results
+        )
 
         try:
-            for task_instance_body in action.entities:
-                task_key = (
-                    task_instance_body.task_id,
-                    task_instance_body.map_index if 
task_instance_body.map_index is not None else -1,
+            specific_entity_map = {
+                (entity.dag_id, entity.dag_run_id, entity.task_id, 
entity.map_index): entity
+                for entity in action.entities
+                if entity.map_index is not None
+            }
+            all_map_entity_map = {
+                (entity.dag_id, entity.dag_run_id, entity.task_id): entity
+                for entity in action.entities
+                if entity.map_index is None
+            }
+
+            # Handle updates for specific map_index task instances
+            if update_specific_map_index_task_keys:
+                _, matched_task_keys, not_found_task_keys = 
self._categorize_task_instances(
+                    update_specific_map_index_task_keys
                 )
 
-                if task_key in not_found_task_keys:
-                    if action.action_on_non_existence == 
BulkActionNotOnExistence.FAIL:
+                if action.action_on_non_existence == 
BulkActionNotOnExistence.FAIL and not_found_task_keys:
+                    not_found_task_ids = [
+                        {"dag_id": dag_id, "dag_run_id": run_id, "task_id": 
task_id, "map_index": map_index}
+                        for dag_id, run_id, task_id, map_index in 
not_found_task_keys
+                    ]
+                    raise HTTPException(
+                        status_code=status.HTTP_404_NOT_FOUND,
+                        detail=f"The task instances with these identifiers: 
{not_found_task_ids} were not found",
+                    )
+
+                for dag_id, dag_run_id, task_id, map_index in 
matched_task_keys:
+                    entity = specific_entity_map.get((dag_id, dag_run_id, 
task_id, map_index))
+
+                    if entity is not None:
+                        self._perform_update(
+                            dag_id=dag_id,
+                            dag_run_id=dag_run_id,
+                            task_id=task_id,
+                            map_index=map_index,
+                            entity=entity,
+                            results=results,
+                            update_mask=action.update_mask,
+                        )
+
+            # Handle updates for all map indexes
+            if update_all_map_index_task_keys:
+                all_dag_ids = {dag_id for dag_id, _, _ in 
update_all_map_index_task_keys}
+                all_run_ids = {run_id for _, run_id, _ in 
update_all_map_index_task_keys}
+                all_task_ids = {task_id for _, _, task_id in 
update_all_map_index_task_keys}
+
+                batch_task_instances = self.session.scalars(
+                    select(TI).where(
+                        TI.dag_id.in_(all_dag_ids),
+                        TI.run_id.in_(all_run_ids),
+                        TI.task_id.in_(all_task_ids),
+                    )
+                ).all()

Review Comment:
   `.where(tuple_(TI.dag_id, TI.run_id, TI.task_id, 
TI.map_index).in_(task_keys_list))`
   
   Would the above statement be more accurate for querying TI?
   
   Although the following statement seems correct as well, I couldn't come up 
with other edge cases right now. 
   ```
   select(TI).where(
       TI.dag_id.in_(all_dag_ids),
       TI.run_id.in_(all_run_ids),
       TI.task_id.in_(all_task_ids),
   )
   ```
   
   So maybe we could unify both statements?



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -265,74 +411,81 @@ def handle_bulk_delete(
         self, action: BulkDeleteAction[BulkTaskInstanceBody], results: 
BulkActionResponse
     ) -> None:
         """Bulk delete task instances."""
-        delete_all_map_indexes: set[str] = set()
-        delete_specific_task_keys: set[tuple[str, int]] = set()
-
-        for entity in action.entities:
-            if isinstance(entity, str):
-                # String task ID - remove all task instances for this task
-                delete_all_map_indexes.add(entity)
-            else:
-                # BulkTaskInstanceBody object
-                if entity.map_index is None:
-                    delete_all_map_indexes.add(entity.task_id)
-                else:
-                    delete_specific_task_keys.add((entity.task_id, 
entity.map_index))
+        # Validate and categorize entities into specific and all map index 
delete sets
+        delete_specific_map_index_task_keys, delete_all_map_index_task_keys = 
self._categorize_entities(
+            action.entities, results
+        )
 
         try:
-            # Handle deletion of specific (task_id, map_index) pairs
-            if delete_specific_task_keys:
-                _, matched_task_keys, not_found_task_keys = 
self.categorize_task_instances(
-                    delete_specific_task_keys
+            # Handle deletion of specific (dag_id, dag_run_id, task_id, 
map_index) tuples
+            if delete_specific_map_index_task_keys:
+                _, matched_task_keys, not_found_task_keys = 
self._categorize_task_instances(
+                    delete_specific_map_index_task_keys
                 )
-                not_found_task_ids = [f"{task_id}[{map_index}]" for task_id, 
map_index in not_found_task_keys]
+                not_found_task_ids = [
+                    {"dag_id": dag_id, "dag_run_id": run_id, "task_id": 
task_id, "map_index": map_index}
+                    for dag_id, run_id, task_id, map_index in 
not_found_task_keys
+                ]
 
                 if action.action_on_non_existence == 
BulkActionNotOnExistence.FAIL and not_found_task_keys:
                     raise HTTPException(
                         status_code=status.HTTP_404_NOT_FOUND,
-                        detail=f"The task instances with these task_ids: 
{not_found_task_ids} were not found",
+                        detail=f"The task instances with these identifiers: 
{not_found_task_ids} were not found",
                     )
 
-                for task_id, map_index in matched_task_keys:
-                    result = (
+                for dag_id, run_id, task_id, map_index in matched_task_keys:
+                    ti = (
                         self.session.execute(
                             select(TI).where(
+                                TI.dag_id == dag_id,
+                                TI.run_id == run_id,
                                 TI.task_id == task_id,
-                                TI.dag_id == self.dag_id,
-                                TI.run_id == self.dag_run_id,
                                 TI.map_index == map_index,
                             )
                         )
                         .scalars()
                         .one_or_none()
                     )
 
-                    if result:
-                        existing_task_instance = result
-                        self.session.delete(existing_task_instance)
-                        results.success.append(f"{task_id}[{map_index}]")
+                    if ti:
+                        self.session.delete(ti)
+                        
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")
 
-            # Handle deletion of all map indexes for certain task_ids
-            for task_id in delete_all_map_indexes:
-                all_task_instances = self.session.scalars(
+            # Handle deletion of all map indexes for certain (dag_id, 
dag_run_id, task_id) tuples
+            if delete_all_map_index_task_keys:
+                all_dag_ids = {dag_id for dag_id, _, _ in 
delete_all_map_index_task_keys}
+                all_run_ids = {run_id for _, run_id, _ in 
delete_all_map_index_task_keys}
+                all_task_ids = {task_id for _, _, task_id in 
delete_all_map_index_task_keys}
+
+                batch_task_instances = self.session.scalars(
                     select(TI).where(
-                        TI.task_id == task_id,
-                        TI.dag_id == self.dag_id,
-                        TI.run_id == self.dag_run_id,
+                        TI.dag_id.in_(all_dag_ids),
+                        TI.run_id.in_(all_run_ids),
+                        TI.task_id.in_(all_task_ids),

Review Comment:
   Same question as above about the query statement.



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -265,74 +411,81 @@ def handle_bulk_delete(
         self, action: BulkDeleteAction[BulkTaskInstanceBody], results: 
BulkActionResponse
     ) -> None:
         """Bulk delete task instances."""
-        delete_all_map_indexes: set[str] = set()
-        delete_specific_task_keys: set[tuple[str, int]] = set()
-
-        for entity in action.entities:
-            if isinstance(entity, str):
-                # String task ID - remove all task instances for this task
-                delete_all_map_indexes.add(entity)
-            else:
-                # BulkTaskInstanceBody object
-                if entity.map_index is None:
-                    delete_all_map_indexes.add(entity.task_id)
-                else:
-                    delete_specific_task_keys.add((entity.task_id, 
entity.map_index))
+        # Validate and categorize entities into specific and all map index 
delete sets
+        delete_specific_map_index_task_keys, delete_all_map_index_task_keys = 
self._categorize_entities(
+            action.entities, results
+        )
 
         try:
-            # Handle deletion of specific (task_id, map_index) pairs
-            if delete_specific_task_keys:
-                _, matched_task_keys, not_found_task_keys = 
self.categorize_task_instances(
-                    delete_specific_task_keys
+            # Handle deletion of specific (dag_id, dag_run_id, task_id, 
map_index) tuples
+            if delete_specific_map_index_task_keys:
+                _, matched_task_keys, not_found_task_keys = 
self._categorize_task_instances(
+                    delete_specific_map_index_task_keys
                 )
-                not_found_task_ids = [f"{task_id}[{map_index}]" for task_id, 
map_index in not_found_task_keys]
+                not_found_task_ids = [
+                    {"dag_id": dag_id, "dag_run_id": run_id, "task_id": 
task_id, "map_index": map_index}
+                    for dag_id, run_id, task_id, map_index in 
not_found_task_keys
+                ]
 
                 if action.action_on_non_existence == 
BulkActionNotOnExistence.FAIL and not_found_task_keys:
                     raise HTTPException(
                         status_code=status.HTTP_404_NOT_FOUND,
-                        detail=f"The task instances with these task_ids: 
{not_found_task_ids} were not found",
+                        detail=f"The task instances with these identifiers: 
{not_found_task_ids} were not found",
                     )
 
-                for task_id, map_index in matched_task_keys:
-                    result = (
+                for dag_id, run_id, task_id, map_index in matched_task_keys:

Review Comment:
   Not sure could we select all the matched TI before the loop, then only 
perform delete in the loop.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to