Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-21 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3116700513


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
 return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.order_by(TI.task_id, TI.map_index)
+)
+
+group_tis = list(session.scalars(query).all())
+
+return group_tis
+
+
+def _patch_ti_group_validate_request(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+update_mask: list[str] | None = Query(None),
+) -> tuple[SerializedDAG, list[TI], dict]:
+"""Validate and prepare data for task group patch request."""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+tis = _get_task_group_task_instances(dag_id, dag_run_id, task_group_id, 
dag, session)
+
+fields_to_update = body.model_fields_set
+if update_mask:
+fields_to_update = fields_to_update.intersection(update_mask)
+else:
+try:
+PatchTaskInstanceBody.model_validate(body)
+except ValidationError as e:
+raise RequestValidationError(errors=e.errors())
+
+return dag, tis, body.model_dump(include=fields_to_update, by_alias=True)

Review Comment:
   I took it out of both to prevent duplicate logic, but I'm not entirely sure 
if that's what you meant. Let me know πŸ‘ 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-21 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3116571531


##
airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml:
##
@@ -7348,6 +7348,166 @@ paths:
 application/json:
   schema:
 $ref: '#/components/schemas/HTTPValidationError'
+  /api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}:
+patch:
+  tags:
+  - Task Instance
+  summary: Patch Task Group Instances
+  description: Update the state of all task instances in a task group.
+  operationId: patch_task_group_instances

Review Comment:
   Should be good, the 409 is only registered for the non dry-run, as there are 
no changes committed in the dry run.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-20 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265585


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
 return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.order_by(TI.task_id, TI.map_index)
