jason810496 commented on code in PR #57441:
URL: https://github.com/apache/airflow/pull/57441#discussion_r2567535134
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -167,32 +169,132 @@ def __init__(
self.dag_bag = dag_bag
self.user = user
- def categorize_task_instances(
- self, task_keys: set[tuple[str, int]]
- ) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str,
int]]]:
+ def _extract_task_identifiers(
+ self, entity: str | BulkTaskInstanceBody
+ ) -> tuple[str, str, str, int | None]:
+ """
+ Extract task identifiers from an id or entity object.
+
+ :param entity: Task identifier as string or BulkTaskInstanceBody object
+ :return: tuple of (dag_id, dag_run_id, task_id, map_index)
+ """
+ if isinstance(entity, str):
+ dag_id = self.dag_id
+ dag_run_id = self.dag_run_id
+ task_id = entity
+ map_index = None
+ else:
+ dag_id = entity.dag_id if entity.dag_id else self.dag_id
+ dag_run_id = entity.dag_run_id if entity.dag_run_id else
self.dag_run_id
+ task_id = entity.task_id
+ map_index = entity.map_index
+
+ return dag_id, dag_run_id, task_id, map_index
+
+ def _categorize_entities(
+ self,
+ entities: Sequence[str | BulkTaskInstanceBody],
+ results: BulkActionResponse,
+ ) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
"""
- Categorize the given task_ids into matched_task_keys and
not_found_task_keys based on existing task_ids.
+ Validate entities and categorize them into specific and all map index
update sets.
- :param task_keys: set of task_keys (tuple of task_id and map_index)
+ :param entities: Sequence of entities to validate
+ :param results: BulkActionResponse object to track errors
+ :return: tuple of (specific_map_index_task_keys,
all_map_index_task_keys)
+ """
+ specific_map_index_task_keys = set()
+ all_map_index_task_keys = set()
+
+ for entity in entities:
+ dag_id, dag_run_id, task_id, map_index =
self._extract_task_identifiers(entity)
+
+ # Validate that we have specific values, not wildcards
+ if dag_id == "~" or dag_run_id == "~":
+ if isinstance(entity, str):
+ error_msg = f"When using wildcard in path, dag_id and
dag_run_id must be specified in BulkTaskInstanceBody object, not as string for
task_id: {entity}"
+ else:
+ error_msg = f"When using wildcard in path, dag_id and
dag_run_id must be specified in request body for task_id: {entity.task_id}"
+ results.errors.append(
+ {
+ "error": error_msg,
+ "status_code": status.HTTP_400_BAD_REQUEST,
+ }
+ )
+ continue
+
+ # Separate logic for "update all" vs "update specific"
+ if map_index is not None:
+ specific_map_index_task_keys.add((dag_id, dag_run_id, task_id,
map_index))
+ else:
+ all_map_index_task_keys.add((dag_id, dag_run_id, task_id))
+
+ return specific_map_index_task_keys, all_map_index_task_keys
+
+ def _categorize_task_instances(
+ self, task_keys: set[tuple[str, str, str, int]]
+ ) -> tuple[
+ dict[tuple[str, str, str, int], TI], set[tuple[str, str, str, int]],
set[tuple[str, str, str, int]]
+ ]:
+ """
+ Categorize the given task_keys into matched and not_found based on
existing task instances.
+
+ :param task_keys: set of task_keys (tuple of dag_id, dag_run_id,
task_id, and map_index)
:return: tuple of (task_instances_map, matched_task_keys,
not_found_task_keys)
"""
- query = select(TI).where(
- TI.dag_id == self.dag_id,
- TI.run_id == self.dag_run_id,
- TI.task_id.in_([task_id for task_id, _ in task_keys]),
- )
+ # Filter at database level using exact tuple matching instead of
fetching all combinations
+ # and filtering in Python
+ task_keys_list = list(task_keys)
+ query = select(TI).where(tuple_(TI.dag_id, TI.run_id, TI.task_id,
TI.map_index).in_(task_keys_list))
Review Comment:
It _seems_ we don't need to cast to `list` again for the `in_` clause.
```suggestion
query = select(TI).where(tuple_(TI.dag_id, TI.run_id, TI.task_id,
TI.map_index).in_(task_keys))
```
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -207,55 +309,99 @@ def handle_bulk_update(
self, action: BulkUpdateAction[BulkTaskInstanceBody], results:
BulkActionResponse
) -> None:
"""Bulk Update Task Instances."""
- to_update_task_keys = {
- (task_instance.task_id, task_instance.map_index if
task_instance.map_index is not None else -1)
- for task_instance in action.entities
- }
- _, _, not_found_task_keys =
self.categorize_task_instances(to_update_task_keys)
+ # Validate and categorize entities into specific and all map index
update sets
+ update_specific_map_index_task_keys, update_all_map_index_task_keys =
self._categorize_entities(
+ action.entities, results
+ )
try:
- for task_instance_body in action.entities:
- task_key = (
- task_instance_body.task_id,
- task_instance_body.map_index if
task_instance_body.map_index is not None else -1,
+ specific_entity_map = {
+ (entity.dag_id, entity.dag_run_id, entity.task_id,
entity.map_index): entity
+ for entity in action.entities
+ if entity.map_index is not None
+ }
+ all_map_entity_map = {
+ (entity.dag_id, entity.dag_run_id, entity.task_id): entity
+ for entity in action.entities
+ if entity.map_index is None
+ }
+
+ # Handle updates for specific map_index task instances
+ if update_specific_map_index_task_keys:
+ _, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
+ update_specific_map_index_task_keys
)
- if task_key in not_found_task_keys:
- if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL:
+ if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL and not_found_task_keys:
+ not_found_task_ids = [
+ {"dag_id": dag_id, "dag_run_id": run_id, "task_id":
task_id, "map_index": map_index}
+ for dag_id, run_id, task_id, map_index in
not_found_task_keys
+ ]
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"The task instances with these identifiers:
{not_found_task_ids} were not found",
+ )
+
+ for dag_id, dag_run_id, task_id, map_index in
matched_task_keys:
+ entity = specific_entity_map.get((dag_id, dag_run_id,
task_id, map_index))
+
+ if entity is not None:
+ self._perform_update(
+ dag_id=dag_id,
+ dag_run_id=dag_run_id,
+ task_id=task_id,
+ map_index=map_index,
+ entity=entity,
+ results=results,
+ update_mask=action.update_mask,
+ )
+
+ # Handle updates for all map indexes
+ if update_all_map_index_task_keys:
+ all_dag_ids = {dag_id for dag_id, _, _ in
update_all_map_index_task_keys}
+ all_run_ids = {run_id for _, run_id, _ in
update_all_map_index_task_keys}
+ all_task_ids = {task_id for _, _, task_id in
update_all_map_index_task_keys}
+
+ batch_task_instances = self.session.scalars(
+ select(TI).where(
+ TI.dag_id.in_(all_dag_ids),
+ TI.run_id.in_(all_run_ids),
+ TI.task_id.in_(all_task_ids),
+ )
+ ).all()
Review Comment:
`.where(tuple_(TI.dag_id, TI.run_id, TI.task_id,
TI.map_index).in_(task_keys_list))`
Would the above statement be more accurate for querying TI?
Although the following statement seems correct as well, I couldn't come up
with other edge cases right now.
```
select(TI).where(
TI.dag_id.in_(all_dag_ids),
TI.run_id.in_(all_run_ids),
TI.task_id.in_(all_task_ids),
)
```
So maybe we could unify both statements?
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -265,74 +411,81 @@ def handle_bulk_delete(
self, action: BulkDeleteAction[BulkTaskInstanceBody], results:
BulkActionResponse
) -> None:
"""Bulk delete task instances."""
- delete_all_map_indexes: set[str] = set()
- delete_specific_task_keys: set[tuple[str, int]] = set()
-
- for entity in action.entities:
- if isinstance(entity, str):
- # String task ID - remove all task instances for this task
- delete_all_map_indexes.add(entity)
- else:
- # BulkTaskInstanceBody object
- if entity.map_index is None:
- delete_all_map_indexes.add(entity.task_id)
- else:
- delete_specific_task_keys.add((entity.task_id,
entity.map_index))
+ # Validate and categorize entities into specific and all map index
delete sets
+ delete_specific_map_index_task_keys, delete_all_map_index_task_keys =
self._categorize_entities(
+ action.entities, results
+ )
try:
- # Handle deletion of specific (task_id, map_index) pairs
- if delete_specific_task_keys:
- _, matched_task_keys, not_found_task_keys =
self.categorize_task_instances(
- delete_specific_task_keys
+ # Handle deletion of specific (dag_id, dag_run_id, task_id,
map_index) tuples
+ if delete_specific_map_index_task_keys:
+ _, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
+ delete_specific_map_index_task_keys
)
- not_found_task_ids = [f"{task_id}[{map_index}]" for task_id,
map_index in not_found_task_keys]
+ not_found_task_ids = [
+ {"dag_id": dag_id, "dag_run_id": run_id, "task_id":
task_id, "map_index": map_index}
+ for dag_id, run_id, task_id, map_index in
not_found_task_keys
+ ]
if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL and not_found_task_keys:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
- detail=f"The task instances with these task_ids:
{not_found_task_ids} were not found",
+ detail=f"The task instances with these identifiers:
{not_found_task_ids} were not found",
)
- for task_id, map_index in matched_task_keys:
- result = (
+ for dag_id, run_id, task_id, map_index in matched_task_keys:
+ ti = (
self.session.execute(
select(TI).where(
+ TI.dag_id == dag_id,
+ TI.run_id == run_id,
TI.task_id == task_id,
- TI.dag_id == self.dag_id,
- TI.run_id == self.dag_run_id,
TI.map_index == map_index,
)
)
.scalars()
.one_or_none()
)
- if result:
- existing_task_instance = result
- self.session.delete(existing_task_instance)
- results.success.append(f"{task_id}[{map_index}]")
+ if ti:
+ self.session.delete(ti)
+
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")
- # Handle deletion of all map indexes for certain task_ids
- for task_id in delete_all_map_indexes:
- all_task_instances = self.session.scalars(
+ # Handle deletion of all map indexes for certain (dag_id,
dag_run_id, task_id) tuples
+ if delete_all_map_index_task_keys:
+ all_dag_ids = {dag_id for dag_id, _, _ in
delete_all_map_index_task_keys}
+ all_run_ids = {run_id for _, run_id, _ in
delete_all_map_index_task_keys}
+ all_task_ids = {task_id for _, _, task_id in
delete_all_map_index_task_keys}
+
+ batch_task_instances = self.session.scalars(
select(TI).where(
- TI.task_id == task_id,
- TI.dag_id == self.dag_id,
- TI.run_id == self.dag_run_id,
+ TI.dag_id.in_(all_dag_ids),
+ TI.run_id.in_(all_run_ids),
+ TI.task_id.in_(all_task_ids),
Review Comment:
Same question as above about the query statement.
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -265,74 +411,81 @@ def handle_bulk_delete(
self, action: BulkDeleteAction[BulkTaskInstanceBody], results:
BulkActionResponse
) -> None:
"""Bulk delete task instances."""
- delete_all_map_indexes: set[str] = set()
- delete_specific_task_keys: set[tuple[str, int]] = set()
-
- for entity in action.entities:
- if isinstance(entity, str):
- # String task ID - remove all task instances for this task
- delete_all_map_indexes.add(entity)
- else:
- # BulkTaskInstanceBody object
- if entity.map_index is None:
- delete_all_map_indexes.add(entity.task_id)
- else:
- delete_specific_task_keys.add((entity.task_id,
entity.map_index))
+ # Validate and categorize entities into specific and all map index
delete sets
+ delete_specific_map_index_task_keys, delete_all_map_index_task_keys =
self._categorize_entities(
+ action.entities, results
+ )
try:
- # Handle deletion of specific (task_id, map_index) pairs
- if delete_specific_task_keys:
- _, matched_task_keys, not_found_task_keys =
self.categorize_task_instances(
- delete_specific_task_keys
+ # Handle deletion of specific (dag_id, dag_run_id, task_id,
map_index) tuples
+ if delete_specific_map_index_task_keys:
+ _, matched_task_keys, not_found_task_keys =
self._categorize_task_instances(
+ delete_specific_map_index_task_keys
)
- not_found_task_ids = [f"{task_id}[{map_index}]" for task_id,
map_index in not_found_task_keys]
+ not_found_task_ids = [
+ {"dag_id": dag_id, "dag_run_id": run_id, "task_id":
task_id, "map_index": map_index}
+ for dag_id, run_id, task_id, map_index in
not_found_task_keys
+ ]
if action.action_on_non_existence ==
BulkActionNotOnExistence.FAIL and not_found_task_keys:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
- detail=f"The task instances with these task_ids:
{not_found_task_ids} were not found",
+ detail=f"The task instances with these identifiers:
{not_found_task_ids} were not found",
)
- for task_id, map_index in matched_task_keys:
- result = (
+ for dag_id, run_id, task_id, map_index in matched_task_keys:
Review Comment:
Not sure could we select all the matched TI before the loop, then only
perform delete in the loop.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]