yuseok89 commented on code in PR #66554:
URL: https://github.com/apache/airflow/pull/66554#discussion_r3234878900
##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py:
##########
@@ -86,3 +123,435 @@ async def wait(self) -> AsyncGenerator[str, None]:
await asyncio.sleep(self.interval)
yield await self._serialize_response(dag_run := await
self._get_dag_run())
yield "\n"
+
+
+def _format_dag_run_key(dag_id: str, dag_run_id: str) -> str:
+ return f"{dag_id}.{dag_run_id}"
+
+
+def _authorize_dag_run(
+ *,
+ session: Session,
+ user,
+ dag_id: str,
+ method: AuthMethod,
+ cache: dict[str, bool],
+) -> bool:
+ """
+ Return whether ``user`` may perform ``method`` on Dag runs of ``dag_id``.
+
+ The result is memoised in ``cache`` so a bulk request that touches many
+ runs of the same Dag only pays for one ``is_authorized_dag`` call per Dag.
+ """
+ if dag_id not in cache:
+ team_name = DagModel.get_team_name(dag_id, session=session)
+ cache[dag_id] = get_auth_manager().is_authorized_dag(
+ method=method,
+ access_entity=DagAccessEntity.RUN,
+ details=DagDetails(id=dag_id, team_name=team_name),
+ user=user,
+ )
+ return cache[dag_id]
+
+
+def _apply_state_change(
+ dag_run: DagRun,
+ new_state: DAGRunPatchStates,
+ dag: SerializedDAG,
+ session: Session,
+) -> None:
+ """Apply ``new_state`` to ``dag_run`` and fire the matching listener
hook."""
+ if new_state == DAGRunPatchStates.SUCCESS:
+ set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id,
commit=True, session=session)
+ try:
+ get_listener_manager().hook.on_dag_run_success(dag_run=dag_run,
msg="")
+ except Exception:
+ log.exception("error calling listener")
+ elif new_state == DAGRunPatchStates.QUEUED:
+ # Notification on queued is intentionally skipped; the scheduler emits
+ # the RUNNING notification instead.
+ set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id,
commit=True, session=session)
+ elif new_state == DAGRunPatchStates.FAILED:
+ set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id,
commit=True, session=session)
+ try:
+ get_listener_manager().hook.on_dag_run_failed(dag_run=dag_run,
msg="")
+ except Exception:
+ log.exception("error calling listener")
+
+
+def _apply_note(dag_run: DagRun, note: str | None, user_id: str) -> None:
+ if dag_run.dag_run_note is None:
+ dag_run.note = (note, user_id)
+ else:
+ dag_run.dag_run_note.content = note
+ dag_run.dag_run_note.user_id = user_id
+
+
+def _validate_no_wildcard_in_resolved(
+ *,
+ dag_id: str,
+ dag_run_id: str,
+ results: BulkActionResponse,
+) -> bool:
+ if dag_id == "~" or dag_run_id == "~":
+ results.errors.append(
+ {
+ "error": (
+ "When the path uses the ``~`` wildcard, ``dag_id`` and
``dag_run_id`` must be "
+ "specified in the body for each entity."
+ ),
+ "status_code": status.HTTP_400_BAD_REQUEST,
+ }
+ )
+ return False
+ return True
+
+
+def _validate_path_dag_id_match(
+ *,
+ path_dag_id: str,
+ entity_dag_id: str | None,
+ dag_run_id: str,
+ results: BulkActionResponse,
+) -> bool:
+ if path_dag_id != "~" and entity_dag_id is not None and entity_dag_id !=
path_dag_id:
+ results.errors.append(
+ {
+ "error": (
+ f"Entity dag_id '{entity_dag_id}' does not match path
dag_id '{path_dag_id}'. "
+ "Use ``~`` in the path for cross-DAG bulk operations."
+ ),
+ "status_code": status.HTTP_400_BAD_REQUEST,
+ "dag_id": entity_dag_id,
+ "dag_run_id": dag_run_id,
+ }
+ )
+ return False
+ return True
+
+
+class BulkDagRunService(BulkService[BulkDagRunBody]):
+ """Service for handling bulk operations on Dag runs."""
+
+ def __init__(
+ self,
+ session: Session,
+ request: BulkBody[BulkDagRunBody],
+ dag_id: str,
+ dag_bag: DagBagDep,
+ user: GetUserDep,
+ ):
+ super().__init__(session, request)
+ self.dag_id = dag_id
+ self.dag_bag = dag_bag
+ self.user = user
+
+ def _resolve_identifiers(self, entity: str | BulkDagRunBody) -> tuple[str,
str]:
+ """Return ``(dag_id, dag_run_id)`` for an entity, falling back to the
path's ``dag_id``."""
+ if isinstance(entity, str):
+ return self.dag_id, entity
+ dag_id = entity.dag_id or self.dag_id
+ return dag_id, entity.dag_run_id
+
+ def _check_dag_authorization(
+ self,
+ dag_id: str,
+ method: AuthMethod,
+ action_name: str,
+ results: BulkActionResponse,
+ cache: dict[str, bool],
+ ) -> bool:
+ if not _authorize_dag_run(
+ session=self.session,
+ user=self.user,
+ dag_id=dag_id,
+ method=method,
+ cache=cache,
+ ):
+ results.errors.append(
+ {
+ "error": f"User is not authorized to {action_name} Dag
runs for DAG '{dag_id}'",
+ "status_code": status.HTTP_403_FORBIDDEN,
+ }
+ )
+ return False
+ return True
+
+ def _fetch_dag_runs(
+ self,
+ keys: set[tuple[str, str]],
+ ) -> tuple[dict[tuple[str, str], DagRun], set[tuple[str, str]]]:
+ if not keys:
+ return {}, set()
+ keys_list = list(keys)
+ dag_runs = self.session.scalars(
+ select(DagRun)
+ .options(joinedload(DagRun.dag_model))
+ .where(
+ DagRun.dag_id.in_({k[0] for k in keys_list}),
+ DagRun.run_id.in_({k[1] for k in keys_list}),
+ )
+ ).all()
+ found = {(dr.dag_id, dr.run_id): dr for dr in dag_runs if (dr.dag_id,
dr.run_id) in keys}
+ not_found = keys - set(found.keys())
+ return found, not_found
+
+ def handle_bulk_create(
+ self, action: BulkCreateAction[BulkDagRunBody], results:
BulkActionResponse
+ ) -> None:
+ results.errors.append(
+ {
+ "error": "Dag runs bulk create is not supported via this
endpoint; use the trigger Dag run endpoint instead.",
+ "status_code": status.HTTP_405_METHOD_NOT_ALLOWED,
+ }
+ )
+
+ def handle_bulk_update(
+ self, action: BulkUpdateAction[BulkDagRunBody], results:
BulkActionResponse
+ ) -> None:
+ """Bulk update Dag runs (state and/or note)."""
+ update_mask = action.update_mask
+ auth_cache: dict[str, bool] = {}
+ keys: set[tuple[str, str]] = set()
+ entity_map: dict[tuple[str, str], BulkDagRunBody] = {}
+
+ for entity in action.entities:
+ if isinstance(entity, str):
+ results.errors.append(
+ {
+ "error": "Bulk update requires entities as objects,
not strings.",
+ "status_code": status.HTTP_400_BAD_REQUEST,
+ }
+ )
+ continue
+ dag_id, dag_run_id = self._resolve_identifiers(entity)
+ if not _validate_no_wildcard_in_resolved(dag_id=dag_id,
dag_run_id=dag_run_id, results=results):
+ continue
+ if not _validate_path_dag_id_match(
+ path_dag_id=self.dag_id,
+ entity_dag_id=entity.dag_id,
+ dag_run_id=dag_run_id,
+ results=results,
+ ):
+ continue
+ if not self._check_dag_authorization(dag_id, "PUT",
action.action.value, results, auth_cache):
+ continue
+ keys.add((dag_id, dag_run_id))
+ entity_map[(dag_id, dag_run_id)] = entity
+
+ try:
+ found, not_found = self._fetch_dag_runs(keys)
+
+ if action.action_on_non_existence == BulkActionNotOnExistence.FAIL
and not_found:
+ missing = [{"dag_id": d, "dag_run_id": r} for d, r in
not_found]
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"The Dag runs with these identifiers were not
found: {missing}",
+ )
+
+ for key, dag_run in found.items():
+ entity = entity_map[key]
+ fields_to_update = entity.model_fields_set
+ if update_mask:
+ fields_to_update =
fields_to_update.intersection(update_mask)
+ fields_to_update = fields_to_update - {"dag_id", "dag_run_id"}
+ if not fields_to_update:
+ continue
+
+ try:
+ with self.session.begin_nested():
+ dag = get_dag_for_run(self.dag_bag, dag_run,
session=self.session)
+ if "state" in fields_to_update and entity.state is not
None:
+ _apply_state_change(dag_run, entity.state, dag,
self.session)
+ if "note" in fields_to_update:
+ refreshed = self.session.get(DagRun, dag_run.id)
+ if refreshed is not None:
+ _apply_note(refreshed, entity.note,
self.user.get_id())
+ except HTTPException as exc:
+ results.errors.append(
+ {
+ "error": str(exc.detail),
+ "status_code": exc.status_code,
+ "dag_id": key[0],
+ "dag_run_id": key[1],
+ }
+ )
+ continue
+ except Exception as exc:
+ results.errors.append(
+ {
+ "error": str(exc),
+ "status_code":
status.HTTP_500_INTERNAL_SERVER_ERROR,
+ "dag_id": key[0],
+ "dag_run_id": key[1],
+ }
+ )
+ continue
+
+ results.success.append(_format_dag_run_key(*key))
+ except HTTPException as e:
+ results.errors.append({"error": f"{e.detail}", "status_code":
e.status_code})
+
+ def handle_bulk_delete(
+ self, action: BulkDeleteAction[BulkDagRunBody], results:
BulkActionResponse
+ ) -> None:
+ """Bulk delete Dag runs."""
+ auth_cache: dict[str, bool] = {}
+ keys: set[tuple[str, str]] = set()
+
+ for entity in action.entities:
+ dag_id, dag_run_id = self._resolve_identifiers(entity)
+ entity_dag_id = entity.dag_id if isinstance(entity,
BulkDagRunBody) else None
+ if not _validate_no_wildcard_in_resolved(dag_id=dag_id,
dag_run_id=dag_run_id, results=results):
+ continue
+ if not _validate_path_dag_id_match(
+ path_dag_id=self.dag_id,
+ entity_dag_id=entity_dag_id,
+ dag_run_id=dag_run_id,
+ results=results,
+ ):
+ continue
+ if not self._check_dag_authorization(dag_id, "DELETE",
action.action.value, results, auth_cache):
+ continue
+ keys.add((dag_id, dag_run_id))
+
+ try:
+ found, not_found = self._fetch_dag_runs(keys)
+
+ if action.action_on_non_existence == BulkActionNotOnExistence.FAIL
and not_found:
+ missing = [{"dag_id": d, "dag_run_id": r} for d, r in
not_found]
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"The Dag runs with these identifiers were not
found: {missing}",
+ )
+
+ deletable_states = {s.value for s in DAGRunPatchStates}
+ for key, dag_run in found.items():
+ if dag_run.state not in deletable_states:
+ results.errors.append(
+ {
+ "error": (
+ f"The DagRun with dag_id: `{dag_run.dag_id}`
and run_id: `{dag_run.run_id}` "
+ f"cannot be deleted in {dag_run.state} state"
+ ),
+ "status_code": status.HTTP_409_CONFLICT,
+ "dag_id": dag_run.dag_id,
+ "dag_run_id": dag_run.run_id,
+ }
+ )
+ continue
+ self.session.delete(dag_run)
+ results.success.append(_format_dag_run_key(*key))
+ except HTTPException as e:
+ results.errors.append({"error": f"{e.detail}", "status_code":
e.status_code})
+
+
+def bulk_clear_dag_runs(
+ body: BulkClearDagRunsBody,
+ dag_id: str,
+ dag_bag: DagBagDep,
+ session: Session,
+ user: GetUserDep,
+) -> BulkActionResponse:
+ """
+ Run ``dag.clear()`` for each ``(dag_id, dag_run_id)`` in ``body.runs``
within a single transaction.
+
+ Returns ``BulkActionResponse`` with per-run success keys and per-run
failure entries so that a partial
+ failure does not abort the entire batch.
+ """
+ results = BulkActionResponse()
+ auth_cache: dict[str, bool] = {}
+
+ for identifier in body.runs:
+ run_dag_id = identifier.dag_id or dag_id
+ run_id = identifier.dag_run_id
+
+ if not _validate_no_wildcard_in_resolved(dag_id=run_dag_id,
dag_run_id=run_id, results=results):
+ continue
+ if not _validate_path_dag_id_match(
+ path_dag_id=dag_id,
+ entity_dag_id=identifier.dag_id,
+ dag_run_id=run_id,
+ results=results,
+ ):
+ continue
+
+ if not _authorize_dag_run(
+ session=session, user=user, dag_id=run_dag_id, method="PUT",
cache=auth_cache
+ ):
+ results.errors.append(
+ {
+ "error": f"User is not authorized to clear Dag runs for
DAG '{run_dag_id}'",
+ "status_code": status.HTTP_403_FORBIDDEN,
+ "dag_id": run_dag_id,
+ "dag_run_id": run_id,
+ }
+ )
+ continue
+
+ dag_run = session.scalar(
+ select(DagRun)
+ .options(joinedload(DagRun.dag_model))
+ .where(DagRun.dag_id == run_dag_id, DagRun.run_id == run_id)
Review Comment:
Done.
Extracted `_fetch_dag_runs` to a module-level helper.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]