pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063329169
##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+ DAG_ID = "example_task_group"
+ RUN_ID = "TEST_DAG_RUN_ID"
+ GROUP_ID = "section_1"
+ BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+ """Test that patching a task group sets state for all tasks in the
group."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 200
+ response_data = response.json()
+ assert response_data["total_entries"] == mock_set_ti_state.call_count
+ assert mock_set_ti_state.call_count == 3
+ called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+ assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+ for call in mock_set_ti_state.call_args_list:
+ assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+ """Test that patching a task group with failed state works."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "failed"},
+ )
+ assert response.status_code == 200
+ for call in mock_set_ti_state.call_args_list:
+ assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+ """Test that patching a nested task group includes tasks from inner
groups."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ # section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+ url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 200
+ assert mock_set_ti_state.call_count == 4
+ called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+ assert called_task_ids == [
+ "section_2.inner_section_2.task_2",
+ "section_2.inner_section_2.task_3",
+ "section_2.inner_section_2.task_4",
+ "section_2.task_1",
+ ]
+
+ def test_patch_task_group_not_found(self, test_client, session):
+ """Test that requesting a non-existent task group returns 404."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 404
+ assert "nonexistent_group" in response.json()["detail"]
+
+ def test_patch_task_group_invalid_state(self, test_client, session):
+ """Test that an invalid new_state returns 422."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "invalid_state"},
+ )
+ assert response.status_code == 422
+
+ def test_patch_task_group_dag_not_found(self, test_client, session):
+ """Test that requesting a non-existent DAG returns 404."""
+ url =
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 404
+
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.patch(self.ENDPOINT_URL,
json={"new_state": "success"})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.patch(self.ENDPOINT_URL,
json={"new_state": "success"})
+ assert response.status_code == 403
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_query_count_does_not_scale_with_task_group_size(self,
mock_set_ti_state, test_client, session):
+ """Test that query count doesn't scale linearly with task group size -
single bulk query."""
Review Comment:
Yep actually I checked on airflow 2.x, we used to do
`dag.set_task_group_state`, we need something similar.
As we see we were doing row locking to be sure we do not end up in a
deadlock.
```
@provide_session
def set_task_group_state(
self,
*,
group_id: str,
execution_date: datetime | None = None,
run_id: str | None = None,
state: TaskInstanceState,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
past: bool = False,
commit: bool = True,
session: Session = NEW_SESSION,
) -> list[TaskInstance]:
"""
Set TaskGroup to the given state and clear downstream tasks in
failed or upstream_failed state.
:param group_id: The group_id of the TaskGroup
:param execution_date: Execution date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
:param upstream: Include all upstream tasks of the given task_id
:param downstream: Include all downstream tasks of the given task_id
:param future: Include all future TaskInstances of the given task_id
:param commit: Commit changes
:param past: Include all past TaskInstances of the given task_id
:param session: new session
"""
from airflow.api.common.mark_tasks import set_state
if not exactly_one(execution_date, run_id):
raise ValueError("Exactly one of execution_date or run_id must
be provided")
tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] =
[]
task_ids: list[str] = []
if execution_date is None:
dag_run = session.scalars(
select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id
== self.dag_id)
).one() # Raises an error if not found
resolve_execution_date = dag_run.execution_date
else:
resolve_execution_date = execution_date
end_date = resolve_execution_date if not future else None
start_date = resolve_execution_date if not past else None
task_group_dict = self.task_group.get_task_group_dict()
task_group = task_group_dict.get(group_id)
if task_group is None:
raise ValueError("TaskGroup {group_id} could not be found")
tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
task_ids = [task.task_id for task in task_group.iter_tasks()]
dag_runs_query = select(DagRun.id).where(DagRun.dag_id ==
self.dag_id)
if start_date is None and end_date is None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date ==
start_date)
else:
if start_date is not None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date
>= start_date)
if end_date is not None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date
<= end_date)
with lock_rows(dag_runs_query, session):
altered = set_state(
tasks=tasks_to_set_state,
execution_date=execution_date,
run_id=run_id,
upstream=upstream,
downstream=downstream,
future=future,
past=past,
state=state,
commit=commit,
session=session,
)
if not commit:
return altered
# Clear downstream tasks that are in failed/upstream_failed
state to resume them.
# Flush the session so that the tasks marked success are
reflected in the db.
session.flush()
task_subset = self.partial_subset(
task_ids_or_regex=task_ids,
include_downstream=True,
include_upstream=False,
)
task_subset.clear(
start_date=start_date,
end_date=end_date,
include_subdags=True,
include_parentdag=True,
only_failed=True,
session=session,
# Exclude the task from the current group from being cleared
exclude_task_ids=frozenset(task_ids),
)
return altered
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]