jason810496 commented on code in PR #60161:
URL: https://github.com/apache/airflow/pull/60161#discussion_r2702205865


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -921,11 +921,17 @@ def bulk_task_instances(
     dag_bag: DagBagDep,
     dag_run_id: str,
     user: GetUserDep,
-) -> BulkResponse:
+    dry_run: bool = Query(
+        False, description="If True, return affected task instances without 
making changes"
+    ),
+) -> BulkResponse | TaskInstanceCollectionResponse:

Review Comment:
   > We can't do this because now the client of the API doesn't know what type 
of structure he's getting, and it will force a lot of if / else check code on 
the client side. (the API here).
   
   It seems we still have the same problem as Pierre described here. Instead of 
having `dry_run` query parameter, let's add a new `bulk_task_instances_dry_run` 
route to minimize the client side behavior change.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -866,33 +874,55 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
 def patch_task_instance_dry_run(
     dag_id: str,
     dag_run_id: str,
-    task_id: str,
     dag_bag: DagBagDep,
     body: PatchTaskInstanceBody,
     session: SessionDep,
+    task_id: str | None = None,
+    task_group_id: str | None = None,
     map_index: int | None = None,
     update_mask: list[str] | None = Query(None),
 ) -> TaskInstanceCollectionResponse:
     """Update a task instance dry_run mode."""
     dag, tis, data = _patch_ti_validate_request(
-        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask
+        dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, 
update_mask, task_group_id
     )
 
     if data.get("new_state"):
-        tis = (
-            dag.set_task_instance_state(
-                task_id=task_id,
-                run_id=dag_run_id,
-                map_indexes=[map_index] if map_index is not None else None,
-                state=data["new_state"],
-                upstream=body.include_upstream,
-                downstream=body.include_downstream,
-                future=body.include_future,
-                past=body.include_past,
-                commit=False,
-                session=session,
+        # Use dict to track unique affected task instances
+        affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
+        # Iterate over all task instances - works for both single TI and task 
groups
+        # since _patch_ti_validate_request already returns the appropriate TIs
+        for ti in tis:
+            affected_tis = (
+                dag.set_task_instance_state(
+                    task_id=ti.task_id,
+                    run_id=dag_run_id,
+                    map_indexes=[ti.map_index] if ti.map_index is not None 
else None,
+                    state=data["new_state"],
+                    upstream=body.include_upstream or False,
+                    downstream=body.include_downstream or False,
+                    future=body.include_future or False,
+                    past=body.include_past or False,
+                    commit=False,
+                    session=session,
+                )
+                or []
             )
-            or []
+
+            # Add unique task instances
+            for affected_ti in affected_tis:
+                ti_key = (
+                    affected_ti.dag_id,
+                    affected_ti.run_id,
+                    affected_ti.task_id,
+                    affected_ti.map_index if affected_ti.map_index is not None 
else -1,
+                )
+                affected_tis_dict[ti_key] = affected_ti

Review Comment:
   We could add another common helper `_collect_unique_tis(affected_tis_dict: 
dict[tuple[str, str, str, int], TI], affected_tis: Iterable[TI]`.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -993,6 +1057,12 @@ def patch_task_instance(
                 update_mask=update_mask,
             )
 
+    if affected_tis_dict:
+        return TaskInstanceCollectionResponse(
+            task_instances=[TaskInstanceResponse.model_validate(ti) for ti in 
affected_tis_dict.values()],
+            total_entries=len(affected_tis_dict),
+        )
+

Review Comment:
   
   ```suggestion
       return TaskInstanceCollectionResponse(
           task_instances=[TaskInstanceResponse.model_validate(ti) for ti in 
affected_tis_dict.values()],
           total_entries=len(affected_tis_dict),
       )
   
   ```
   
   After we generalize patch logic for note, we can directly return from  
`affected_tis_dict`.
   



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -993,6 +1057,12 @@ def patch_task_instance(
                 update_mask=update_mask,
             )

Review Comment:
   Shouldn't we update the `affected_tis_dict` for note as well?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to