pierrejeambrun commented on code in PR #44220:
URL: https://github.com/apache/airflow/pull/44220#discussion_r1851661003


##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -265,7 +268,7 @@ def get_mapped_task_instance(
 
 
 @task_instances_router.get(
-    "",
+    task_instances_prefix + "",

Review Comment:
   ```suggestion
       task_instances_prefix,
   ```



##########
airflow/api_fastapi/common/types.py:
##########
@@ -59,6 +66,20 @@ class TimeDelta(BaseModel):
 TimeDeltaWithValidation = Annotated[TimeDelta, 
BeforeValidator(_validate_timedelta_field)]
 
 
+def _validate_nonnaive_datetime_field(dt: datetime | None) -> datetime | None:
+    """Validate and return the datetime field."""
+    if dt is None:
+        return None
+    if isinstance(dt, str):
+        dt = datetime.fromisoformat(dt)
+    if not dt.tzinfo:
+        raise ValueError("Invalid datetime format, Naive datetime is 
disallowed")
+    return dt
+
+
+DatetimeWithNonNaiveValidation = Annotated[datetime, 
BeforeValidator(_validate_nonnaive_datetime_field)]
+
+

Review Comment:
   Pydantic has a native `AwareDatetime` (we use it a little bit) that does 
exactly that I believe.
   
   At some point we might want to switch all datetime to `AwareDatetime` I think



##########
airflow/api_fastapi/core_api/datamodels/task_instances.py:
##########
@@ -150,3 +154,54 @@ class TaskInstanceHistoryCollectionResponse(BaseModel):
 
     task_instances: list[TaskInstanceHistoryResponse]
     total_entries: int
+
+
+class ClearTaskInstancesBody(BaseModel):
+    """Request body for Clear Task Instances endpoint."""
+
+    dry_run: bool = True
+    start_date: DatetimeWithNonNaiveValidation | None = None
+    end_date: DatetimeWithNonNaiveValidation | None = None
+    only_failed: bool = True
+    only_running: bool = False
+    reset_dag_runs: bool = False
+    task_ids: list[str] | None = None
+    dag_run_id: str | None = None
+    include_upstream: bool = False
+    include_downstream: bool = False
+    include_future: bool = False
+    include_past: bool = False
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_model(cls, data: Any) -> Any:
+        """Validate clear task instance form."""
+        if data.get("only_failed") and data.get("only_running"):
+            raise ValidationError("only_failed and only_running both are set 
to True")
+        if data.get("start_date") and data.get("end_date"):
+            if data.get("start_date") > data.get("end_date"):
+                raise ValidationError("end_date is sooner than start_date")
+        if data.get("start_date") and data.get("end_date") and 
data.get("dag_run_id"):
+            raise ValidationError("Exactly one of dag_run_id or (start_date 
and end_date) must be provided")
+        if data.get("start_date") and data.get("dag_run_id"):
+            raise ValidationError("Exactly one of dag_run_id or start_date 
must be provided")
+        if data.get("end_date") and data.get("dag_run_id"):
+            raise ValidationError("Exactly one of dag_run_id or end_date must 
be provided")
+        if isinstance(data.get("task_ids"), list) and 
len(data.get("task_ids")) < 1:
+            raise ValidationError("task_ids list should have at least 1 
element.")
+        return data
+
+
+class TaskInstanceReferenceResponse(BaseModel):
+    """Task Instance Reference serializer for responses."""
+
+    task_id: str
+    dag_run_id: str = Field(validation_alias=AliasChoices("run_id"))

Review Comment:
   Alias choice with only 1 choice should be equivalent to that:
   ```suggestion
       dag_run_id: str = Field(validation_alias="run_id")
   ```



##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -482,3 +485,88 @@ def get_mapped_task_instance_try_details(
         map_index=map_index,
         session=session,
     )
+
+
+@task_instances_router.post(
+    "/clearTaskInstances",
+    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+)
+def post_clear_task_instances(
+    dag_id: str,
+    request: Request,
+    body: ClearTaskInstancesBody,
+    session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceReferenceCollectionResponse:
+    """Clear task instances."""
+    dag = request.app.state.dag_bag.get_dag(dag_id)
+    if not dag:
+        error_message = f"DAG {dag_id} not found"
+        raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+
+    reset_dag_runs = body.reset_dag_runs
+    dry_run = body.dry_run
+    # We always pass dry_run here, otherwise this would try to confirm on the 
terminal!
+    dag_run_id = body.dag_run_id
+    future = body.include_future
+    past = body.include_past
+    downstream = body.include_downstream
+    upstream = body.include_upstream
+
+    if dag_run_id is not None:
+        dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id == 
dag_id, DR.run_id == dag_run_id))
+        if dag_run is None:
+            error_message = f"Dag Run id {dag_run_id} not found in dag 
{dag_id}"
+            raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+        body.start_date = dag_run.logical_date
+        body.end_date = dag_run.logical_date
+
+    if past:
+        body.start_date = None
+
+    if future:
+        body.end_date = None
+
+    task_ids = body.task_ids
+    if task_ids is not None:
+        task_id = [task[0] if isinstance(task, tuple) else task for task in 
task_ids]
+        dag = dag.partial_subset(
+            task_ids_or_regex=task_id,
+            include_downstream=downstream,
+            include_upstream=upstream,
+        )
+
+        if len(dag.task_dict) > 1:
+            # If we had upstream/downstream etc then also include those!
+            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
+
+    task_instances = dag.clear(
+        dry_run=True,
+        task_ids=body.task_ids,

Review Comment:
   We need to use the `extended` task_ids with upstream/downstream too. Not 
only the `taks_ids` from the payload.
   ```suggestion
           task_ids=task_ids,
   ```



##########
airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -482,3 +485,88 @@ def get_mapped_task_instance_try_details(
         map_index=map_index,
         session=session,
     )
+
+
+@task_instances_router.post(
+    "/clearTaskInstances",
+    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+)
+def post_clear_task_instances(
+    dag_id: str,
+    request: Request,
+    body: ClearTaskInstancesBody,
+    session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceReferenceCollectionResponse:
+    """Clear task instances."""
+    dag = request.app.state.dag_bag.get_dag(dag_id)
+    if not dag:
+        error_message = f"DAG {dag_id} not found"
+        raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+
+    reset_dag_runs = body.reset_dag_runs
+    dry_run = body.dry_run
+    # We always pass dry_run here, otherwise this would try to confirm on the 
terminal!
+    dag_run_id = body.dag_run_id
+    future = body.include_future
+    past = body.include_past
+    downstream = body.include_downstream
+    upstream = body.include_upstream
+
+    if dag_run_id is not None:
+        dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id == 
dag_id, DR.run_id == dag_run_id))
+        if dag_run is None:
+            error_message = f"Dag Run id {dag_run_id} not found in dag 
{dag_id}"
+            raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
+        body.start_date = dag_run.logical_date
+        body.end_date = dag_run.logical_date
+
+    if past:
+        body.start_date = None
+
+    if future:
+        body.end_date = None
+
+    task_ids = body.task_ids
+    if task_ids is not None:
+        task_id = [task[0] if isinstance(task, tuple) else task for task in 
task_ids]
+        dag = dag.partial_subset(
+            task_ids_or_regex=task_id,
+            include_downstream=downstream,
+            include_upstream=upstream,
+        )
+
+        if len(dag.task_dict) > 1:
+            # If we had upstream/downstream etc then also include those!
+            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
+
+    task_instances = dag.clear(
+        dry_run=True,
+        task_ids=body.task_ids,
+        dag_bag=request.app.state.dag_bag,
+        **body.model_dump(
+            include=[  # type: ignore[arg-type]
+                "start_date",
+                "end_date",
+                "only_failed",
+                "only_running",
+            ]

Review Comment:
   I think a set will remove the typing error.
   ```suggestion
               include={
                   "start_date",
                   "end_date",
                   "only_failed",
                   "only_running",
               }
   ```



##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -1663,3 +1668,649 @@ def test_raises_404_for_nonexistent_task_instance(self, 
test_client, session):
         assert response.json() == {
             "detail": "The Task Instance with dag_id: 
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: 
`nonexistent_task`, try_number: `0` and map_index: `-1` was not found"
         }
+
+
+class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
+    @pytest.mark.parametrize(
+        "main_dag, task_instances, request_dag, payload, expected_ti",
+        [
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "start_date": DEFAULT_DATETIME_STR_2,
+                    "only_failed": True,
+                },
+                2,
+                id="clear start date filter",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "end_date": DEFAULT_DATETIME_STR_2,
+                    "only_failed": True,
+                },
+                2,
+                id="clear end date filter",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.RUNNING},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.RUNNING,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {"dry_run": True, "only_running": True, "only_failed": False},
+                2,
+                id="clear only running",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.RUNNING,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "only_failed": True,
+                },
+                2,
+                id="clear only failed",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=3),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "task_ids": ["print_the_context", "sleep_for_1"],
+                },
+                2,
+                id="clear by task ids",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.RUNNING,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "only_failed": True,
+                },
+                2,
+                id="dry_run default",
+            ),
+        ],
+    )
+    def test_should_respond_200(
+        self,
+        test_client,
+        session,
+        main_dag,
+        task_instances,
+        request_dag,
+        payload,
+        expected_ti,
+    ):
+        self.create_task_instances(
+            session,
+            dag_id=main_dag,
+            task_instances=task_instances,
+            update_extras=False,
+        )
+        self.dagbag.sync_to_db()
+        response = test_client.post(
+            f"/public/dags/{request_dag}/clearTaskInstances",
+            json=payload,
+        )
+        assert response.status_code == 200
+        assert len(response.json()["task_instances"]) == expected_ti
+
+    def test_clear_taskinstance_is_called_with_queued_dr_state(self, 
test_client, session):
+        """Test that if reset_dag_runs is True, then clear_task_instances is 
called with State.QUEUED"""
+        self.create_task_instances(session)
+        dag_id = "example_python_operator"
+        payload = {"reset_dag_runs": True, "dry_run": False}
+        self.dagbag.sync_to_db()
+        with mock.patch(
+            
"airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances",
+        ) as mp:

