guan404ming commented on code in PR #57441:
URL: https://github.com/apache/airflow/pull/57441#discussion_r2508187649
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -167,32 +169,155 @@ def __init__(
self.dag_bag = dag_bag
self.user = user
- def categorize_task_instances(
- self, task_keys: set[tuple[str, int]]
- ) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str,
int]]]:
+ def _extract_task_identifiers(
+ self, entity: str | BulkTaskInstanceBody
+ ) -> tuple[str, str, str, int | None]:
+ """
+ Extract task identifiers from an id or entity object.
+
+ :param entity: Task identifier as string or BulkTaskInstanceBody object
+ :return: tuple of (dag_id, dag_run_id, task_id, map_index)
+ """
+ if isinstance(entity, str):
+ dag_id = self.dag_id
+ dag_run_id = self.dag_run_id
+ task_id = entity
+ map_index = None
+ else:
+ dag_id = entity.dag_id if entity.dag_id else self.dag_id
+ dag_run_id = entity.dag_run_id if entity.dag_run_id else
self.dag_run_id
+ task_id = entity.task_id
+ map_index = entity.map_index
+
+ return dag_id, dag_run_id, task_id, map_index
+
+ def _categorize_entities(
+ self,
+ entities: Sequence[str | BulkTaskInstanceBody],
+ results: BulkActionResponse,
+ ) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
"""
- Categorize the given task_ids into matched_task_keys and
not_found_task_keys based on existing task_ids.
+ Validate entities and categorize them into specific and all map index
update sets.
- :param task_keys: set of task_keys (tuple of task_id and map_index)
+ :param entities: Sequence of entities to validate
+ :param results: BulkActionResponse object to track errors
+ :return: tuple of (specific_map_index_task_keys,
all_map_index_task_keys)
+ """
+ specific_map_index_task_keys = set()
+ all_map_index_task_keys = set()
+
+ for entity in entities:
+ dag_id, dag_run_id, task_id, map_index =
self._extract_task_identifiers(entity)
+
+ # Validate that we have specific values, not wildcards
+ if dag_id == "~" or dag_run_id == "~":
+ if isinstance(entity, str):
+ error_msg = f"When using wildcard in path, dag_id and
dag_run_id must be specified in BulkTaskInstanceBody object, not as string for
task_id: {entity}"
+ else:
+ error_msg = f"When using wildcard in path, dag_id and
dag_run_id must be specified in request body for task_id: {entity.task_id}"
+ results.errors.append(
+ {
+ "error": error_msg,
+ "status_code": status.HTTP_400_BAD_REQUEST,
+ }
+ )
+ continue
+
+ # Separate logic for "update all" vs "update specific"
+ if map_index is not None:
+ specific_map_index_task_keys.add((dag_id, dag_run_id, task_id,
map_index))
+ else:
+ all_map_index_task_keys.add((dag_id, dag_run_id, task_id))
+
+ return specific_map_index_task_keys, all_map_index_task_keys
+
+ def _categorize_task_instances(
+ self, task_keys: set[tuple[str, str, str, int]]
+ ) -> tuple[
+ dict[tuple[str, str, str, int], TI], set[tuple[str, str, str, int]],
set[tuple[str, str, str, int]]
+ ]:
+ """
+ Categorize the given task_keys into matched and not_found based on
existing task instances.
+
+ :param task_keys: set of task_keys (tuple of dag_id, dag_run_id,
task_id, and map_index)
:return: tuple of (task_instances_map, matched_task_keys,
not_found_task_keys)
"""
- query = select(TI).where(
- TI.dag_id == self.dag_id,
- TI.run_id == self.dag_run_id,
- TI.task_id.in_([task_id for task_id, _ in task_keys]),
- )
+ # Build a query that handles potentially multiple dag_id/dag_run_id
combinations
+ query = select(TI)
+
+ # If path params are specific, use them for filtering
+ if self.dag_id != "~":
+ query = query.where(TI.dag_id == self.dag_id)
+ if self.dag_run_id != "~":
+ query = query.where(TI.run_id == self.dag_run_id)
+
+ # Extract unique dag_ids, run_ids, and task_ids from task_keys
+ dag_ids = {dag_id for dag_id, _, _, _ in task_keys}
+ run_ids = {run_id for _, run_id, _, _ in task_keys}
+ task_ids = {task_id for _, _, task_id, _ in task_keys}
+
+ # Apply filters based on the extracted values when using wildcards
+ if self.dag_id == "~" and dag_ids:
+ query = query.where(TI.dag_id.in_(dag_ids))
+ if self.dag_run_id == "~" and run_ids:
+ query = query.where(TI.run_id.in_(run_ids))
+ if task_ids:
+ query = query.where(TI.task_id.in_(task_ids))
+
task_instances = self.session.scalars(query).all()
task_instances_map = {
- (ti.task_id, ti.map_index if ti.map_index is not None else -1): ti
for ti in task_instances
+ (ti.dag_id, ti.run_id, ti.task_id, ti.map_index if ti.map_index is
not None else -1): ti
+ for ti in task_instances
}
matched_task_keys = {
- (task_id, map_index)
- for (task_id, map_index) in task_instances_map.keys()
- if (task_id, map_index) in task_keys
+ (dag_id, run_id, task_id, map_index)
+ for (dag_id, run_id, task_id, map_index) in
task_instances_map.keys()
+ if (dag_id, run_id, task_id, map_index) in task_keys
Review Comment:
Thanks for the suggestion, I'v simplified the logic and moved this in query
step.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]