arjav1528 commented on code in PR #60161:
URL: https://github.com/apache/airflow/pull/60161#discussion_r2679276482


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -1008,6 +1018,248 @@ def patch_task_instance(
     )
 
 
+@task_instances_router.patch(
+    task_instances_prefix + "/groups/{task_group_id}/dry_run",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST],
+    ),
+    dependencies=[Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE))],
+    operation_id="patch_task_group_dry_run",
+)
+def patch_task_group_dry_run(
+    dag_id: str,
+    dag_run_id: str,
+    task_group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskInstanceBody,
+    session: SessionDep,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+    """Update task instances in a task group (dry_run mode)."""
+    dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+
+    # Get all task instances in the task group
+    group_tis = _get_task_group_task_instances(
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        task_group_id=task_group_id,
+        dag=dag,
+        session=session,
+    )
+
+    if not group_tis:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            f"No task instances found for task group '{task_group_id}' in DAG 
'{dag_id}' and run '{dag_run_id}'",
+        )
+
+    # Validate request body
+    fields_to_update = body.model_fields_set
+    if update_mask:
+        fields_to_update = fields_to_update.intersection(update_mask)
+    else:
+        try:
+            PatchTaskInstanceBody.model_validate(body)
+        except ValidationError as e:
+            raise RequestValidationError(errors=e.errors())
+
+    data = body.model_dump(include=fields_to_update, by_alias=True)
+
+    # Collect all affected task instances (including 
upstream/downstream/future/past)
+    all_affected_tis: list[TI] = []
+    seen_ti_keys: set[tuple[str, str, str, int]] = set()
+
+    if data.get("new_state"):
+        # For each task in the group, simulate state change
+        for ti in group_tis:
+            affected_tis = (
+                dag.set_task_instance_state(
+                    task_id=ti.task_id,
+                    run_id=dag_run_id,
+                    map_indexes=[ti.map_index] if ti.map_index is not None 
else None,
+                    state=data["new_state"],
+                    upstream=body.include_upstream or False,
+                    downstream=body.include_downstream or False,
+                    future=body.include_future or False,
+                    past=body.include_past or False,
+                    commit=False,
+                    session=session,
+                )
+                or []
+            )
+
+            # Add unique task instances
+            for affected_ti in affected_tis:
+                ti_key = (
+                    affected_ti.dag_id,
+                    affected_ti.run_id,
+                    affected_ti.task_id,
+                    affected_ti.map_index if affected_ti.map_index is not None 
else -1,
+                )
+                if ti_key not in seen_ti_keys:
+                    seen_ti_keys.add(ti_key)
+                    all_affected_tis.append(affected_ti)
+    else:
+        # If no state change, just return the group task instances
+        all_affected_tis = group_tis
+
+    return TaskInstanceCollectionResponse(
+        task_instances=[TaskInstanceResponse.model_validate(ti) for ti in 
all_affected_tis],
+        total_entries=len(all_affected_tis),
+    )
+
+
+@task_instances_router.patch(
+    task_instances_prefix + "/groups/{task_group_id}",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+    dependencies=[
+        Depends(action_logging()),
+        Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+    ],
+    operation_id="patch_task_group",
+)
+def patch_task_group(
+    dag_id: str,
+    dag_run_id: str,
+    task_group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskInstanceBody,
+    user: GetUserDep,
+    session: SessionDep,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+    """Update task instances in a task group."""
+    dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+
+    # Get all task instances in the task group
+    group_tis = _get_task_group_task_instances(
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        task_group_id=task_group_id,
+        dag=dag,
+        session=session,
+    )
+
+    if not group_tis:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            f"No task instances found for task group '{task_group_id}' in DAG 
'{dag_id}' and run '{dag_run_id}'",
+        )
+
+    # Validate request body
+    fields_to_update = body.model_fields_set
+    if update_mask:
+        fields_to_update = fields_to_update.intersection(update_mask)
+    else:
+        try:
+            PatchTaskInstanceBody.model_validate(body)
+        except ValidationError as e:
+            raise RequestValidationError(errors=e.errors())
+
+    data = body.model_dump(include=fields_to_update, by_alias=True)
+
+    # Collect all affected task instances (including 
upstream/downstream/future/past)
+    all_affected_tis: list[TI] = []
+    seen_ti_keys: set[tuple[str, str, str, int]] = set()
+
+    # Process each task in the group
+    for ti in group_tis:
+        # Update state if requested
+        if data.get("new_state"):
+            map_indexes = [ti.map_index] if ti.map_index is not None else None
+
+            # Update state and get all affected TIs (including 
upstream/downstream/future/past)
+            updated_tis = dag.set_task_instance_state(
+                task_id=ti.task_id,
+                run_id=dag_run_id,
+                map_indexes=map_indexes,
+                state=data["new_state"],
+                upstream=body.include_upstream or False,
+                downstream=body.include_downstream or False,
+                future=body.include_future or False,
+                past=body.include_past or False,
+                commit=True,
+                session=session,
+            )
+
+            if not updated_tis:
+                raise HTTPException(
+                    status.HTTP_409_CONFLICT,
+                    f"Task id {ti.task_id} is already in {data['new_state']} 
state",
+                )
+
+            # Track unique affected TIs and trigger listeners
+            for updated_ti in updated_tis:
+                ti_key = (
+                    updated_ti.dag_id,
+                    updated_ti.run_id,
+                    updated_ti.task_id,
+                    updated_ti.map_index if updated_ti.map_index is not None 
else -1,
+                )
+                if ti_key not in seen_ti_keys:
+                    seen_ti_keys.add(ti_key)
+                    all_affected_tis.append(updated_ti)
+
+                    # Trigger listeners
+                    try:
+                        if data["new_state"] == TaskInstanceState.SUCCESS:
+                            
get_listener_manager().hook.on_task_instance_success(
+                                previous_state=None, task_instance=updated_ti
+                            )
+                        elif data["new_state"] == TaskInstanceState.FAILED:
+                            
get_listener_manager().hook.on_task_instance_failed(
+                                previous_state=None,
+                                task_instance=updated_ti,
+                                error=f"TaskInstance's state was manually set 
to `{TaskInstanceState.FAILED}`.",
+                            )
+                    except Exception:
+                        log.exception("error calling listener")
+
+        # Update note if requested
+        if data.get("note") is not None:
+            _patch_task_instance_note(
+                task_instance_body=body,
+                tis=[ti],
+                user=user,
+                update_mask=update_mask,
+            )
+
+    # If we didn't collect affected TIs from state changes, use the group TIs
+    # (which may have had notes updated)
+    if not all_affected_tis:
+        all_affected_tis = group_tis
+    else:
+        # Refresh the affected TIs from the database to get the latest state 
and notes
+        ti_keys_list = list(seen_ti_keys)
+        refreshed_tis = session.scalars(
+            select(TI)
+            .where(tuple_(TI.dag_id, TI.run_id, TI.task_id, 
TI.map_index).in_(ti_keys_list))
+            .options(joinedload(TI.rendered_task_instance_fields))
+        ).all()
+        all_affected_tis = list(refreshed_tis)
+
+        # Also include group TIs that had notes updated but weren't in the 
state change list
+        # (to avoid duplicates, only add TIs not already in seen_ti_keys)
+        if data.get("note") is not None:
+            for ti in group_tis:
+                ti_key = (
+                    ti.dag_id,
+                    ti.run_id,
+                    ti.task_id,
+                    ti.map_index if ti.map_index is not None else -1,
+                )
+                if ti_key not in seen_ti_keys:
+                    seen_ti_keys.add(ti_key)
+                    all_affected_tis.append(ti)

Review Comment:
   @jason810496 I have applied all the changes, do review them, and merge as 
per your availability after the CI checks
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to