Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3116700513
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update,
by_alias=True)
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.order_by(TI.task_id, TI.map_index)
+)
+
+group_tis = list(session.scalars(query).all())
+
+return group_tis
+
+
+def _patch_ti_group_validate_request(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+update_mask: list[str] | None = Query(None),
+) -> tuple[SerializedDAG, list[TI], dict]:
+"""Validate and prepare data for task group patch request."""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+tis = _get_task_group_task_instances(dag_id, dag_run_id, task_group_id,
dag, session)
+
+fields_to_update = body.model_fields_set
+if update_mask:
+fields_to_update = fields_to_update.intersection(update_mask)
+else:
+try:
+PatchTaskInstanceBody.model_validate(body)
+except ValidationError as e:
+raise RequestValidationError(errors=e.errors())
+
+return dag, tis, body.model_dump(include=fields_to_update, by_alias=True)
Review Comment:
I took it out of both to prevent duplicate logic, but I'm not entirely sure
if that's what you meant. Let me know π
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3116571531
##
airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml:
##
@@ -7348,6 +7348,166 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ /api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}:
+patch:
+ tags:
+ - Task Instance
+ summary: Patch Task Group Instances
+ description: Update the state of all task instances in a task group.
+ operationId: patch_task_group_instances
Review Comment:
Should be good, the 409 is only registered for the non dry-run, as there are
no changes committed in the dry run.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265585
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update,
by_alias=True)
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.order_by(TI.task_id, TI.map_index)
+)
+
+group_tis = list(session.scalars(query).all())
+
+return group_tis
+
+
+def _patch_ti_group_validate_request(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+update_mask: list[str] | None = Query(None),
Review Comment:
```suggestion
update_mask: list[str] | None = None,
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265596
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -145,6 +214,53 @@ def _patch_task_instance_state(
except Exception:
log.exception("error calling listener")
+return updated_tis
+
+
+def _patch_task_group_state(
+group_id: str,
+dag_run_id: str,
+dag: SerializedDAG,
+body: PatchTaskInstanceBody,
+data: dict,
+*,
+session: Session,
+) -> list[TI]:
+"""Update the state of all task instances in a task group."""
+updated_tis = dag.set_task_group_state(
+group_id=group_id,
+run_id=dag_run_id,
+state=data["new_state"],
+upstream=body.include_upstream,
+downstream=body.include_downstream,
+future=body.include_future,
+past=body.include_past,
+commit=True,
Review Comment:
Nit: `commit=True` is hardcoded here but the helper could easily take it as
a param. Not critical, but would make this reusable for a dry-run path later
without duplicating the function.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3111265581
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update,
by_alias=True)
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
Review Comment:
`joinedload(TI.rendered_task_instance_fields)` is applied here AND
re-queried with `populate_existing=True` in the route after
`set_task_group_state` returns. One of them is wasted β probably this one since
the route always re-queries. Can we drop it here?
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update,
by_alias=True)
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND, f"Task group '{task_group_id}' not
found in DAG '{dag_id}'"
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+
+query = (
+select(TI)
+.where(
+TI.dag_id == dag_id,
+TI.run_id == dag_run_id,
+TI.task_id.in_(task_ids),
+)
+.join(TI.dag_run)
Review Comment:
Is this join needed? The `dag_id` + `run_id` filters already constrain the
result. I see `_patch_ti_validate_request` does the same thing β pre-existing
but worth cleaning up in both places while we're here.
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -56,6 +56,17 @@
log = structlog.get_logger(__name__)
+def _collect_unique_tis(
+affected_tis_dict: dict[tuple[str, str, str, int], TI],
+tis: list[TI] | None,
+) -> None:
+"""Collect unique task instances into a dictionary keyed by (dag_id,
run_id, task_id, map_index)."""
+if tis:
+for ti in tis:
+key = (ti.dag_id, ti.run_id, ti.task_id, ti.map_index)
+affected_tis_dict[key] = ti
Review Comment:
Do we really need this? `set_task_group_state` already returns unique TIs
(the underlying `set_state` dedupes), and the `note` branch only touches TIs
already in the group. Feels defensive for no real case.
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -6477,3 +6478,383 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
Review Comment:
Most tests in this class mock `set_task_group_state`, which means we're only
validating the route/service glue. The real v2.x-equivalent logic we just added
in `SerializedDAG.set_task_group_state` isn't exercised. Can we add unmocked
tests for:
- `include_upstream=True` / `include_downstream=True` actually affecting
up/downstream TIs
- `past=True` / `future=True` covering the right runs (and the
`exclude_run_ids` logic)
- The `clear()` of failed downstream actually resuming them
- The 409 path (all TIs already in target state)
The existing `test_patch_task_group_updates_ti_states_in_db` and the
query-count tests are exactly the right pattern β just more of those.
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -145,6 +214,53 @@ def _patch_task_instance_state(
except Exception:
log.exception("error calling listener")
+return updated_tis
+
+
+def _patch_task_group_state(
+group_id: str,
+dag_run_id: str,
+dag: SerializedDAG,
+body: PatchTaskInstanceBody,
+data: dict,
+*,
+
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4281454132 ### Concerns 1. **Most tests mock `set_task_group_state`** β `test_patch_task_group_success`, `_failed_state`, `_nested`, `_includes_upstream_downstream_parameters`, `_dry_run_returns_affected_tis_without_committing` all mock the core logic. This means the tests verify the route/service glue but NOT that the v2.x-equivalent logic in `SerializedDAG.set_task_group_state` is correct. The query-count tests (`test_query_count_does_not_scale_with_task_group_size`, `test_dry_run_query_count_does_not_scale`) and `test_patch_task_group_updates_ti_states_in_db` don't mock and are the real regression guards β keep those, but add more unmocked tests for: - `include_upstream=True` / `include_downstream=True` β verify actual upstream/downstream TIs get their state set - `past=True` / `future=True` β verify runs before/after the target are affected AND the `exclude_run_ids` logic works - Empty or no-op group (409 path) - The `clear()` of failed downstream tasks actually resumes them 2. **`_collect_unique_tis` may be unnecessary** β the key `(dag_id, run_id, task_id, map_index)` isn't necessary if `set_task_group_state` already returns unique TIs (the underlying `set_state` dedupes). Only the `"note"` branch adds the tis separately, but those are all within the group too β so the dedup helper is defensive rather than necessary. Could be simplified. 3. **`_patch_ti_group_validate_request` is near-duplicate of `_patch_ti_validate_request`** β differs only in replacing `task_id`/`map_index` with `task_group_id` and calling `_get_task_group_task_instances`. Consider parameterizing or extracting the common validation logic. 4. **`_get_task_group_task_instances` uses `.join(TI.dag_run)` with no DagRun filter** β Copilot flagged this too. Since `TI.dag_id == dag_id, TI.run_id == run_id` already constrain the result, the join is redundant. Remove or justify with a comment (e.g., "inner join enforces the DagRun exists"). Also `joinedload(TI.rendered_task_instance_fields)` is applied here AND again in the route (`patch_task_group_instances` does a re-query with `populate_existing=True`). Looks like duplicated eager-load work β verify the first one is actually used, otherwise drop it. 5. **Conflict on "already in target state"** (`_patch_task_group_state`) ```python if not updated_tis: raise HTTPException(status.HTTP_409_CONFLICT, ...) ``` The single-TI endpoint returns an empty list instead of 409 (I argued against 404 earlier). Inconsistent behavior between "single TI already in state" (empty 200) and "all group TIs already in state" (409). Align them: return empty list 200. 6. **`dag.set_task_group_state` in the service is called with `commit=True`, but the route function also calls it via `_patch_task_group_state` which also defaults `commit=True`**. Not a bug, but if someone ever calls `_patch_task_group_state` expecting `commit=False`, there's no way to override β the helper hardcodes `commit=True`. Consider threading `commit` as a parameter for future reuse. 7. **`_patch_ti_group_validate_request` uses `Query(None)`** for `update_mask` β copilot flagged this. `Query` is a FastAPI dependency marker; using it in a service helper is misleading (the param isn't actually populated from the request here since this is called internally). Use `update_mask: list[str] | None = None`. 9. **OpenAPI spec regenerated but not verified** β `v2-rest-api-generated.yaml` has 162 additions. Review the spec diff for correct request/response schemas, especially the 409 response added per point 5 above. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4275693100 Hi @pierrejeambrun, I had to make some tweaks to make the `set_task_group_state` work as it was found in Airflow 2.x. Let me know if this is going in the right direction. As an alternative I have the previous to last commit, which has the original implementation with your initial feedback. This solution would cause a steeper query count growth with scale, but doesn't add the logic in the DAG serialization just yet. Let me know how you would like to proceed! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4260172373 > @OscarLigthart I replied to your main concern there [#62812 (comment)](https://github.com/apache/airflow/pull/62812#discussion_r3063329169) just in case you missed it Hey @pierrejeambrun, sorry for the radio silence on my side, had a few crazy days. I will find some time over the weekend to look through your suggestions and apply the necessary changes π -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4250536838 @OscarLigthart I replied to your main concern there https://github.com/apache/airflow/pull/62812#discussion_r3063329169 just in case you missed it -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
Copilot commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3066481777
##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -83,8 +83,12 @@
from airflow.api_fastapi.core_api.security import GetUserDep,
ReadableTIFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.services.public.task_instances import (
BulkTaskInstanceService,
+_collect_unique_tis,
+_get_task_group_task_instances,
+_patch_task_group_state,
_patch_task_instance_note,
_patch_task_instance_state,
+_patch_ti_group_validate_request,
_patch_ti_validate_request,
)
Review Comment:
This route module imports multiple underscore-prefixed helpers from the
service module. Since these are now used cross-module, theyβre effectively part
of the internal API surface and the underscore convention becomes misleading.
Consider promoting these to public helpers (drop the `_` prefix), or
encapsulate them behind a single public service function (e.g.,
`TaskGroupInstanceService.patch(...)`) to keep routing logic thinner and reduce
coupling.
##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -864,6 +868,133 @@ def _collect_relatives(run_id: str, direction:
Literal["upstream", "downstream"]
)
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT",
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskInstanceBody,
+session: SessionDep,
+user: GetUserDep,
+update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+dag, tis, data = _patch_ti_group_validate_request(
+dag_id, dag_run_id, group_id, dag_bag, body, session, update_mask
+)
+affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
+for key, _ in data.items():
+if key == "new_state":
+updated_tis = _patch_task_group_state(
+tis=tis,
+dag_run_id=dag_run_id,
+dag=dag,
+body=body,
+data=data,
+session=session,
+)
Review Comment:
`data` can include `"new_state": None` if the client explicitly sends
`{"new_state": null}` (or uses `update_mask=["new_state"]` with a null value).
In that case this route will still enter the `"new_state"` branch and
`_patch_task_group_state()` will attempt to set `state=None`, which will fail
downstream. Add explicit validation that if `"new_state"` is present in
`fields_to_update` / `update_mask` then `body.new_state` must be non-null
(raise a 422 `RequestValidationError` / `HTTPException`). Alternatively,
exclude `"new_state"` from `data` when its value is `None` so the route becomes
a no-op consistently (like the dry-run path which checks `if body.new_state:`).
##
airflow-core/src/airflow/serialization/definitions/dag.py:
##
@@ -761,6 +761,100 @@ def set_task_instance_state(
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)),
**clear_kwargs)
return altered
+@provide_session
+def set_multiple_task_instances_state(
+self,
+*,
+task_ids_with_map_indexes: list[tuple[str, int]],
+run_id: str | None = None,
+state: TaskInstanceState,
Review Comment:
`run_id` is annotated as optional, but this method unconditionally does a
`select(DagRun.id, DagRun.logical_date).where(DagRun.run_id == run_id, ...)`
followed by `.one()`, which will raise if `run_id` is `None` or not found.
Since this is effectively required for correctness, make `run_id` a required
`str` parameter (remove the default) or add an explicit guard that raises a
clear `ValueError`/`AirflowException` when `run_id` is not provided.
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -101,14 +112,72 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update,
by_alias=True)
+def _get_task_group_task_instances(
+dag_id: str,
+dag_run_id: str,
+task_group_id: str,
+dag: SerializedDAG,
+session: Session,
+) -> list[TI]:
+"""Get all task instances in a task group for a specific DAG run."""
+task_group = dag.task_group_dict.get(task_group_id)
+if not task_group:
+raise HTTPException(
+statu
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063329169
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063329169
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respon
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4185146488 Thanks for reopening and reviewing @pierrejeambrun ! I've addressed most of your comments. I got stuck again on the N+1 query problem after removing the mock. I've consulted with Claude and left a trace of the resulting conversation (can't believe I'm already calling this "conversations" with my agent, what a time to be alive). Let me know how you would like to proceed! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027386658
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a nested task group includes tasks from inner
groups."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+# section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+response = test_client.patch(
+url,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+assert mock_set_ti_state.call_count == 4
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == [
+"section_2.inner_section_2.task_2",
+"section_2.inner_section_2.task_3",
+"section_2.inner_section_2.task_4",
+"section_2.task_1",
+]
+
+def test_patch_task_group_not_found(self, test_client, session):
+"""Test that requesting a non-existent task group returns 404."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+respo
[PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart opened a new pull request, #62812:
URL: https://github.com/apache/airflow/pull/62812
## Context
In Airflow 3 the ability to mark a full task group as failed or success is
currently missing. In this PR, I try to implement the logic into the API, that
can then be called from the frontend to ensure the functionality returns.
I deliberately split the feature in two separate PRs, so the reviewing
process can be more targeted to the individually touched components. Should
this PR get merged, I will continue to build the remaining requirements into
the frontend.
There is an already open PR here: #60161
However, it looks to be stale, and I would love to get this feature into the
Airflow 3.2 release.
## Implementation
I make use of the BulkTaskInstanceService to perform the state updates with
this endpoint. Using it, the implementation should be pretty straightforward :).
## Issues
related: #56103
---
# Was generative AI tooling used to co-author this PR?
- [x] Yes (please specify the tool below)
Generated-by: Claude Opus following [the
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions)
---
* Read the **[Pull Request
Guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-guidelines)**
for more information. Note: commit author/co-author name and email in commits
become permanently public when merged.
* For fundamental code changes, an Airflow Improvement Proposal
([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals))
is needed.
* When adding dependency, check compliance with the [ASF 3rd Party License
Policy](https://www.apache.org/legal/resolved.html#category-x).
* For significant user-facing changes create newsfragment:
`{pr_number}.significant.rst` or `{issue_number}.significant.rst`, in
[airflow-core/newsfragments](https://github.com/apache/airflow/tree/main/airflow-core/newsfragments).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027370557
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_set_ti_state.call_count
+assert mock_set_ti_state.call_count == 3
+called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+"""Test that patching a task group with failed state works."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "failed"},
+)
+assert response.status_code == 200
+for call in mock_set_ti_state.call_args_list:
+assert call.kwargs["state"] == "failed"
Review Comment:
Same here assert the API response.
##
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+DAG_ID = "example_task_group"
+RUN_ID = "TEST_DAG_RUN_ID"
+GROUP_ID = "section_1"
+BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+"""Test that patching a task group sets state for all tasks in the
group."""
+self.create_task_instances(session, dag_id=self.DAG_ID)
+
+tis = session.scalars(
+select(TaskInstance).where(
+TaskInstance.dag_id == self.DAG_ID,
+TaskInstance.run_id == self.RUN_ID,
+TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+)
+).all()
+
+ti_map = {ti.task_id: ti for ti in tis}
+mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+response = test_client.patch(
+self.ENDPOINT_URL,
+json={"new_state": "success"},
+)
+assert response.status_code == 200
+response_data = response.json()
+assert response_data["total_entries"] == mock_s
Re: [PR] feat: implement patching of task group instances in API [airflow]
potiuk commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4169717356 This pull request has been converted to draft due to quality issues more than a week ago and there has been no response from the author since then. We are closing it for now to keep the review queue manageable. **@OscarLigthart**, you are welcome to reopen this PR after addressing the review comments. Thank you for your contribution! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
potiuk closed pull request #62812: feat: implement patching of task group instances in API URL: https://github.com/apache/airflow/pull/62812 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
potiuk commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4120213722 @OscarLigthart This PR has been converted to **draft** because it does not yet meet our [Pull Request quality criteria](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-quality-criteria). **Issues found:** - :x: **Merge conflicts**: This PR has merge conflicts with the `main` branch. Your branch is 237 commits behind `main`. Please rebase your branch (`git fetch origin && git rebase origin/main`), resolve the conflicts, and push again. See [contributing quick start](https://github.com/apache/airflow/blob/main/contributing-docs/03a_contributors_quick_start_beginners.rst). > **Note:** Your branch is **237 commits behind `main`**. Some check failures may be caused by changes in the base branch rather than by your PR. Please rebase your branch and push again to get up-to-date CI results. **What to do next:** - The comment informs you what you need to do. - Fix each issue, then mark the PR as "Ready for review" in the GitHub UI - but only after making sure that all the issues are fixed. - There is no rush β take your time and work at your own pace. We appreciate your contribution and are happy to wait for updates. - Maintainers will then proceed with a normal review. Converting a PR to draft is **not** a rejection β it is an invitation to bring the PR up to the project's standards so that maintainer review time is spent productively. There is no rush β take your time and work at your own pace. We appreciate your contribution and are happy to wait for updates. If you have questions, feel free to ask on the [Airflow Slack](https://s.apache.org/airflow-slack). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
potiuk commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4038905256 @OscarLigthart This PR has been converted to **draft** because it does not yet meet our [Pull Request quality criteria](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-quality-criteria). **Issues found:** - :x: **Merge conflicts**: This PR has merge conflicts with the `main` branch. Your branch is 64 commits behind `main`. Please rebase your branch (`git fetch origin && git rebase origin/main`), resolve the conflicts, and push again. See [contributing quick start](https://github.com/apache/airflow/blob/main/contributing-docs/03a_contributors_quick_start_beginners.rst). - :warning: **Unresolved review comments**: This PR has 2 unresolved review threads from maintainers. Please review and resolve all inline review comments before requesting another review. You can resolve a conversation by clicking 'Resolve conversation' on each thread after addressing the feedback. See [pull request guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst). > **Note:** Your branch is **64 commits behind `main`**. Some check failures may be caused by changes in the base branch rather than by your PR. Please rebase your branch and push again to get up-to-date CI results. **What to do next:** - The comment informs you what you need to do. - Fix each issue, then mark the PR as "Ready for review" in the GitHub UI - but only after making sure that all the issues are fixed. - Maintainers will then proceed with a normal review. Converting a PR to draft is **not** a rejection β it is an invitation to bring the PR up to the project's standards so that maintainer review time is spent productively. If you have questions, feel free to ask on the [Airflow Slack](https://s.apache.org/airflow-slack). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on PR #62812: URL: https://github.com/apache/airflow/pull/62812#issuecomment-4027342056 > Thanks for the PR. > > Can you add query guards to tests (`assert_queries_count`) . I'm afraid this will generate N+1 queries problem -> Number of db request will scale linearly with the number of TIs in the group. And we should rework the code not do do that. > > Also I find the code overly complicated with numerous duplicated call made at different abstraction level because of function nesting, it makes the whole thing hard to understand and probably sub optimal. Thanks for the detailed review! I think the points made a lot of sense. Apologies for this oversight on my side. I tried following the implementation of the other PR more, removing the BulkTaskInstanceService usage and solving the N+1 query problem. I still opted to keep them in separate endpoints for clarity. Let me know if you disagree! Also happy to cast this aside and wait for the other PR to be picked up again, if that is more in line with your preference. One question regarding the `assert_queries_count`, I've create multiple cases where they remain the same, effectively eliminating the query problem, but I'm guessing there's many things that can impact this amount. In turn, this would break the tests implemented here. Is there a specific rule to keep in mind when applying it? Or is just setting it in the way I do in the test correct? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812: URL: https://github.com/apache/airflow/pull/62812#discussion_r2908258611 ## airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py: ## @@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str: return ns +class PatchTaskGroupBody(StrictBaseModel): +"""Request body for patching the state of all task instances in a task group.""" + +new_state: TaskInstanceState +include_future: bool = False +include_past: bool = False Review Comment: Added back! I thought we wouldn't need them for the UI implementation, but better to keep them for the endpoint. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2908255963
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
ti.task_instance_note.user_id = user.get_id()
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)
+
+return task_ids
+
+
+def _patch_task_group_state(
+*,
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+body: PatchTaskGroupBody,
+dag_bag: DagBagDep,
+user: GetUserDep,
+session: Session,
+) -> None:
+"""
+Set the state of all task instances in a task group.
+
+Uses BulkTaskInstanceService to update each task in the group.
+
+:param dag_id: The DAG ID.
+:param dag_run_id: The run_id of the DAG run.
+:param group_id: The ID of the task group.
+:param body: The request body with the new state and options.
+:param dag_bag: The DAG bag for DAG resolution.
+:param user: The authenticated user.
+:param session: The database session.
+"""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+
+entities = [
+BulkTaskInstanceBody(
+task_id=task_id,
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+new_state=body.new_state,
+include_future=body.include_future,
+include_past=body.include_past,
+)
+for task_id in task_ids
+]
+
+action = BulkUpdateAction(
+action=BulkAction.UPDATE,
+entities=entities,
+update_mask=["new_state"],
+action_on_non_existence=BulkActionNotOnExistence.SKIP,
+)
+results = BulkActionResponse()
+
+service = BulkTaskInstanceService(
+session=session,
+request=BulkBody(actions=[]), # unused, but required by base class
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+dag_bag=dag_bag,
+user=user,
+)
+service.handle_bulk_update(action, results)
Review Comment:
I tried taking inspiration from the other PR and follow that implementation
more. π
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2907088878
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
ti.task_instance_note.user_id = user.get_id()
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)
+
+return task_ids
+
+
+def _patch_task_group_state(
+*,
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+body: PatchTaskGroupBody,
+dag_bag: DagBagDep,
+user: GetUserDep,
+session: Session,
+) -> None:
+"""
+Set the state of all task instances in a task group.
+
+Uses BulkTaskInstanceService to update each task in the group.
+
+:param dag_id: The DAG ID.
+:param dag_run_id: The run_id of the DAG run.
+:param group_id: The ID of the task group.
+:param body: The request body with the new state and options.
+:param dag_bag: The DAG bag for DAG resolution.
+:param user: The authenticated user.
+:param session: The database session.
+"""
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+
+entities = [
+BulkTaskInstanceBody(
+task_id=task_id,
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+new_state=body.new_state,
+include_future=body.include_future,
+include_past=body.include_past,
+)
+for task_id in task_ids
+]
+
+action = BulkUpdateAction(
+action=BulkAction.UPDATE,
+entities=entities,
+update_mask=["new_state"],
+action_on_non_existence=BulkActionNotOnExistence.SKIP,
+)
+results = BulkActionResponse()
+
+service = BulkTaskInstanceService(
+session=session,
+request=BulkBody(actions=[]), # unused, but required by base class
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+dag_bag=dag_bag,
+user=user,
+)
+service.handle_bulk_update(action, results)
Review Comment:
Maybe not going through the bulk update service is actually better, the
interface wasn't ment for this. And it makes you do weird stuff to actually
plug into it.
The code for Updating a single TI is probably more re-usable and fitted to
this use case.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2907045411
##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction:
Literal["upstream", "downstream"]
)
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT",
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskGroupBody,
+session: SessionDep,
+user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+_patch_task_group_state(
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+group_id=group_id,
+body=body,
+dag_bag=dag_bag,
+user=user,
+session=session,
+)
+
+# Collect all TIs for the task group to build the response
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
+tis = (
+session.scalars(
+select(TI)
+.where(TI.dag_id == dag_id, TI.run_id == dag_run_id,
TI.task_id.in_(task_ids))
+.join(TI.dag_run)
+.options(joinedload(TI.rendered_task_instance_fields))
+.options(joinedload(TI.dag_version))
+
.options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
+)
+.unique()
+.all()
Review Comment:
This query could probably be avoided. we are fetching from the DB at
multiple places.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
Re: [PR] feat: implement patching of task group instances in API [airflow]
pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r2906986909
##
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##
@@ -215,6 +215,29 @@ def validate_new_state(cls, ns: str | None) -> str:
return ns
+class PatchTaskGroupBody(StrictBaseModel):
+"""Request body for patching the state of all task instances in a task
group."""
+
+new_state: TaskInstanceState
+include_future: bool = False
+include_past: bool = False
+
+@field_validator("new_state", mode="before")
+@classmethod
+def validate_new_state(cls, ns: str | None) -> str:
+"""Validate new_state."""
+valid_states = [
+vs.name.lower()
+for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED)
+]
+if ns is None:
+raise ValueError("'new_state' should not be empty")
+ns = ns.lower()
+if ns not in valid_states:
+raise ValueError(f"'{ns}' is not one of {valid_states}")
+return ns
+
Review Comment:
This is a complete duplicate of the existing `validate_new_state`. Make a
common base body class for `PatchTaskGroup` and `PatchTaskInstanceBody`. Same
for other attributes.
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -153,6 +158,99 @@ def _patch_task_instance_note(
ti.task_instance_note.user_id = user.get_id()
+def _get_task_group_task_ids(
+dag: SerializedDAG,
+group_id: str,
+) -> list[str]:
+"""
+Get task ids that belong to a task group.
+
+:param dag: The serialized DAG containing the task group.
+:param group_id: The ID of the task group.
+:return: List of task IDs in the group.
+:raises HTTPException: If the task group is not found or has no tasks.
+"""
+if not hasattr(dag, "task_group"):
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"DAG '{dag.dag_id}' does not have task groups",
+)
+
+task_groups = dag.task_group.get_task_group_dict()
+task_group = task_groups.get(group_id)
+if not task_group:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' not found in DAG '{dag.dag_id}'",
+)
+
+task_ids = [task.task_id for task in task_group.iter_tasks()]
+if not task_ids:
+raise HTTPException(
+status.HTTP_404_NOT_FOUND,
+f"Task group '{group_id}' in DAG '{dag.dag_id}' has no tasks",
+)
Review Comment:
Nit: Not sure we should 404 here. If there is nothing to patch, just return
`[]` and proceed.
##
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##
@@ -847,6 +850,103 @@ def _collect_relatives(run_id: str, direction:
Literal["upstream", "downstream"]
)
+@task_instances_router.patch(
+"/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+responses=create_openapi_http_exception_doc(
+[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+),
+dependencies=[
+Depends(action_logging()),
+Depends(requires_access_dag(method="PUT",
access_entity=DagAccessEntity.TASK_INSTANCE)),
+],
+operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+dag_id: str,
+dag_run_id: str,
+group_id: str,
+dag_bag: DagBagDep,
+body: PatchTaskGroupBody,
+session: SessionDep,
+user: GetUserDep,
+) -> TaskInstanceCollectionResponse:
+"""Update the state of all task instances in a task group."""
+_patch_task_group_state(
+dag_id=dag_id,
+dag_run_id=dag_run_id,
+group_id=group_id,
+body=body,
+dag_bag=dag_bag,
+user=user,
+session=session,
+)
+
+# Collect all TIs for the task group to build the response
+dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+task_ids = _get_task_group_task_ids(dag, group_id)
Review Comment:
`_get_task_group_task_ids` is done multiple times. Inside the
`_patch_task_group_state` too.
Same for `get_latest_version_of_dag`
##
airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py:
##
@@ -268,34 +366,39 @@ def _perform_update(
results: BulkActionResponse,
update_mask: list[str] | None = Query(None),
) -> None:
-dag, tis, data = _patch_ti_validate_request(
-dag_id=dag_id,
-dag_run_id=dag_run_id,
-task_id=task_id,
-dag_bag=self.dag_bag,
-body=entity,
-session=self.session,
-update_mask=update_mask,
-)
-
-for key, _ in data.items():
-if key == "new_state":
-_patch_task_instance_state(
-task_id
[PR] feat: implement patching of task group instances in API [airflow]
OscarLigthart opened a new pull request, #62812:
URL: https://github.com/apache/airflow/pull/62812
## Context
In Airflow 3 the ability to mark a full task group as failed or success is
currently missing. In this PR, I try to implement the logic into the API, that
can then be called from the frontend to ensure the functionality returns.
I deliberately split the feature in two separate PRs, so the reviewing
process can be more targeted to the individually touched components. Should
this PR get merged, I will continue to build the remaining requirements into
the frontend.
There is an already open PR here: #60161
However, it looks to be stale, and I would love to get this feature into the
Airflow 3.2 release.
## Implementation
I make use of the BulkTaskInstanceService to perform the state updates with
this endpoint. Using it, the implementation should be pretty straightforward :).
## Issues
related: #56103
---
# Was generative AI tooling used to co-author this PR?
- [x] Yes (please specify the tool below)
Generated-by: Claude Opus following [the
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions)
---
* Read the **[Pull Request
Guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-guidelines)**
for more information. Note: commit author/co-author name and email in commits
become permanently public when merged.
* For fundamental code changes, an Airflow Improvement Proposal
([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals))
is needed.
* When adding dependency, check compliance with the [ASF 3rd Party License
Policy](https://www.apache.org/legal/resolved.html#category-x).
* For significant user-facing changes create newsfragment:
`{pr_number}.significant.rst` or `{issue_number}.significant.rst`, in
[airflow-core/newsfragments](https://github.com/apache/airflow/tree/main/airflow-core/newsfragments).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
