pierrejeambrun commented on code in PR #44220:
URL: https://github.com/apache/airflow/pull/44220#discussion_r1851661003
##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -265,7 +268,7 @@ def get_mapped_task_instance(
@task_instances_router.get(
- "",
+ task_instances_prefix + "",
Review Comment:
```suggestion
task_instances_prefix,
```
##########
airflow/api_fastapi/common/types.py:
##########
@@ -59,6 +66,20 @@ class TimeDelta(BaseModel):
TimeDeltaWithValidation = Annotated[TimeDelta,
BeforeValidator(_validate_timedelta_field)]
+def _validate_nonnaive_datetime_field(dt: datetime | None) -> datetime | None:
+ """Validate and return the datetime field."""
+ if dt is None:
+ return None
+ if isinstance(dt, str):
+ dt = datetime.fromisoformat(dt)
+ if not dt.tzinfo:
+ raise ValueError("Invalid datetime format, Naive datetime is
disallowed")
+ return dt
+
+
+DatetimeWithNonNaiveValidation = Annotated[datetime,
BeforeValidator(_validate_nonnaive_datetime_field)]
+
+
Review Comment:
Pydantic has a native `AwareDatetime` (we use it a little bit) that does
exactly that I believe.
At some point we might want to switch all datetime to `AwareDatetime` I think
##########
airflow/api_fastapi/core_api/datamodels/task_instances.py:
##########
@@ -150,3 +154,54 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):
task_instances: list[TaskInstanceHistoryResponse]
total_entries: int
+
+
+class ClearTaskInstancesBody(BaseModel):
+ """Request body for Clear Task Instances endpoint."""
+
+ dry_run: bool = True
+ start_date: DatetimeWithNonNaiveValidation | None = None
+ end_date: DatetimeWithNonNaiveValidation | None = None
+ only_failed: bool = True
+ only_running: bool = False
+ reset_dag_runs: bool = False
+ task_ids: list[str] | None = None
+ dag_run_id: str | None = None
+ include_upstream: bool = False
+ include_downstream: bool = False
+ include_future: bool = False
+ include_past: bool = False
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_model(cls, data: Any) -> Any:
+ """Validate clear task instance form."""
+ if data.get("only_failed") and data.get("only_running"):
+ raise ValidationError("only_failed and only_running both are set
to True")
+ if data.get("start_date") and data.get("end_date"):
+ if data.get("start_date") > data.get("end_date"):
+ raise ValidationError("end_date is sooner than start_date")
+ if data.get("start_date") and data.get("end_date") and
data.get("dag_run_id"):
+ raise ValidationError("Exactly one of dag_run_id or (start_date
and end_date) must be provided")
+ if data.get("start_date") and data.get("dag_run_id"):
+ raise ValidationError("Exactly one of dag_run_id or start_date
must be provided")
+ if data.get("end_date") and data.get("dag_run_id"):
+ raise ValidationError("Exactly one of dag_run_id or end_date must
be provided")
+ if isinstance(data.get("task_ids"), list) and
len(data.get("task_ids")) < 1:
+ raise ValidationError("task_ids list should have at least 1
element.")
+ return data
+
+
+class TaskInstanceReferenceResponse(BaseModel):
+ """Task Instance Reference serializer for responses."""
+
+ task_id: str
+ dag_run_id: str = Field(validation_alias=AliasChoices("run_id"))
Review Comment:
Alias choice with only 1 choice should be equivalent to that:
```suggestion
dag_run_id: str = Field(validation_alias="run_id")
```
##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -482,3 +485,88 @@ def get_mapped_task_instance_try_details(
map_index=map_index,
session=session,
)
+
+
+@task_instances_router.post(
+ "/clearTaskInstances",
+ responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+)
+def post_clear_task_instances(
+ dag_id: str,
+ request: Request,
+ body: ClearTaskInstancesBody,
+ session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceReferenceCollectionResponse:
+ """Clear task instances."""
+ dag = request.app.state.dag_bag.get_dag(dag_id)
+ if not dag:
+ error_message = f"DAG {dag_id} not found"
+ raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+
+ reset_dag_runs = body.reset_dag_runs
+ dry_run = body.dry_run
+ # We always pass dry_run here, otherwise this would try to confirm on the
terminal!
+ dag_run_id = body.dag_run_id
+ future = body.include_future
+ past = body.include_past
+ downstream = body.include_downstream
+ upstream = body.include_upstream
+
+ if dag_run_id is not None:
+ dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id ==
dag_id, DR.run_id == dag_run_id))
+ if dag_run is None:
+ error_message = f"Dag Run id {dag_run_id} not found in dag
{dag_id}"
+ raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+ body.start_date = dag_run.logical_date
+ body.end_date = dag_run.logical_date
+
+ if past:
+ body.start_date = None
+
+ if future:
+ body.end_date = None
+
+ task_ids = body.task_ids
+ if task_ids is not None:
+ task_id = [task[0] if isinstance(task, tuple) else task for task in
task_ids]
+ dag = dag.partial_subset(
+ task_ids_or_regex=task_id,
+ include_downstream=downstream,
+ include_upstream=upstream,
+ )
+
+ if len(dag.task_dict) > 1:
+ # If we had upstream/downstream etc then also include those!
+ task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
+
+ task_instances = dag.clear(
+ dry_run=True,
+ task_ids=body.task_ids,
Review Comment:
We need to use the `extended` task_ids with upstream/downstream too. Not
only the `taks_ids` from the payload.
```suggestion
task_ids=task_ids,
```
##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -482,3 +485,88 @@ def get_mapped_task_instance_try_details(
map_index=map_index,
session=session,
)
+
+
+@task_instances_router.post(
+ "/clearTaskInstances",
+ responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+)
+def post_clear_task_instances(
+ dag_id: str,
+ request: Request,
+ body: ClearTaskInstancesBody,
+ session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceReferenceCollectionResponse:
+ """Clear task instances."""
+ dag = request.app.state.dag_bag.get_dag(dag_id)
+ if not dag:
+ error_message = f"DAG {dag_id} not found"
+ raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+
+ reset_dag_runs = body.reset_dag_runs
+ dry_run = body.dry_run
+ # We always pass dry_run here, otherwise this would try to confirm on the
terminal!
+ dag_run_id = body.dag_run_id
+ future = body.include_future
+ past = body.include_past
+ downstream = body.include_downstream
+ upstream = body.include_upstream
+
+ if dag_run_id is not None:
+ dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id ==
dag_id, DR.run_id == dag_run_id))
+ if dag_run is None:
+ error_message = f"Dag Run id {dag_run_id} not found in dag
{dag_id}"
+ raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+ body.start_date = dag_run.logical_date
+ body.end_date = dag_run.logical_date
+
+ if past:
+ body.start_date = None
+
+ if future:
+ body.end_date = None
+
+ task_ids = body.task_ids
+ if task_ids is not None:
+ task_id = [task[0] if isinstance(task, tuple) else task for task in
task_ids]
+ dag = dag.partial_subset(
+ task_ids_or_regex=task_id,
+ include_downstream=downstream,
+ include_upstream=upstream,
+ )
+
+ if len(dag.task_dict) > 1:
+ # If we had upstream/downstream etc then also include those!
+ task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
+
+ task_instances = dag.clear(
+ dry_run=True,
+ task_ids=body.task_ids,
+ dag_bag=request.app.state.dag_bag,
+ **body.model_dump(
+ include=[ # type: ignore[arg-type]
+ "start_date",
+ "end_date",
+ "only_failed",
+ "only_running",
+ ]
Review Comment:
I think a set will remove the typing error.
```suggestion
include={
"start_date",
"end_date",
"only_failed",
"only_running",
}
```
##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -1663,3 +1668,649 @@ def test_raises_404_for_nonexistent_task_instance(self,
test_client, session):
assert response.json() == {
"detail": "The Task Instance with dag_id:
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id:
`nonexistent_task`, try_number: `0` and map_index: `-1` was not found"
}
+
+
+class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
+ @pytest.mark.parametrize(
+ "main_dag, task_instances, request_dag, payload, expected_ti",
+ [
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "start_date": DEFAULT_DATETIME_STR_2,
+ "only_failed": True,
+ },
+ 2,
+ id="clear start date filter",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "end_date": DEFAULT_DATETIME_STR_2,
+ "only_failed": True,
+ },
+ 2,
+ id="clear end date filter",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.RUNNING},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.RUNNING,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {"dry_run": True, "only_running": True, "only_failed": False},
+ 2,
+ id="clear only running",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.RUNNING,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "only_failed": True,
+ },
+ 2,
+ id="clear only failed",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=3),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "task_ids": ["print_the_context", "sleep_for_1"],
+ },
+ 2,
+ id="clear by task ids",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.RUNNING,
+ },
+ ],
+ "example_python_operator",
+ {
+ "only_failed": True,
+ },
+ 2,
+ id="dry_run default",
+ ),
+ ],
+ )
+ def test_should_respond_200(
+ self,
+ test_client,
+ session,
+ main_dag,
+ task_instances,
+ request_dag,
+ payload,
+ expected_ti,
+ ):
+ self.create_task_instances(
+ session,
+ dag_id=main_dag,
+ task_instances=task_instances,
+ update_extras=False,
+ )
+ self.dagbag.sync_to_db()
+ response = test_client.post(
+ f"/public/dags/{request_dag}/clearTaskInstances",
+ json=payload,
+ )
+ assert response.status_code == 200
+ assert len(response.json()["task_instances"]) == expected_ti
+
+ def test_clear_taskinstance_is_called_with_queued_dr_state(self,
test_client, session):
+ """Test that if reset_dag_runs is True, then clear_task_instances is
called with State.QUEUED"""
+ self.create_task_instances(session)
+ dag_id = "example_python_operator"
+ payload = {"reset_dag_runs": True, "dry_run": False}
+ self.dagbag.sync_to_db()
+ with mock.patch(
+
"airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances",
+ ) as mp:
Review Comment:
I think the decorator mock from the legacy is more readable.
##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -1663,3 +1668,649 @@ def test_raises_404_for_nonexistent_task_instance(self,
test_client, session):
assert response.json() == {
"detail": "The Task Instance with dag_id:
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id:
`nonexistent_task`, try_number: `0` and map_index: `-1` was not found"
}
+
+
+class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
+ @pytest.mark.parametrize(
+ "main_dag, task_instances, request_dag, payload, expected_ti",
+ [
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "start_date": DEFAULT_DATETIME_STR_2,
+ "only_failed": True,
+ },
+ 2,
+ id="clear start date filter",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "end_date": DEFAULT_DATETIME_STR_2,
+ "only_failed": True,
+ },
+ 2,
+ id="clear end date filter",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.RUNNING},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.RUNNING,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {"dry_run": True, "only_running": True, "only_failed": False},
+ 2,
+ id="clear only running",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.RUNNING,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "only_failed": True,
+ },
+ 2,
+ id="clear only failed",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=3),
+ "state": State.FAILED,
+ },
+ ],
+ "example_python_operator",
+ {
+ "dry_run": True,
+ "task_ids": ["print_the_context", "sleep_for_1"],
+ },
+ 2,
+ id="clear by task ids",
+ ),
+ pytest.param(
+ "example_python_operator",
+ [
+ {"logical_date": DEFAULT_DATETIME_1, "state":
State.FAILED},
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=1),
+ "state": State.FAILED,
+ },
+ {
+ "logical_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=2),
+ "state": State.RUNNING,
+ },
+ ],
+ "example_python_operator",
+ {
+ "only_failed": True,
+ },
+ 2,
+ id="dry_run default",
+ ),
+ ],
+ )
+ def test_should_respond_200(
+ self,
+ test_client,
+ session,
+ main_dag,
+ task_instances,
+ request_dag,
+ payload,
+ expected_ti,
+ ):
+ self.create_task_instances(
+ session,
+ dag_id=main_dag,
+ task_instances=task_instances,
+ update_extras=False,
+ )
+ self.dagbag.sync_to_db()
+ response = test_client.post(
+ f"/public/dags/{request_dag}/clearTaskInstances",
+ json=payload,
+ )
+ assert response.status_code == 200
+ assert len(response.json()["task_instances"]) == expected_ti
+
+ def test_clear_taskinstance_is_called_with_queued_dr_state(self,
test_client, session):
+ """Test that if reset_dag_runs is True, then clear_task_instances is
called with State.QUEUED"""
+ self.create_task_instances(session)
+ dag_id = "example_python_operator"
+ payload = {"reset_dag_runs": True, "dry_run": False}
+ self.dagbag.sync_to_db()
+ with mock.patch(
+
"airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances",
+ ) as mp:
+ response = test_client.post(
+ f"/public/dags/{dag_id}/clearTaskInstances",
+ json=payload,
+ )
+ assert response.status_code == 200
+ mp.assert_called_once()
Review Comment:
Can we also assert that the parameters for the call are the expected ones ?
##########
tests/api_fastapi/conftest.py:
##########
@@ -36,3 +36,12 @@ def create_test_client(apps="all"):
return TestClient(app)
return create_test_client
+
+
[email protected](scope="module")
+def dagbag():
+ from airflow.models import DagBag
+
+ dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False)
+ dagbag_instance.sync_to_db()
+ return dagbag_instance
Review Comment:
If you put the fixture at the `conftest.py` level, do you mind also using it
in other tests, and replace.
Maybe `test_dag` in `test_dag_sources.py` can reuse this fixture:
```
@pytest.fixture
def test_dag(dagbag):
return dagbag.dags[TEST_DAG_ID]
```
##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -973,8 +978,8 @@ def test_return_TI_only_from_readable_dags(self,
test_client, session):
)
response = test_client.get("/public/dags/~/dagRuns/~/taskInstances")
assert response.status_code == 200
- assert response.json["total_entries"] == 3
- assert len(response.json["task_instances"]) == 3
+ assert response.json()["total_entries"] == 3
+ assert len(response.json()["task_instances"]) == 3
Review Comment:
I was wondering how this could pass. Just realized it's xfailed, thanks.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]