jason810496 commented on code in PR #60161:
URL: https://github.com/apache/airflow/pull/60161#discussion_r2701036546


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -993,6 +1075,28 @@ def patch_task_instance(
                 update_mask=update_mask,
             )
 
+    # For task groups, refresh affected TIs from the database
+    if task_group_id is not None and affected_tis_dict:
+        ti_keys_list = list(affected_tis_dict.keys())
+        refreshed_tis = session.scalars(
+            select(TI)
+            .where(tuple_(TI.dag_id, TI.run_id, TI.task_id, 
TI.map_index).in_(ti_keys_list))
+            .options(joinedload(TI.rendered_task_instance_fields))
+        ).all()
+        # Update dict with refreshed TIs
+        for refreshed_ti in refreshed_tis:
+            ti_key = (
+                refreshed_ti.dag_id,
+                refreshed_ti.run_id,
+                refreshed_ti.task_id,
+                refreshed_ti.map_index if refreshed_ti.map_index is not None 
else -1,
+            )
+            affected_tis_dict[ti_key] = refreshed_ti
+        return TaskInstanceCollectionResponse(
+            task_instances=[TaskInstanceResponse.model_validate(ti) for ti in 
affected_tis_dict.values()],
+            total_entries=len(affected_tis_dict),
+        )
+

Review Comment:
   Same here, all the handling can be removed.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -955,35 +997,75 @@ def patch_task_instance(
     user: GetUserDep,
     session: SessionDep,
     map_index: int | None = None,
+    task_group_id: str | None = Query(None),
     update_mask: list[str] | None = Query(None),
 ) -> TaskInstanceCollectionResponse:
     """Update a task instance."""
     dag, tis, data = _patch_ti_validate_request(
-        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask
+        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask, task_group_id
     )
 
+    # Collect all affected task instances (including 
upstream/downstream/future/past)
+    # Use dict to track unique task instances by key
+    affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
     for key, _ in data.items():
         if key == "new_state":
-            # Create BulkTaskInstanceBody object with map_index field
-            bulk_ti_body = BulkTaskInstanceBody(
-                task_id=task_id,
-                map_index=map_index,
-                new_state=body.new_state,
-                note=body.note,
-                include_upstream=body.include_upstream,
-                include_downstream=body.include_downstream,
-                include_future=body.include_future,
-                include_past=body.include_past,
-            )
+            if task_group_id is not None:
+                # For task group: iterate over each task instance in the group
+                for ti in tis:
+                    # Create BulkTaskInstanceBody object with map_index field
+                    bulk_ti_body = BulkTaskInstanceBody(
+                        task_id=ti.task_id,
+                        map_index=ti.map_index,
+                        new_state=body.new_state,
+                        note=body.note,
+                        include_upstream=body.include_upstream,
+                        include_downstream=body.include_downstream,

Review Comment:
   Same here, we already generalized the logic for getting TIs in 
`_patch_ti_validate_request`, so we can remove all the logic for task group.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -871,14 +871,50 @@ def patch_task_instance_dry_run(
     body: PatchTaskInstanceBody,
     session: SessionDep,
     map_index: int | None = None,
+    task_group_id: str | None = Query(None),
     update_mask: list[str] | None = Query(None),
 ) -> TaskInstanceCollectionResponse:
     """Update a task instance dry_run mode."""
     dag, tis, data = _patch_ti_validate_request(
-        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask
+        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask, task_group_id
     )
 
     if data.get("new_state"):
+        if task_group_id is not None:
+            # For task group: iterate over each task instance and collect 
affected TIs
+            affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+            for ti in tis:
+                affected_tis = (
+                    dag.set_task_instance_state(
+                        task_id=ti.task_id,
+                        run_id=dag_run_id,
+                        map_indexes=[ti.map_index] if ti.map_index is not None 
else None,
+                        state=data["new_state"],
+                        upstream=body.include_upstream or False,
+                        downstream=body.include_downstream or False,
+                        future=body.include_future or False,
+                        past=body.include_past or False,
+                        commit=False,
+                        session=session,
+                    )
+                    or []
+                )
+
+                # Add unique task instances
+                for affected_ti in affected_tis:
+                    ti_key = (
+                        affected_ti.dag_id,
+                        affected_ti.run_id,
+                        affected_ti.task_id,
+                        affected_ti.map_index if affected_ti.map_index is not 
None else -1,
+                    )
+                    affected_tis_dict[ti_key] = affected_ti
+
+            return TaskInstanceCollectionResponse(
+                task_instances=[TaskInstanceResponse.model_validate(ti) for ti 
in affected_tis_dict.values()],
+                total_entries=len(affected_tis_dict),
+            )
+        # For regular task instance: use original behavior

Review Comment:
   We had already handled task group logic in `_patch_ti_validate_request`. So 
we can remove all this part, as the task group's TIs can be handled as regular 
TIs. 
   ```suggestion
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -871,14 +871,50 @@ def patch_task_instance_dry_run(
     body: PatchTaskInstanceBody,
     session: SessionDep,
     map_index: int | None = None,
+    task_group_id: str | None = Query(None),

Review Comment:
   We need to define the route decorator for `task_group_id` before line 866 as:
   ```python
   @task_instances_router.patch(
       task_instances_prefix + "/{task_group_id}/dry_run",
       responses=create_openapi_http_exception_doc(
           [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST],
       ),
       dependencies=[Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE))],
       operation_id="patch_task_group_dry_run",
   )
   ```
   
   so do all the other routes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to