This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 691b672d2d6 AIP-103: Add patch task state core API and support for 
expires_at in set API (#67319)
691b672d2d6 is described below

commit 691b672d2d6a710ef67e77faf3f8482f68e28357
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Jun 1 12:38:50 2026 +0530

    AIP-103: Add patch task state core API and support for expires_at in set 
API (#67319)
---
 .../api_fastapi/core_api/datamodels/task_state.py  |  33 ++++++-
 .../core_api/openapi/v2-rest-api-generated.yaml    | 110 ++++++++++++++++++++-
 .../core_api/routes/public/task_state.py           |  90 ++++++++++++++---
 .../src/airflow/ui/openapi-gen/queries/common.ts   |   1 +
 .../src/airflow/ui/openapi-gen/queries/queries.ts  |  30 +++++-
 .../airflow/ui/openapi-gen/requests/schemas.gen.ts |  38 ++++++-
 .../ui/openapi-gen/requests/services.gen.ts        |  39 +++++++-
 .../airflow/ui/openapi-gen/requests/types.gen.ts   |  50 ++++++++++
 .../core_api/routes/public/test_task_state.py      |  97 +++++++++++++++++-
 .../src/airflowctl/api/datamodels/generated.py     |  18 ++++
 10 files changed, 485 insertions(+), 21 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py 
b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py
index e6622f842e1..16d21435230 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_state.py
@@ -18,8 +18,9 @@ from __future__ import annotations
 
 import json
 from datetime import datetime
+from typing import Literal
 
-from pydantic import JsonValue, field_validator
+from pydantic import AwareDatetime, JsonValue, field_validator
 
 from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
 
@@ -43,7 +44,35 @@ class TaskStateCollectionResponse(BaseModel):
 
 
 class TaskStateBody(StrictBaseModel):
-    """Request body for setting a task state value."""
+    """
+    Request body for setting a task state value.
+
+    ``expires_at`` controls expiry:
+
+    - ``"default"``: apply the configured ``[state_store] 
default_retention_days``.
+    - ``null``: never expire.
+    - aware datetime: expire at that time.
+    """
+
+    value: JsonValue
+    expires_at: AwareDatetime | None | Literal["default"] = "default"
+
+    @field_validator("value")
+    @classmethod
+    def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
+        if v is None:
+            raise ValueError("value cannot be null")
+        try:
+            serialized = json.dumps(v, allow_nan=False)
+        except ValueError:
+            raise ValueError("value contains non-finite numbers; NaN and Inf 
are not JSON representable")
+        if len(serialized) > _MAX_SERIALIZED_BYTES:
+            raise ValueError(f"value exceeds maximum serialized size of 
{_MAX_SERIALIZED_BYTES} bytes")
+        return v
+
+
+class TaskStatePatchBody(StrictBaseModel):
+    """Request body for patching only the value of an existing task state 
key."""
 
     value: JsonValue
 
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
 
b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
index 86017f12a60..8c1097adbb1 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
@@ -6144,6 +6144,84 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+    patch:
+      tags:
+      - Task State
+      summary: Patch Task State
+      description: Update the value of an existing task state key.
+      operationId: patch_task_state
+      security:
+      - OAuth2PasswordBearer: []
+      - HTTPBearer: []
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: dag_run_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Run Id
+      - name: task_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Task Id
+      - name: key
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Key
+      - name: map_index
+        in: query
+        required: false
+        schema:
+          type: integer
+          minimum: -1
+          default: -1
+          title: Map Index
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/TaskStatePatchBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema: {}
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
     delete:
       tags:
       - Task State
@@ -15814,12 +15892,31 @@ components:
       properties:
         value:
           $ref: '#/components/schemas/JsonValue'
+        expires_at:
+          anyOf:
+          - type: string
+            format: date-time
+          - type: string
+            const: default
+          - type: 'null'
+          title: Expires At
+          default: default
       additionalProperties: false
       type: object
       required:
       - value
       title: TaskStateBody
-      description: Request body for setting a task state value.
+      description: 'Request body for setting a task state value.
+
+
+        ``expires_at`` controls expiry:
+
+
+        - ``"default"``: apply the configured ``[state_store] 
default_retention_days``.
+
+        - ``null``: never expire.
+
+        - aware datetime: expire at that time.'
     TaskStateCollectionResponse:
       properties:
         task_states:
@@ -15836,6 +15933,17 @@ components:
       - total_entries
       title: TaskStateCollectionResponse
       description: All task state entries for a task instance.
+    TaskStatePatchBody:
+      properties:
+        value:
+          $ref: '#/components/schemas/JsonValue'
+      additionalProperties: false
+      type: object
+      required:
+      - value
+      title: TaskStatePatchBody
+      description: Request body for patching only the value of an existing 
task state
+        key.
     TaskStateResponse:
       properties:
         key:
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py
index 31cc7272ddc..3ca667336f4 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_state.py
@@ -17,7 +17,8 @@
 from __future__ import annotations
 
 import json
-from typing import Annotated
+from datetime import datetime, timedelta, timezone
+from typing import Annotated, Literal
 
 from fastapi import Depends, HTTPException, Query, status
 from sqlalchemy import select
@@ -30,10 +31,12 @@ from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.datamodels.task_state import (
     TaskStateBody,
     TaskStateCollectionResponse,
+    TaskStatePatchBody,
     TaskStateResponse,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
 from airflow.api_fastapi.core_api.security import requires_access_dag
+from airflow.configuration import conf
 from airflow.models.task_state import TaskStateModel
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.state import get_state_backend
@@ -48,6 +51,36 @@ def _get_scope(dag_id: str, dag_run_id: str, task_id: str, 
map_index: int) -> Ta
     return TaskScope(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, 
map_index=map_index)
 
 
+def _resolve_expires_at(expires_at: datetime | None | Literal["default"]) -> 
datetime | None:
+    """
+    Resolve the expires_at value from the request body.
+
+    - ``"default"``: apply configured default_retention_days
+    - ``None``: never expire
+    - datetime: use as-is
+    """
+    if expires_at == "default":
+        days = conf.getint("state_store", "default_retention_days")
+        return datetime.now(tz=timezone.utc) + timedelta(days=days)
+    return expires_at
+
+
+def _require_ti(dag_id: str, dag_run_id: str, task_id: str, map_index: int, 
session: SessionDep) -> None:
+    ti_exists = session.scalar(
+        select(TI.task_id).where(
+            TI.dag_id == dag_id,
+            TI.run_id == dag_run_id,
+            TI.task_id == task_id,
+            TI.map_index == map_index,
+        )
+    )
+    if ti_exists is None:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"Task instance not found for dag_id={dag_id!r}, 
run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}",
+        )
+
+
 @task_state_router.get(
     "",
     responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
@@ -150,24 +183,53 @@ def set_task_state(
     map_index: Annotated[int, Query(ge=-1)] = -1,
 ) -> None:
     """Set a task state value. Creates or overwrites the key."""
-    ti_exists = session.scalar(
-        select(TI.task_id).where(
-            TI.dag_id == dag_id,
-            TI.run_id == dag_run_id,
-            TI.task_id == task_id,
-            TI.map_index == map_index,
+    _require_ti(dag_id, dag_run_id, task_id, map_index, session)
+    expires_at = _resolve_expires_at(body.expires_at)
+    scope = _get_scope(dag_id, dag_run_id, task_id, map_index)
+    try:
+        get_state_backend().set(scope, key, json.dumps(body.value), 
expires_at=expires_at, session=session)
+    except ValueError as e:
+        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, 
detail=str(e)) from e
+
+
+@task_state_router.patch(
+    "/{key:path}",
+    status_code=status.HTTP_200_OK,
+    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+    dependencies=[Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE))],
+)
+def patch_task_state(
+    dag_id: str,
+    dag_run_id: str,
+    task_id: str,
+    key: str,
+    body: TaskStatePatchBody,
+    session: SessionDep,
+    map_index: Annotated[int, Query(ge=-1)] = -1,
+) -> None:
+    """Update the value of an existing task state key."""
+    _require_ti(dag_id, dag_run_id, task_id, map_index, session)
+
+    existing = session.execute(
+        select(TaskStateModel.expires_at).where(
+            TaskStateModel.dag_id == dag_id,
+            TaskStateModel.run_id == dag_run_id,
+            TaskStateModel.task_id == task_id,
+            TaskStateModel.map_index == map_index,
+            TaskStateModel.key == key,
         )
-    )
-    if ti_exists is None:
+    ).one_or_none()
+
+    if existing is None:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
-            detail=f"Task instance not found for dag_id={dag_id!r}, 
run_id={dag_run_id!r}, task_id={task_id!r}, map_index={map_index}",
+            detail=f"Task state key {key!r} not found",
         )
+
     scope = _get_scope(dag_id, dag_run_id, task_id, map_index)
-    try:
-        get_state_backend().set(scope, key, json.dumps(body.value), 
session=session)
-    except ValueError as e:
-        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, 
detail=str(e)) from e
+    get_state_backend().set(
+        scope, key, json.dumps(body.value), expires_at=existing.expires_at, 
session=session
+    )
 
 
 @task_state_router.delete(
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts 
b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
index f8e3dbe9af6..632d3111576 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
@@ -1067,6 +1067,7 @@ export type 
TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<R
 export type TaskInstanceServiceUpdateHitlDetailMutationResult = 
Awaited<ReturnType<typeof TaskInstanceService.updateHitlDetail>>;
 export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof 
PoolService.patchPool>>;
 export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof 
PoolService.bulkPools>>;
+export type TaskStateServicePatchTaskStateMutationResult = 
Awaited<ReturnType<typeof TaskStateService.patchTaskState>>;
 export type XcomServiceUpdateXcomEntryMutationResult = 
Awaited<ReturnType<typeof XcomService.updateXcomEntry>>;
 export type VariableServicePatchVariableMutationResult = 
Awaited<ReturnType<typeof VariableService.patchVariable>>;
 export type VariableServiceBulkVariablesMutationResult = 
Awaited<ReturnType<typeof VariableService.bulkVariables>>;
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
index 8c0976ec328..01e3d78e69c 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
@@ -2,7 +2,7 @@
 
 import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from 
"@tanstack/react-query";
 import { AssetService, AssetStateService, AuthLinksService, BackfillService, 
CalendarService, ConfigService, ConnectionService, DagParsingService, 
DagRunService, DagService, DagSourceService, DagStatsService, 
DagVersionService, DagWarningService, DashboardService, DeadlinesService, 
DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, 
GanttService, GridService, ImportErrorService, JobService, LoginService, 
MonitorService, PartitionedDagRunService, PluginService, P [...]
-import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, 
BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, 
BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, 
CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, 
DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, 
MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, 
TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDe [...]
+import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, 
BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, 
BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, 
CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, 
DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, 
MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, 
TaskInstancesBatchBody, TaskStateBody, TaskStatePatchBody, TriggerDAGRunPo [...]
 import * as Common from "./common";
 /**
 * Get Assets
@@ -2801,6 +2801,34 @@ export const usePoolServiceBulkPools = <TData = 
Common.PoolServiceBulkPoolsMutat
   requestBody: BulkBody_PoolBody_;
 }, TContext>({ mutationFn: ({ requestBody }) => PoolService.bulkPools({ 
requestBody }) as unknown as Promise<TData>, ...options });
 /**
+* Patch Task State
+* Update the value of an existing task state key.
+* @param data The data for the request.
+* @param data.dagId
+* @param data.dagRunId
+* @param data.taskId
+* @param data.key
+* @param data.requestBody
+* @param data.mapIndex
+* @returns unknown Successful Response
+* @throws ApiError
+*/
+export const useTaskStateServicePatchTaskState = <TData = 
Common.TaskStateServicePatchTaskStateMutationResult, TError = unknown, TContext 
= unknown>(options?: Omit<UseMutationOptions<TData, TError, {
+  dagId: string;
+  dagRunId: string;
+  key: string;
+  mapIndex?: number;
+  requestBody: TaskStatePatchBody;
+  taskId: string;
+}, TContext>, "mutationFn">) => useMutation<TData, TError, {
+  dagId: string;
+  dagRunId: string;
+  key: string;
+  mapIndex?: number;
+  requestBody: TaskStatePatchBody;
+  taskId: string;
+}, TContext>({ mutationFn: ({ dagId, dagRunId, key, mapIndex, requestBody, 
taskId }) => TaskStateService.patchTaskState({ dagId, dagRunId, key, mapIndex, 
requestBody, taskId }) as unknown as Promise<TData>, ...options });
+/**
 * Update Xcom Entry
 * Update an existing XCom entry.
 * @param data The data for the request.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
index b00c1833c16..742a18b09ef 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -6992,13 +6992,36 @@ export const $TaskStateBody = {
     properties: {
         value: {
             '$ref': '#/components/schemas/JsonValue'
+        },
+        expires_at: {
+            anyOf: [
+                {
+                    type: 'string',
+                    format: 'date-time'
+                },
+                {
+                    type: 'string',
+                    const: 'default'
+                },
+                {
+                    type: 'null'
+                }
+            ],
+            title: 'Expires At',
+            default: 'default'
         }
     },
     additionalProperties: false,
     type: 'object',
     required: ['value'],
     title: 'TaskStateBody',
-    description: 'Request body for setting a task state value.'
+    description: `Request body for setting a task state value.
+
+\`\`expires_at\`\` controls expiry:
+
+- \`\`"default"\`\`: apply the configured \`\`[state_store] 
default_retention_days\`\`.
+- \`\`null\`\`: never expire.
+- aware datetime: expire at that time.`
 } as const;
 
 export const $TaskStateCollectionResponse = {
@@ -7021,6 +7044,19 @@ export const $TaskStateCollectionResponse = {
     description: 'All task state entries for a task instance.'
 } as const;
 
+export const $TaskStatePatchBody = {
+    properties: {
+        value: {
+            '$ref': '#/components/schemas/JsonValue'
+        }
+    },
+    additionalProperties: false,
+    type: 'object',
+    required: ['value'],
+    title: 'TaskStatePatchBody',
+    description: 'Request body for patching only the value of an existing task 
state key.'
+} as const;
+
 export const $TaskStateResponse = {
     properties: {
         key: {
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
index e65f313e1d9..3e659835303 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -3,7 +3,7 @@
 import type { CancelablePromise } from './core/CancelablePromise';
 import { OpenAPI } from './core/OpenAPI';
 import { request as __request } from './core/request';
-import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, 
GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, 
GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, 
CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, 
GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, 
DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, 
GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, 
Dele [...]
+import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, 
GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, 
GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, 
CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, 
GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, 
DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, 
GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, 
Dele [...]
 
 export class AssetService {
     /**
@@ -3813,6 +3813,43 @@ export class TaskStateService {
         });
     }
     
+    /**
+     * Patch Task State
+     * Update the value of an existing task state key.
+     * @param data The data for the request.
+     * @param data.dagId
+     * @param data.dagRunId
+     * @param data.taskId
+     * @param data.key
+     * @param data.requestBody
+     * @param data.mapIndex
+     * @returns unknown Successful Response
+     * @throws ApiError
+     */
+    public static patchTaskState(data: PatchTaskStateData): 
CancelablePromise<PatchTaskStateResponse> {
+        return __request(OpenAPI, {
+            method: 'PATCH',
+            url: 
'/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/states/{key}',
+            path: {
+                dag_id: data.dagId,
+                dag_run_id: data.dagRunId,
+                task_id: data.taskId,
+                key: data.key
+            },
+            query: {
+                map_index: data.mapIndex
+            },
+            body: data.requestBody,
+            mediaType: 'application/json',
+            errors: {
+                401: 'Unauthorized',
+                403: 'Forbidden',
+                404: 'Not Found',
+                422: 'Validation Error'
+            }
+        });
+    }
+    
     /**
      * Delete Task State
      * Delete a single task state key. No-op if the key does not exist.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
index e674f27e8a3..16929844a94 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1714,9 +1714,16 @@ export type TaskResponse = {
 
 /**
  * Request body for setting a task state value.
+ *
+ * ``expires_at`` controls expiry:
+ *
+ * - ``"default"``: apply the configured ``[state_store] 
default_retention_days``.
+ * - ``null``: never expire.
+ * - aware datetime: expire at that time.
  */
 export type TaskStateBody = {
     value: JsonValue;
+    expires_at?: string | "default" | null;
 };
 
 /**
@@ -1727,6 +1734,13 @@ export type TaskStateCollectionResponse = {
     total_entries: number;
 };
 
+/**
+ * Request body for patching only the value of an existing task state key.
+ */
+export type TaskStatePatchBody = {
+    value: JsonValue;
+};
+
 /**
  * A single task state key/value pair with metadata.
  */
@@ -3968,6 +3982,17 @@ export type SetTaskStateData = {
 
 export type SetTaskStateResponse = void;
 
+export type PatchTaskStateData = {
+    dagId: string;
+    dagRunId: string;
+    key: string;
+    mapIndex?: number;
+    requestBody: TaskStatePatchBody;
+    taskId: string;
+};
+
+export type PatchTaskStateResponse = unknown;
+
 export type DeleteTaskStateData = {
     dagId: string;
     dagRunId: string;
@@ -7262,6 +7287,31 @@ export type $OpenApiTs = {
                 422: HTTPValidationError;
             };
         };
+        patch: {
+            req: PatchTaskStateData;
+            res: {
+                /**
+                 * Successful Response
+                 */
+                200: unknown;
+                /**
+                 * Unauthorized
+                 */
+                401: HTTPExceptionResponse;
+                /**
+                 * Forbidden
+                 */
+                403: HTTPExceptionResponse;
+                /**
+                 * Not Found
+                 */
+                404: HTTPExceptionResponse;
+                /**
+                 * Validation Error
+                 */
+                422: HTTPValidationError;
+            };
+        };
         delete: {
             req: DeleteTaskStateData;
             res: {
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py
index e96aca22cdd..8481a6fc43b 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_state.py
@@ -23,12 +23,13 @@ from pydantic import ValidationError
 from sqlalchemy import select
 
 from airflow._shared.timezones import timezone
-from airflow.api_fastapi.core_api.datamodels.task_state import TaskStateBody
+from airflow.api_fastapi.core_api.datamodels.task_state import TaskStateBody, 
TaskStatePatchBody
 from airflow.models.dagrun import DagRun
 from airflow.models.task_state import TaskStateModel
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.utils.types import DagRunType
 
+from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags, 
clear_db_runs
 
 pytestmark = pytest.mark.db_test
@@ -258,10 +259,104 @@ class TestSetTaskState(TestTaskStateEndpoint):
         assert response.status_code == 204
         assert test_client.get(f"{BASE_URL}/workflow/step_1").json()["key"] == 
"workflow/step_1"
 
+    def test_new_key_default_retention_applies_config(self, test_client, 
time_machine):
+        time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False)
+        with conf_vars({("state_store", "default_retention_days"): "7"}):
+            test_client.put(f"{BASE_URL}/job_id", json={"value": "v", 
"expires_at": "default"})
+
+        resp = test_client.get(f"{BASE_URL}/job_id").json()
+        assert resp["expires_at"] == "2026-01-08T00:00:00Z"
+
+    def test_new_key_never_expiry(self, test_client):
+        """PUT with expires_at=null stores a key that never expires."""
+        test_client.put(f"{BASE_URL}/job_id", json={"value": "v", 
"expires_at": None})
+        assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] is 
None
+
+    def test_new_key_explicit_expiry(self, test_client, time_machine):
+        """PUT with an explicit datetime uses that as expires_at."""
+        time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False)
+        target = "2026-01-31T00:00:00Z"
+        test_client.put(f"{BASE_URL}/job_id", json={"value": "v", 
"expires_at": target})
+        assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] == 
target
+
+    def test_put_overwrites_expiry_on_existing_key(self, test_client, 
time_machine):
+        """PUT on an existing key replaces expires_at with whatever the body 
specifies."""
+        time_machine.move_to("2026-01-01T00:00:00+00:00", tick=False)
+        test_client.put(f"{BASE_URL}/job_id", json={"value": "v1", 
"expires_at": "2026-01-31T00:00:00Z"})
+
+        # second request but with null expires_at
+        test_client.put(f"{BASE_URL}/job_id", json={"value": "v2", 
"expires_at": None})
+
+        resp = test_client.get(f"{BASE_URL}/job_id").json()
+        assert resp["value"] == "v2"
+        assert resp["expires_at"] is None
+
     def test_unauthorized_returns_401(self, unauthenticated_test_client):
         assert unauthenticated_test_client.put(f"{BASE_URL}/job_id", 
json={"value": "v"}).status_code == 401
 
 
+class TestPatchTaskState(TestTaskStateEndpoint):
+    def test_patch_updates_value(self, test_client):
+        _create_task_state(self._session, "job_id", "v1", self.dag_run)
+        self._session.commit()
+
+        assert test_client.patch(f"{BASE_URL}/job_id", json={"value": 
"v2"}).status_code == 200
+        row = self._session.scalar(
+            select(TaskStateModel).where(
+                TaskStateModel.dag_id == DAG_ID,
+                TaskStateModel.run_id == RUN_ID,
+                TaskStateModel.task_id == TASK_ID,
+                TaskStateModel.key == "job_id",
+            )
+        )
+        assert row.value == '"v2"'
+
+    def test_patch_missing_key_returns_404(self, test_client):
+        assert test_client.patch(f"{BASE_URL}/nonexistent", json={"value": 
"v"}).status_code == 404
+
+    def test_patch_empty_body_returns_422(self, test_client):
+        _create_task_state(self._session, "job_id", "v", self.dag_run)
+        self._session.commit()
+        assert test_client.patch(f"{BASE_URL}/job_id", json={}).status_code == 
422
+
+    def test_patch_null_value_returns_422(self, test_client):
+        _create_task_state(self._session, "job_id", "v", self.dag_run)
+        self._session.commit()
+        assert test_client.patch(f"{BASE_URL}/job_id", json={"value": 
None}).status_code == 422
+
+    @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), {"a": 
float("nan")}, [float("inf")]])
+    def test_patch_non_finite_float_rejected_by_validator(self, bad_value):
+        with pytest.raises(ValidationError, match="non-finite"):
+            TaskStatePatchBody(value=bad_value)
+
+    @pytest.mark.parametrize(
+        ("value", "expected_db"),
+        [
+            (42, "42"),
+            ("hello", '"hello"'),
+            ({"k": 1}, '{"k": 1}'),
+            ([1, 2], "[1, 2]"),
+        ],
+    )
+    def test_patch_stores_json_encoded_value(self, test_client, value, 
expected_db):
+        _create_task_state(self._session, "job_id", "initial", self.dag_run)
+        self._session.commit()
+        test_client.patch(f"{BASE_URL}/job_id", json={"value": value})
+        row = self._session.scalar(
+            select(TaskStateModel).where(
+                TaskStateModel.dag_id == DAG_ID,
+                TaskStateModel.run_id == RUN_ID,
+                TaskStateModel.task_id == TASK_ID,
+                TaskStateModel.key == "job_id",
+            )
+        )
+        self._session.refresh(row)
+        assert row.value == expected_db
+
+    def test_unauthorized_returns_401(self, unauthenticated_test_client):
+        assert unauthenticated_test_client.patch(f"{BASE_URL}/job_id", 
json={"value": "v"}).status_code == 401
+
+
 class TestDeleteTaskState(TestTaskStateEndpoint):
     def test_deletes_key(self, test_client):
         _create_task_state(self._session, "job_id", "spark_001", self.dag_run)
diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py 
b/airflow-ctl/src/airflowctl/api/datamodels/generated.py
index 63444c465d9..4ad4fd39cb9 100644
--- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py
+++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py
@@ -956,6 +956,24 @@ class TaskOutletAssetReference(BaseModel):
 class TaskStateBody(BaseModel):
     """
     Request body for setting a task state value.
+
+    ``expires_at`` controls expiry:
+
+    - ``"default"``: apply the configured ``[state_store] 
default_retention_days``.
+    - ``null``: never expire.
+    - aware datetime: expire at that time.
+    """
+
+    model_config = ConfigDict(
+        extra="forbid",
+    )
+    value: JsonValue
+    expires_at: Annotated[datetime | str | None, Field(title="Expires At")] = 
"default"
+
+
+class TaskStatePatchBody(BaseModel):
+    """
+    Request body for patching only the value of an existing task state key.
     """
 
     model_config = ConfigDict(

Reply via email to