guan404ming commented on code in PR #57441:
URL: https://github.com/apache/airflow/pull/57441#discussion_r2476877453
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -207,55 +327,91 @@ def handle_bulk_update(
self, action: BulkUpdateAction[BulkTaskInstanceBody], results:
BulkActionResponse
) -> None:
"""Bulk Update Task Instances."""
- to_update_task_keys = {
- (task_instance.task_id, task_instance.map_index if
task_instance.map_index is not None else -1)
- for task_instance in action.entities
- }
- _, _, not_found_task_keys =
self.categorize_task_instances(to_update_task_keys)
+ # Validate and categorize entities into specific and "all" update sets
+ update_specific_map_index_task_keys, update_all_map_index_task_keys =
self._categorize_entities(
+ action.entities, results
+ )
try:
- for task_instance_body in action.entities:
- task_key = (
- task_instance_body.task_id,
- task_instance_body.map_index if
task_instance_body.map_index is not None else -1,
+ # Handle updates for specific map_index task instances
+ if update_specific_map_index_task_keys:
+ _, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
+ update_specific_map_index_task_keys
)
- if task_key in not_found_task_keys:
- if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"The Task Instance with dag_id:
`{self.dag_id}`, run_id: `{self.dag_run_id}`, task_id:
`{task_instance_body.task_id}` and map_index: `{task_instance_body.map_index}`
was not found",
+ if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL and not_found_task_keys:
+ not_found_task_ids = [
+ {"dag_id": dag_id, "dag_run_id": run_id, "task_id":
task_id, "map_index": map_index}
+ for dag_id, run_id, task_id, map_index in
not_found_task_keys
+ ]
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"The task instances with these identifiers:
{not_found_task_ids} were not found",
+ )
+
+ for dag_id, dag_run_id, task_id, map_index in
matched_task_keys:
+ entity = next(
+ (
+ entity
+ for entity in action.entities
+ if entity.dag_id == dag_id
+ and entity.dag_run_id == dag_run_id
+ and entity.task_id == task_id
+ and entity.map_index == map_index
+ ),
+ None,
+ )
+
+ if entity is not None:
+ self._perform_update(
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ task_id=task_id,
+ map_index=map_index,
+ entity=entity,
+ results=results,
+ update_mask=action.update_mask,
)
- if action.action_on_non_existence ==
BulkActionNotOnExistence.SKIP:
- continue
-
- dag, tis, data = _patch_ti_validate_request(
- dag_id=self.dag_id,
- dag_run_id=self.dag_run_id,
- task_id=task_instance_body.task_id,
- dag_bag=self.dag_bag,
- body=task_instance_body,
- session=self.session,
- map_index=task_instance_body.map_index,
- update_mask=None,
+
+ # Handle updates for all map indexes
+ for dag_id, run_id, task_id in update_all_map_index_task_keys:
+ all_task_instances = self.session.scalars(
+ select(TI).where(
+ TI.dag_id == dag_id,
+ TI.run_id == run_id,
+ TI.task_id == task_id,
+ )
+ ).all()
Review Comment:
Thanks for letting me know. I have optimized the code here to prevent N+1
problem.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]