Review Comment:
   I think the decorator mock from the legacy is more readable.



##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -1663,3 +1668,649 @@ def test_raises_404_for_nonexistent_task_instance(self, 
test_client, session):
         assert response.json() == {
             "detail": "The Task Instance with dag_id: 
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: 
`nonexistent_task`, try_number: `0` and map_index: `-1` was not found"
         }
+
+
+class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
+    @pytest.mark.parametrize(
+        "main_dag, task_instances, request_dag, payload, expected_ti",
+        [
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "start_date": DEFAULT_DATETIME_STR_2,
+                    "only_failed": True,
+                },
+                2,
+                id="clear start date filter",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "end_date": DEFAULT_DATETIME_STR_2,
+                    "only_failed": True,
+                },
+                2,
+                id="clear end date filter",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.RUNNING},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.RUNNING,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {"dry_run": True, "only_running": True, "only_failed": False},
+                2,
+                id="clear only running",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.RUNNING,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "only_failed": True,
+                },
+                2,
+                id="clear only failed",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=3),
+                        "state": State.FAILED,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "dry_run": True,
+                    "task_ids": ["print_the_context", "sleep_for_1"],
+                },
+                2,
+                id="clear by task ids",
+            ),
+            pytest.param(
+                "example_python_operator",
+                [
+                    {"logical_date": DEFAULT_DATETIME_1, "state": 
State.FAILED},
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=1),
+                        "state": State.FAILED,
+                    },
+                    {
+                        "logical_date": DEFAULT_DATETIME_1 + 
dt.timedelta(days=2),
+                        "state": State.RUNNING,
+                    },
+                ],
+                "example_python_operator",
+                {
+                    "only_failed": True,
+                },
+                2,
+                id="dry_run default",
+            ),
+        ],
+    )
+    def test_should_respond_200(
+        self,
+        test_client,
+        session,
+        main_dag,
+        task_instances,
+        request_dag,
+        payload,
+        expected_ti,
+    ):
+        self.create_task_instances(
+            session,
+            dag_id=main_dag,
+            task_instances=task_instances,
+            update_extras=False,
+        )
+        self.dagbag.sync_to_db()
+        response = test_client.post(
+            f"/public/dags/{request_dag}/clearTaskInstances",
+            json=payload,
+        )
+        assert response.status_code == 200
+        assert len(response.json()["task_instances"]) == expected_ti
+
+    def test_clear_taskinstance_is_called_with_queued_dr_state(self, 
test_client, session):
+        """Test that if reset_dag_runs is True, then clear_task_instances is 
called with State.QUEUED"""
+        self.create_task_instances(session)
+        dag_id = "example_python_operator"
+        payload = {"reset_dag_runs": True, "dry_run": False}
+        self.dagbag.sync_to_db()
+        with mock.patch(
+            
"airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances",
+        ) as mp:
+            response = test_client.post(
+                f"/public/dags/{dag_id}/clearTaskInstances",
+                json=payload,
+            )
+            assert response.status_code == 200
+            mp.assert_called_once()

