jason810496 commented on code in PR #67319: URL: https://github.com/apache/airflow/pull/67319#discussion_r3292096509
########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py: ########## @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from datetime import datetime, timedelta, timezone Review Comment: Should we import from shared? For example: https://github.com/apache/airflow/blob/d0f981c2ffb7cb0177096582ac4625f84fa1e6b9/airflow-core/src/airflow/api_fastapi/common/parameters.py#L43-L44 ########## airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py: ########## @@ -40,6 +41,21 @@ class TaskStateCollectionResponse(BaseModel): class TaskStateBody(StrictBaseModel): - """Request body for setting a task state value.""" + """ + Request body for setting a task state value. + + ``expires_at`` controls expiry: + + - ``"default"``: apply the configured ``[state_store] default_retention_days``. + - ``null``: never expire. + - aware datetime: expire at that time. + """ + + value: str = Field(max_length=65535) Review Comment: May I ask where is the `65535` threshold come from? ########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py: ########## @@ -147,24 +180,51 @@ def set_task_state( map_index: Annotated[int, Query(ge=-1)] = -1, ) -> None: """Set a task state value. Creates or overwrites the key.""" - ti_exists = session.scalar( - select(TI.task_id).where( - TI.dag_id == dag_id, - TI.run_id == dag_run_id, - TI.task_id == task_id, - TI.map_index == map_index, + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + expires_at = _resolve_expires_at(body.expires_at) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) + try: + get_state_backend().set(scope, key, body.value, expires_at=expires_at, session=session) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@task_state_router.patch( + "/{key:path}", + status_code=status.HTTP_204_NO_CONTENT, + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], +) +def patch_task_state( + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + body: TaskStatePatchBody, + session: SessionDep, + map_index: Annotated[int, Query(ge=-1)] = -1, +) -> None: + """Update the value of an existing task state key.""" + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + + existing = session.execute( + select(TaskStateModel.expires_at).where( + TaskStateModel.dag_id == dag_id, + TaskStateModel.run_id == dag_run_id, + TaskStateModel.task_id == task_id, + TaskStateModel.map_index == map_index, + TaskStateModel.key == key, ) - ) - if ti_exists is None: + ).one_or_none() + + if existing is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Task instance not found for dag_id={dag_id!r}, run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}", + detail=f"Task state key {key!r} not found", ) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) - try: - get_state_backend().set(scope, key, body.value, session=session) - except ValueError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + get_state_backend().set(scope, key, body.value, expires_at=existing.expires_at, session=session) Review Comment: Though it _should_ be safe to set the `expires_at` from `existing` in most of the case. Would it be better to make the upsert statement in `MetastoreStateBackend` to only set the `value` and keep the `expires_at` as-is in DB? Since there might still might be the case that concurrent request cause the phantom value for `expires_at` field. ########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py: ########## @@ -47,6 +50,36 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, map_index: int) -> Ta return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) +def _resolve_expires_at(expires_at: datetime | None | str) -> datetime | None: Review Comment: Would it be better to introduce a compute property at Pydantic level? ########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py: ########## @@ -147,24 +180,51 @@ def set_task_state( map_index: Annotated[int, Query(ge=-1)] = -1, ) -> None: """Set a task state value. Creates or overwrites the key.""" - ti_exists = session.scalar( - select(TI.task_id).where( - TI.dag_id == dag_id, - TI.run_id == dag_run_id, - TI.task_id == task_id, - TI.map_index == map_index, + _require_ti(dag_id, dag_run_id, task_id, map_index, session) + expires_at = _resolve_expires_at(body.expires_at) + scope = _get_scope(dag_id, dag_run_id, task_id, map_index) + try: + get_state_backend().set(scope, key, body.value, expires_at=expires_at, session=session) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@task_state_router.patch( + "/{key:path}", + status_code=status.HTTP_204_NO_CONTENT, Review Comment: IIRC, `patch` method should return 200 instead of 204 status code. ```suggestion ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