+)
+
+group_tis = list(session.scalars(query).all())
+
+return group_tis
+
+
+def _patch_ti_group_validate_request(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+update_mask: list[str] | None = Query(None),

Review Comment:
   ```suggestion
   update_mask: list[str] | None = None,
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-20 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265596


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -145,6 +214,53 @@ def _patch_task_instance_state(
 except Exception:
 log.exception("error calling listener")
 
+return updated_tis
+
+
+def _patch_task_group_state(
+group_id: str,
+dag_run_id: str,
+dag: SerializedDAG,
+body: PatchTaskInstanceBody,
+data: dict,
+*,
+session: Session,
+) -> list[TI]:
+"""Update the state of all task instances in a task group."""
+updated_tis = dag.set_task_group_state(
+group_id=group_id,
+run_id=dag_run_id,
+state=data["new_state"],
+upstream=body.include_upstream,
+downstream=body.include_downstream,
+future=body.include_future,
+past=body.include_past,
+commit=True,

Review Comment:
   Nit: `commit=True` is hardcoded here but the helper could easily take it as 
a param. Not critical, but would make this reusable for a dry-run path later 
without duplicating the function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-20 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265581


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
 return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))

Review Comment:
   `joinedload(TI.rendered_task_instance_fields)` is applied here AND 
re-queried with `populate_existing=True` in the route after 
`set_task_group_state` returns. One of them is wasted β€” probably this one since 
the route always re-queries. Can we drop it here?



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
 return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not 
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)

Review Comment:
   Is this join needed? The `dag_id` + `run_id` filters already constrain the 
result. I see `_patch_ti_validate_request` does the same thing β€” pre-existing 
but worth cleaning up in both places while we're here.



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -56,6 +56,17 @@
 log = structlog.get_logger(__name__)
 
 
+def _collect_unique_tis(
+affected_tis_dict: dict[tuple[str, str, str, int], TI],
+tis: list[TI] | None,
+) -> None:
+"""Collect unique task instances into a dictionary keyed by (dag_id, 
run_id, task_id, map_index)."""
+if tis:
+for ti in tis:
+key = (ti.dag_id, ti.run_id, ti.task_id, ti.map_index)
+affected_tis_dict[key] = ti

Review Comment:
   Do we really need this? `set_task_group_state` already returns unique TIs 
(the underlying `set_state` dedupes), and the `note` branch only touches TIs 
already in the group. Feels defensive for no real case.



##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -6477,3 +6478,383 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):

Review Comment:
   Most tests in this class mock `set_task_group_state`, which means we're only 
validating the route/service glue. The real v2.x-equivalent logic we just added 
in `SerializedDAG.set_task_group_state` isn't exercised. Can we add unmocked 
tests for:
   - `include_upstream=True` / `include_downstream=True` actually affecting 
up/downstream TIs
   - `past=True` / `future=True` covering the right runs (and the 
`exclude_run_ids` logic)
   - The `clear()` of failed downstream actually resuming them
   - The 409 path (all TIs already in target state)
   
   The existing `test_patch_task_group_updates_ti_states_in_db` and the 
query-count tests are exactly the right pattern β€” just more of those.



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -145,6 +214,53 @@ def _patch_task_instance_state(
 except Exception:
 log.exception("error calling listener")
 
+return updated_tis
+
+
+def _patch_task_group_state(
+group_id: str,
+dag_run_id: str,
+dag: SerializedDAG,
+body: PatchTaskInstanceBody,
+data: dict,
+*,
+  

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-20 Thread via GitHub


pierrejeambrun commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4281454132

   ### Concerns
   
   1. **Most tests mock `set_task_group_state`** β€” 
`test_patch_task_group_success`, `_failed_state`, `_nested`, 
`_includes_upstream_downstream_parameters`, 
`_dry_run_returns_affected_tis_without_committing` all mock the core logic. 
This means the tests verify the route/service glue but NOT that the 
v2.x-equivalent logic in `SerializedDAG.set_task_group_state` is correct. The 
query-count tests (`test_query_count_does_not_scale_with_task_group_size`, 
`test_dry_run_query_count_does_not_scale`) and 
`test_patch_task_group_updates_ti_states_in_db` don't mock and are the real 
regression guards β€” keep those, but add more unmocked tests for:
  - `include_upstream=True` / `include_downstream=True` β€” verify actual 
upstream/downstream TIs get their state set
  - `past=True` / `future=True` β€” verify runs before/after the target are 
affected AND the `exclude_run_ids` logic works
  - Empty or no-op group (409 path)
  - The `clear()` of failed downstream tasks actually resumes them
   
   2. **`_collect_unique_tis` may be unnecessary** β€” the key `(dag_id, run_id, 
task_id, map_index)` isn't necessary if `set_task_group_state` already returns 
unique TIs (the underlying `set_state` dedupes). Only the `"note"` branch adds 
the tis separately, but those are all within the group too β€” so the dedup 
helper is defensive rather than necessary. Could be simplified.
   
   3. **`_patch_ti_group_validate_request` is near-duplicate of 
`_patch_ti_validate_request`** β€” differs only in replacing 
`task_id`/`map_index` with `task_group_id` and calling 
`_get_task_group_task_instances`. Consider parameterizing or extracting the 
common validation logic.
   
   4. **`_get_task_group_task_instances` uses `.join(TI.dag_run)` with no 
DagRun filter** β€” Copilot flagged this too. Since `TI.dag_id == dag_id, 
TI.run_id == run_id` already constrain the result, the join is redundant. 
Remove or justify with a comment (e.g., "inner join enforces the DagRun 
exists"). Also `joinedload(TI.rendered_task_instance_fields)` is applied here 
AND again in the route (`patch_task_group_instances` does a re-query with 
`populate_existing=True`). Looks like duplicated eager-load work β€” verify the 
first one is actually used, otherwise drop it.
   
   5. **Conflict on "already in target state"** (`_patch_task_group_state`)
  ```python
  if not updated_tis:
  raise HTTPException(status.HTTP_409_CONFLICT, ...)
  ```
  The single-TI endpoint returns an empty list instead of 409 (I argued 
against 404 earlier). Inconsistent behavior between "single TI already in 
state" (empty 200) and "all group TIs already in state" (409). Align them: 
return empty list 200.
   
   6. **`dag.set_task_group_state` in the service is called with `commit=True`, 
but the route function also calls it via `_patch_task_group_state` which also 
defaults `commit=True`**. Not a bug, but if someone ever calls 
`_patch_task_group_state` expecting `commit=False`, there's no way to override 
β€” the helper hardcodes `commit=True`. Consider threading `commit` as a 
parameter for future reuse.
   
   7. **`_patch_ti_group_validate_request` uses `Query(None)`** for 
`update_mask` β€” copilot flagged this. `Query` is a FastAPI dependency marker; 
using it in a service helper is misleading (the param isn't actually populated 
from the request here since this is called internally). Use `update_mask: 
list[str] | None = None`.
   
   9. **OpenAPI spec regenerated but not verified** β€” 
`v2-rest-api-generated.yaml` has 162 additions. Review the spec diff for 
correct request/response schemas, especially the 409 response added per point 5 
above.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-19 Thread via GitHub


OscarLigthart commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4275693100

   Hi @pierrejeambrun, I had to make some tweaks to make the 
`set_task_group_state` work as it was found in Airflow 2.x. Let me know if this 
is going in the right direction. 
   
   As an alternative I have the previous to last commit, which has the original 
implementation with your initial feedback. This solution would cause a steeper 
query count growth with scale, but doesn't add the logic in the DAG 
serialization just yet.
   
   Let me know how you would like to proceed! 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-16 Thread via GitHub


OscarLigthart commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4260172373

   > @OscarLigthart I replied to your main concern there [#62812 
(comment)](https://github.com/apache/airflow/pull/62812#discussion_r3063329169) 
just in case you missed it
   
   Hey @pierrejeambrun, sorry for the radio silence on my side, had a few crazy 
days. I will find some time over the weekend to look through your suggestions 
and apply the necessary changes πŸ‘ 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-15 Thread via GitHub


pierrejeambrun commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4250536838

   @OscarLigthart I replied to your main concern there 
https://github.com/apache/airflow/pull/62812#discussion_r3063329169 just in 
case you missed it 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


Copilot commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3066481777


##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -83,8 +83,12 @@
 from airflow.api_fastapi.core_api.security import GetUserDep, 
ReadableTIFilterDep, requires_access_dag
 from airflow.api_fastapi.core_api.services.public.task_instances import (
 BulkTaskInstanceService,
+_collect_unique_tis,
+_get_task_group_task_instances,
+_patch_task_group_state,
 _patch_task_instance_note,
 _patch_task_instance_state,
+_patch_ti_group_validate_request,
 _patch_ti_validate_request,
 )

Review Comment:
   This route module imports multiple underscore-prefixed helpers from the 
service module. Since these are now used cross-module, they’re effectively part 
of the internal API surface and the underscore convention becomes misleading. 
Consider promoting these to public helpers (drop the `_` prefix), or 
encapsulate them behind a single public service function (e.g., 
`TaskGroupInstanceService.patch(...)`) to keep routing logic thinner and reduce 
coupling.



##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -864,6 +868,133 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
 )
 
 
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+user: GetUserDep,
+update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+dag, tis, data = _patch_ti_group_validate_request(
+dag_id, dag_run_id, group_id, dag_bag, body, session, update_mask
+)
+affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
+for key, _ in data.items():
+if key == "new_state":
+updated_tis = _patch_task_group_state(
+tis=tis,
+dag_run_id=dag_run_id,
+dag=dag,
+body=body,
+data=data,
+session=session,
+)

Review Comment:
   `data` can include `"new_state": None` if the client explicitly sends 
`{"new_state": null}` (or uses `update_mask=["new_state"]` with a null value). 
In that case this route will still enter the `"new_state"` branch and 
`_patch_task_group_state()` will attempt to set `state=None`, which will fail 
downstream. Add explicit validation that if `"new_state"` is present in 
`fields_to_update` / `update_mask` then `body.new_state` must be non-null 
(raise a 422 `RequestValidationError` / `HTTPException`). Alternatively, 
exclude `"new_state"` from `data` when its value is `None` so the route becomes 
a no-op consistently (like the dry-run path which checks `if body.new_state:`).



##
airflow-core/src/airflow/serialization/definitions/dag.py:
##
@@ -761,6 +761,100 @@ def set_task_instance_state(
 
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), 
**clear_kwargs)
 return altered
 
+@provide_session
+def set_multiple_task_instances_state(
+self,
+*,
+task_ids_with_map_indexes: list[tuple[str, int]],
+run_id: str | None = None,
+state: TaskInstanceState,

Review Comment:
   `run_id` is annotated as optional, but this method unconditionally does a 
`select(DagRun.id, DagRun.logical_date).where(DagRun.run_id == run_id, ...)` 
followed by `.one()`, which will raise if `run_id` is `None` or not found. 
Since this is effectively required for correctness, make `run_id` a required 
`str` parameter (remove the default) or add an explicit guard that raises a 
clear `ValueError`/`AirflowException` when `run_id` is not provided.



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
 return dag, list(tis), body.model_dump(include=fields_to_update, 
by_alias=True)
 
 
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+statu

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063329169


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063329169


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-10 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-03 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-03 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-03 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-03 Thread via GitHub


OscarLigthart commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4185146488

   Thanks for reopening and reviewing @pierrejeambrun ! I've addressed most of 
your comments. I got stuck again on the N+1 query problem after removing the 
mock. I've consulted with Claude and left a trace of the resulting conversation 
(can't believe I'm already calling this "conversations" with my agent, what a 
time to be alive).
   
   Let me know how you would like to proceed!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-02 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027386658


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a nested task group includes tasks from inner 
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo

[PR] feat: implement patching of task group instances in API [airflow]

2026-04-02 Thread via GitHub


OscarLigthart opened a new pull request, #62812:
URL: https://github.com/apache/airflow/pull/62812


   
   
   
   ## Context
   In Airflow 3 the ability to mark a full task group as failed or success is 
currently missing. In this PR, I try to implement the logic into the API, that 
can then be called from the frontend to ensure the functionality returns.
   
   I deliberately split the feature in two separate PRs, so the reviewing 
process can be more targeted to the individually touched components. Should 
this PR get merged, I will continue to build the remaining requirements into 
the frontend.
   
   There is an already open PR here: #60161
   
   However, it looks to be stale, and I would love to get this feature into the 
Airflow 3.2 release.
   
   ## Implementation
   
   I make use of the BulkTaskInstanceService to perform the state updates with 
this endpoint. Using it, the implementation should be pretty straightforward :).
   
   ## Issues
   
   related: #56103 
   
   ---
   
   # Was generative AI tooling used to co-author this PR?
   
   
   
   - [x] Yes (please specify the tool below)
   
   
   Generated-by: Claude Opus following [the 
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions)
   
   
   ---
   
   * Read the **[Pull Request 
Guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-guidelines)**
 for more information. Note: commit author/co-author name and email in commits 
become permanently public when merged.
   * For fundamental code changes, an Airflow Improvement Proposal 
([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals))
 is needed.
   * When adding dependency, check compliance with the [ASF 3rd Party License 
Policy](https://www.apache.org/legal/resolved.html#category-x).
   * For significant user-facing changes create newsfragment: 
`{pr_number}.significant.rst` or `{issue_number}.significant.rst`, in 
[airflow-core/newsfragments](https://github.com/apache/airflow/tree/main/airflow-core/newsfragments).
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-02 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027370557


##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"

Review Comment:
   Same here assert the API response.



##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
 def test_should_respond_422(self, test_client):
 response = test_client.patch(self.ENDPOINT_URL, json={})
 assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+"""Test that patching a task group sets state for all tasks in the 
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_s

Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-01 Thread via GitHub


potiuk commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4169717356

   This pull request has been converted to draft due to quality issues more 
than a week ago and there has been no response from the author since then. We 
are closing it for now to keep the review queue manageable.
   
   **@OscarLigthart**, you are welcome to reopen this PR after addressing the 
review comments. Thank you for your contribution!
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-04-01 Thread via GitHub


potiuk closed pull request #62812: feat: implement patching of task group 
instances in API
URL: https://github.com/apache/airflow/pull/62812


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-24 Thread via GitHub


potiuk commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4120213722

   @OscarLigthart This PR has been converted to **draft** because it does not 
yet meet our [Pull Request quality 
criteria](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-quality-criteria).
   
   **Issues found:**
   - :x: **Merge conflicts**: This PR has merge conflicts with the `main` 
branch. Your branch is 237 commits behind `main`. Please rebase your branch 
(`git fetch origin && git rebase origin/main`), resolve the conflicts, and push 
again. See [contributing quick 
start](https://github.com/apache/airflow/blob/main/contributing-docs/03a_contributors_quick_start_beginners.rst).
   
   > **Note:** Your branch is **237 commits behind `main`**. Some check 
failures may be caused by changes in the base branch rather than by your PR. 
Please rebase your branch and push again to get up-to-date CI results.
   
   **What to do next:**
   - The comment informs you what you need to do.
   - Fix each issue, then mark the PR as "Ready for review" in the GitHub UI - 
but only after making sure that all the issues are fixed.
   - There is no rush β€” take your time and work at your own pace. We appreciate 
your contribution and are happy to wait for updates.
   - Maintainers will then proceed with a normal review.
   
   Converting a PR to draft is **not** a rejection β€” it is an invitation to 
bring the PR up to the project's standards so that maintainer review time is 
spent productively. There is no rush β€” take your time and work at your own 
pace. We appreciate your contribution and are happy to wait for updates. If you 
have questions, feel free to ask on the [Airflow 
Slack](https://s.apache.org/airflow-slack).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-11 Thread via GitHub


potiuk commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4038905256

   @OscarLigthart This PR has been converted to **draft** because it does not 
yet meet our [Pull Request quality 
criteria](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-quality-criteria).
   
   **Issues found:**
   - :x: **Merge conflicts**: This PR has merge conflicts with the `main` 
branch. Your branch is 64 commits behind `main`. Please rebase your branch 
(`git fetch origin && git rebase origin/main`), resolve the conflicts, and push 
again. See [contributing quick 
start](https://github.com/apache/airflow/blob/main/contributing-docs/03a_contributors_quick_start_beginners.rst).
   - :warning: **Unresolved review comments**: This PR has 2 unresolved review 
threads from maintainers. Please review and resolve all inline review comments 
before requesting another review. You can resolve a conversation by clicking 
'Resolve conversation' on each thread after addressing the feedback. See [pull 
request 
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst).
   
   > **Note:** Your branch is **64 commits behind `main`**. Some check failures 
may be caused by changes in the base branch rather than by your PR. Please 
rebase your branch and push again to get up-to-date CI results.
   
   **What to do next:**
   - The comment informs you what you need to do.
   - Fix each issue, then mark the PR as "Ready for review" in the GitHub UI - 
but only after making sure that all the issues are fixed.
   - Maintainers will then proceed with a normal review.
   
   Converting a PR to draft is **not** a rejection β€” it is an invitation to 
bring the PR up to the project's standards so that maintainer review time is 
spent productively. If you have questions, feel free to ask on the [Airflow 
Slack](https://s.apache.org/airflow-slack).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


OscarLigthart commented on PR #62812:
URL: https://github.com/apache/airflow/pull/62812#issuecomment-4027342056

   > Thanks for the PR.
   > 
   > Can you add query guards to tests (`assert_queries_count`) . I'm afraid 
this will generate N+1 queries problem -> Number of db request will scale 
linearly with the number of TIs in the group. And we should rework the code not 
do do that.
   > 
   > Also I find the code overly complicated with numerous duplicated call made 
at different abstraction level because of function nesting, it makes the whole 
thing hard to understand and probably sub optimal.
   
   Thanks for the detailed review! I think the points made a lot of sense. 
Apologies for this oversight on my side. I tried following the implementation 
of the other PR more, removing the BulkTaskInstanceService usage and solving 
the N+1 query problem.
   
   I still opted to keep them in separate endpoints for clarity. Let me know if 
you disagree! Also happy to cast this aside and wait for the other PR to be 
picked up again, if that is more in line with your preference.
   
   One question regarding the `assert_queries_count`, I've create multiple 
cases where they remain the same, effectively eliminating the query problem, 
but I'm guessing there's many things that can impact this amount. In turn, this 
would break the tests implemented here. Is there a specific rule to keep in 
mind when applying it? Or is just setting it in the way I do in the test 
correct?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2908258611


##
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##
@@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str:
 return ns
 
 
+class PatchTaskGroupBody(StrictBaseModel):
+"""Request body for patching the state of all task instances in a task 
group."""
+
+new_state: TaskInstanceState
+include_future: bool = False
+include_past: bool = False

Review Comment:
   Added back! I thought we wouldn't need them for the UI implementation, but 
better to keep them for the endpoint.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2908255963


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
 ti.task_instance_note.user_id = user.get_id()
 
 
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)
+
+return task_ids
+
+
+def _patch_task_group_state(
+*,
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+body: PatchTaskGroupBody,
+dag_bag: DagBagDep,
+user: GetUserDep,
+session: Session,
+) -> None:
+"""
+Set the state of all task instances in a task group.
+
+Uses BulkTaskInstanceService to update each task in the group.
+
+:param dag_id: The DAG ID.
+:param dag_run_id: The run_id of the DAG run.
+:param group_id: The ID of the task group.
+:param body: The request body with the new state and options.
+:param dag_bag: The DAG bag for DAG resolution.
+:param user: The authenticated user.
+:param session: The database session.
+"""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+
+entities = [
+BulkTaskInstanceBody(
+task_id=task_id,
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+new_state=body.new_state,
+include_future=body.include_future,
+include_past=body.include_past,
+)
+for task_id in task_ids
+]
+
+action = BulkUpdateAction(
+action=BulkAction.UPDATE,
+entities=entities,
+update_mask=["new_state"],
+action_on_non_existence=BulkActionNotOnExistence.SKIP,
+)
+results = BulkActionResponse()
+
+service = BulkTaskInstanceService(
+session=session,
+request=BulkBody(actions=[]),  # unused, but required by base class
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+dag_bag=dag_bag,
+user=user,
+)
+service.handle_bulk_update(action, results)

Review Comment:
   I tried taking inspiration from the other PR and follow that implementation 
more. πŸ‘ 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2907088878


##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
 ti.task_instance_note.user_id = user.get_id()
 
 
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)
+
+return task_ids
+
+
+def _patch_task_group_state(
+*,
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+body: PatchTaskGroupBody,
+dag_bag: DagBagDep,
+user: GetUserDep,
+session: Session,
+) -> None:
+"""
+Set the state of all task instances in a task group.
+
+Uses BulkTaskInstanceService to update each task in the group.
+
+:param dag_id: The DAG ID.
+:param dag_run_id: The run_id of the DAG run.
+:param group_id: The ID of the task group.
+:param body: The request body with the new state and options.
+:param dag_bag: The DAG bag for DAG resolution.
+:param user: The authenticated user.
+:param session: The database session.
+"""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+
+entities = [
+BulkTaskInstanceBody(
+task_id=task_id,
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+new_state=body.new_state,
+include_future=body.include_future,
+include_past=body.include_past,
+)
+for task_id in task_ids
+]
+
+action = BulkUpdateAction(
+action=BulkAction.UPDATE,
+entities=entities,
+update_mask=["new_state"],
+action_on_non_existence=BulkActionNotOnExistence.SKIP,
+)
+results = BulkActionResponse()
+
+service = BulkTaskInstanceService(
+session=session,
+request=BulkBody(actions=[]),  # unused, but required by base class
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+dag_bag=dag_bag,
+user=user,
+)
+service.handle_bulk_update(action, results)

Review Comment:
   Maybe not going through the bulk update service is actually better, the 
interface wasn't ment for this. And it makes you do weird stuff to actually 
plug into it. 
   
   The code for Updating a single TI is probably more re-usable and fitted to 
this use case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2907045411


##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
 )
 
 
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskGroupBody,
+session: SessionDep,
+user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+_patch_task_group_state(
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+group_id=group_id,
+body=body,
+dag_bag=dag_bag,
+user=user,
+session=session,
+)
+
+# Collect all TIs for the task group to build the response
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+tis = (
+session.scalars(
+select(TI)
+.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, 
TI.task_id.in_(task_ids))
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.options(joinedload(TI.dag_version))
+
.options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+)
+.unique()
+.all()

Review Comment:
   This query could probably be avoided. we are fetching from the DB at 
multiple places.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



Re: [PR] feat: implement patching of task group instances in API [airflow]

2026-03-09 Thread via GitHub


pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2906986909


##
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##
@@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str:
 return ns
 
 
+class PatchTaskGroupBody(StrictBaseModel):
+"""Request body for patching the state of all task instances in a task 
group."""
+
+new_state: TaskInstanceState
+include_future: bool = False
+include_past: bool = False
+
+@field_validator("new_state", mode="before")
+@classmethod
+def validate_new_state(cls, ns: str | None) -> str:
+"""Validate new_state."""
+valid_states = [
+vs.name.lower()
+for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, 
TaskInstanceState.SKIPPED)
+]
+if ns is None:
+raise ValueError("'new_state' should not be empty")
+ns = ns.lower()
+if ns not in valid_states:
+raise ValueError(f"'{ns}' is not one of {valid_states}")
+return ns
+

Review Comment:
   This is a complete duplicate of the existing `validate_new_state`. Make a 
common base body class for `PatchTaskGroup` and `PatchTaskInstanceBody`. Same 
for other attributes.



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
 ti.task_instance_note.user_id = user.get_id()
 
 
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)

Review Comment:
   Nit: Not sure we should 404 here. If there is nothing to patch, just return 
`[]` and proceed.



##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
 )
 
 
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskGroupBody,
+session: SessionDep,
+user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+_patch_task_group_state(
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+group_id=group_id,
+body=body,
+dag_bag=dag_bag,
+user=user,
+session=session,
+)
+
+# Collect all TIs for the task group to build the response
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)

Review Comment:
   `_get_task_group_task_ids` is done multiple times. Inside the 
`_patch_task_group_state` too. 
   
   Same for `get_latest_version_of_dag` 



##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -268,34 +366,39 @@ def _perform_update(
 results: BulkActionResponse,
 update_mask: list[str] | None = Query(None),
 ) -> None:
-dag, tis, data = _patch_ti_validate_request(
-dag_id=dag_id,
-dag_run_id=dag_run_id,
-task_id=task_id,
-dag_bag=self.dag_bag,
-body=entity,
-session=self.session,
-update_mask=update_mask,
-)
-
-for key, _ in data.items():
-if key == "new_state":
-_patch_task_instance_state(
-task_id

[PR] feat: implement patching of task group instances in API [airflow]

2026-03-03 Thread via GitHub


OscarLigthart opened a new pull request, #62812:
URL: https://github.com/apache/airflow/pull/62812


   
   
   
   ## Context
   In Airflow 3 the ability to mark a full task group as failed or success is 
currently missing. In this PR, I try to implement the logic into the API, that 
can then be called from the frontend to ensure the functionality returns.
   
   I deliberately split the feature in two separate PRs, so the reviewing 
process can be more targeted to the individually touched components. Should 
this PR get merged, I will continue to build the remaining requirements into 
the frontend.
   
   There is an already open PR here: #60161
   
   However, it looks to be stale, and I would love to get this feature into the 
Airflow 3.2 release.
   
   ## Implementation
   
   I make use of the BulkTaskInstanceService to perform the state updates with 
this endpoint. Using it, the implementation should be pretty straightforward :).
   
   ## Issues
   
   related: #56103 
   
   ---
   
   # Was generative AI tooling used to co-author this PR?
   
   
   
   - [x] Yes (please specify the tool below)
   
   
   Generated-by: Claude Opus following [the 
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions)
   
   
   ---
   
   * Read the **[Pull Request 
Guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-guidelines)**
 for more information. Note: commit author/co-author name and email in commits 
become permanently public when merged.
   * For fundamental code changes, an Airflow Improvement Proposal 
([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals))
 is needed.
   * When adding dependency, check compliance with the [ASF 3rd Party License 
Policy](https://www.apache.org/legal/resolved.html#category-x).
   * For significant user-facing changes create newsfragment: 
`{pr_number}.significant.rst` or `{issue_number}.significant.rst`, in 
[airflow-core/newsfragments](https://github.com/apache/airflow/tree/main/airflow-core/newsfragments).
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]