Review Comment:
   Can we also assert that the parameters for the call are the expected ones ?



##########
tests/api_fastapi/conftest.py:
##########
@@ -36,3 +36,12 @@ def create_test_client(apps="all"):
         return TestClient(app)
 
     return create_test_client
+
+
[email protected](scope="module")
+def dagbag():
+    from airflow.models import DagBag
+
+    dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False)
+    dagbag_instance.sync_to_db()
+    return dagbag_instance

Review Comment:
   If you put the fixture at the `conftest.py` level, do you mind also using it 
in other tests, and replace.
   
   Maybe `test_dag` in `test_dag_sources.py` can reuse this fixture:
   ```
   @pytest.fixture
   def test_dag(dagbag):
       return dagbag.dags[TEST_DAG_ID]
   ```



##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -973,8 +978,8 @@ def test_return_TI_only_from_readable_dags(self, 
test_client, session):
             )
         response = test_client.get("/public/dags/~/dagRuns/~/taskInstances")
         assert response.status_code == 200
-        assert response.json["total_entries"] == 3
-        assert len(response.json["task_instances"]) == 3
+        assert response.json()["total_entries"] == 3
+        assert len(response.json()["task_instances"]) == 3

Review Comment:
   I was wondering how this could pass. Just realized it's xfailed, thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to