Copilot commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3066481777


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -83,8 +83,12 @@
 from airflow.api_fastapi.core_api.security import GetUserDep, 
ReadableTIFilterDep, requires_access_dag
 from airflow.api_fastapi.core_api.services.public.task_instances import (
     BulkTaskInstanceService,
+    _collect_unique_tis,
+    _get_task_group_task_instances,
+    _patch_task_group_state,
     _patch_task_instance_note,
     _patch_task_instance_state,
+    _patch_ti_group_validate_request,
     _patch_ti_validate_request,
 )

Review Comment:
   This route module imports multiple underscore-prefixed helpers from the 
service module. Since these are now used cross-module, they’re effectively part 
of the internal API surface and the underscore convention becomes misleading. 
Consider promoting these to public helpers (drop the `_` prefix), or 
encapsulate them behind a single public service function (e.g., 
`TaskGroupInstanceService.patch(...)`) to keep routing logic thinner and reduce 
coupling.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -864,6 +868,133 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
     )
 
 
+@task_instances_router.patch(
+    "/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+    dependencies=[
+        Depends(action_logging()),
+        Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+    ],
+    operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+    dag_id: str,
+    dag_run_id: str,
+    group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskInstanceBody,
+    session: SessionDep,
+    user: GetUserDep,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+    """Update the state of all task instances in a task group."""
+    dag, tis, data = _patch_ti_group_validate_request(
+        dag_id, dag_run_id, group_id, dag_bag, body, session, update_mask
+    )
+    affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
+    for key, _ in data.items():
+        if key == "new_state":
+            updated_tis = _patch_task_group_state(
+                tis=tis,
+                dag_run_id=dag_run_id,
+                dag=dag,
+                body=body,
+                data=data,
+                session=session,
+            )

Review Comment:
   `data` can include `"new_state": None` if the client explicitly sends 
`{"new_state": null}` (or uses `update_mask=["new_state"]` with a null value). 
In that case this route will still enter the `"new_state"` branch and 
`_patch_task_group_state()` will attempt to set `state=None`, which will fail 
downstream. Add explicit validation that if `"new_state"` is present in 
`fields_to_update` / `update_mask` then `body.new_state` must be non-null 
(raise a 422 `RequestValidationError` / `HTTPException`). Alternatively, 
exclude `"new_state"` from `data` when its value is `None` so the route becomes 
a no-op consistently (like the dry-run path which checks `if body.new_state:`).



##########
airflow-core/src/airflow/serialization/definitions/dag.py:
##########
@@ -761,6 +761,100 @@ def set_task_instance_state(
             
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), 
**clear_kwargs)
         return altered
 
+    @provide_session
+    def set_multiple_task_instances_state(
+        self,
+        *,
+        task_ids_with_map_indexes: list[tuple[str, int]],
+        run_id: str | None = None,
+        state: TaskInstanceState,

Review Comment:
   `run_id` is annotated as optional, but this method unconditionally does a 
`select(DagRun.id, DagRun.logical_date).where(DagRun.run_id == run_id, ...)` 
followed by `.one()`, which will raise if `run_id` is `None` or not found. 
Since this is effectively required for correctness, make `run_id` a required 
`str` parameter (remove the default) or add an explicit guard that raises a 
clear `ValueError`/`AirflowException` when `run_id` is not provided.



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
     return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+    dag_id: str,
+    dag_run_id: str,
+    task_group_id: str,
+    dag: SerializedDAG,
+    session: Session,
+) -> list[TI]:
+    """Get all task instances in a task group for a specific DAG run."""
+    task_group = dag.task_group_dict.get(task_group_id)
+    if not task_group:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+        )
+
+    task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+    query = (
+        select(TI)
+        .where(
+            TI.dag_id == dag_id,
+            TI.run_id == dag_run_id,
+            TI.task_id.in_(task_ids),
+        )
+        .join(TI.dag_run)
+        .options(joinedload(TI.rendered_task_instance_fields))
+        .order_by(TI.task_id, TI.map_index)
+    )
+
+    group_tis = list(session.scalars(query).all())
+
+    return group_tis
+
+
+def _patch_ti_group_validate_request(
+    dag_id: str,
+    dag_run_id: str,
+    task_group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskInstanceBody,
+    session: SessionDep,
+    update_mask: list[str] | None = Query(None),
+) -> tuple[SerializedDAG, list[TI], dict]:
+    """Validate and prepare data for task group patch request."""
+    dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+    tis = _get_task_group_task_instances(dag_id, dag_run_id, task_group_id, 
dag, session)
+
+    fields_to_update = body.model_fields_set
+    if update_mask:
+        fields_to_update = fields_to_update.intersection(update_mask)
+    else:
+        try:
+            PatchTaskInstanceBody.model_validate(body)
+        except ValidationError as e:
+            raise RequestValidationError(errors=e.errors())

Review Comment:
   `_patch_ti_group_validate_request()` is not a FastAPI endpoint, so using 
`Query(None)` in its signature is misleading and couples service-layer code to 
FastAPI parameter utilities. Prefer a plain default of `None`. Also, 
`PatchTaskInstanceBody.model_validate(body)` is redundant here because `body` 
is already a validated model instance; if you need additional semantic checks 
(e.g., requiring `new_state` when present), implement explicit validation for 
the relevant fields instead of re-validating the model.



##########
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##########
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
     return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+    dag_id: str,
+    dag_run_id: str,
+    task_group_id: str,
+    dag: SerializedDAG,
+    session: Session,
+) -> list[TI]:
+    """Get all task instances in a task group for a specific DAG run."""
+    task_group = dag.task_group_dict.get(task_group_id)
+    if not task_group:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+        )
+
+    task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+    query = (
+        select(TI)
+        .where(
+            TI.dag_id == dag_id,
+            TI.run_id == dag_run_id,
+            TI.task_id.in_(task_ids),
+        )
+        .join(TI.dag_run)

Review Comment:
   The `.join(TI.dag_run)` appears unused here since no columns/filters from 
`DagRun` are referenced (and `TI.dag_id` + `TI.run_id` already constrain the 
result). Dropping the join would simplify the generated SQL and may reduce 
planner work, especially for large groups; if the join is needed for a subtle 
reason (e.g., enforcing an inner join to exclude orphaned TIs), consider adding 
an inline comment explaining that.
   ```suggestion
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to