This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b59126d7c90 Simplifing authoring of task and asset states by allowing 
JSON types (#67418)
b59126d7c90 is described below

commit b59126d7c90e355476649dfb7a74a483c468422b
Author: Amogh Desai <[email protected]>
AuthorDate: Mon May 25 19:16:28 2026 +0530

    Simplifing authoring of task and asset states by allowing JSON types 
(#67418)
---
 .../execution_api/datamodels/asset_state.py        | 17 ++++++-
 .../execution_api/datamodels/task_state.py         | 16 ++++++-
 .../execution_api/routes/asset_state.py            |  9 ++--
 .../api_fastapi/execution_api/routes/task_state.py |  5 +-
 .../execution_api/versions/v2026_06_16.py          |  2 +-
 .../versions/head/test_asset_state.py              | 51 +++++++++++++++++++-
 .../execution_api/versions/head/test_task_state.py | 47 +++++++++++++++++-
 shared/state/pyproject.toml                        |  4 +-
 shared/state/src/airflow_shared/state/__init__.py  | 55 ++++++++++++----------
 shared/state/tests/state/test_state.py             | 23 +++++++++
 task-sdk/src/airflow/sdk/api/client.py             |  8 ++--
 .../src/airflow/sdk/api/datamodels/_generated.py   | 48 +++++++++----------
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  6 +--
 task-sdk/src/airflow/sdk/execution_time/context.py | 37 ++++++++++-----
 .../airflow/sdk/execution_time/schema/schema.json  | 30 ++++++++----
 .../tests/task_sdk/execution_time/test_context.py  | 23 +++++++--
 .../task_sdk/execution_time/test_task_runner.py    | 34 +++++++++++++
 uv.lock                                            |  4 ++
 18 files changed, 324 insertions(+), 95 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
index ec773201c7e..ab8b3aa2aec 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
@@ -17,16 +17,29 @@
 
 from __future__ import annotations
 
+import math
+
+from pydantic import JsonValue, field_validator
+
 from airflow.api_fastapi.core_api.base import StrictBaseModel
 
 
 class AssetStateResponse(StrictBaseModel):
     """Asset state value returned to a worker."""
 
-    value: str
+    value: JsonValue
 
 
 class AssetStatePutBody(StrictBaseModel):
     """Request body for setting an asset state value."""
 
-    value: str
+    value: JsonValue
+
+    @field_validator("value")
+    @classmethod
+    def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
+        if v is None:
+            raise ValueError("value cannot be null")
+        if isinstance(v, float) and not math.isfinite(v):
+            raise ValueError("value must be a finite number; NaN and Inf are 
not JSON representable")
+        return v
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
index 20980b315c3..15fc44b7267 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
@@ -17,19 +17,31 @@
 
 from __future__ import annotations
 
+import math
 from datetime import datetime
 
+from pydantic import JsonValue, field_validator
+
 from airflow.api_fastapi.core_api.base import StrictBaseModel
 
 
 class TaskStateResponse(StrictBaseModel):
     """Task state value returned to a worker."""
 
-    value: str
+    value: JsonValue
 
 
 class TaskStatePutBody(StrictBaseModel):
     """Request body for setting a task state value."""
 
-    value: str
+    value: JsonValue
     expires_at: datetime | None = None
+
+    @field_validator("value")
+    @classmethod
+    def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
+        if v is None:
+            raise ValueError("value cannot be null")
+        if isinstance(v, float) and not math.isfinite(v):
+            raise ValueError("value must be a finite number; NaN and Inf are 
not JSON representable")
+        return v
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
index 2351caa6dfa..f7001c3158c 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
@@ -28,6 +28,7 @@ Per-task asset registration checks are intentionally not 
implemented here
 
 from __future__ import annotations
 
+import json
 from typing import Annotated
 
 from cadwyn import VersionedAPIRouter
@@ -93,7 +94,7 @@ def get_asset_state_by_name(
             status_code=status.HTTP_404_NOT_FOUND,
             detail={"reason": "not_found", "message": f"Asset state key 
{key!r} not found"},
         )
-    return AssetStateResponse(value=value)
+    return AssetStateResponse(value=json.loads(value))
 
 
 @router.put("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -105,7 +106,7 @@ def set_asset_state_by_name(
 ) -> None:
     """Set an asset state value by asset name."""
     asset_id = _resolve_asset_id_by_name(name, session)
-    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)
+    get_state_backend().set(AssetScope(asset_id=asset_id), key, 
json.dumps(body.value), session=session)
 
 
 @router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -143,7 +144,7 @@ def get_asset_state_by_uri(
             status_code=status.HTTP_404_NOT_FOUND,
             detail={"reason": "not_found", "message": f"Asset state key 
{key!r} not found"},
         )
-    return AssetStateResponse(value=value)
+    return AssetStateResponse(value=json.loads(value))
 
 
 @router.put("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -155,7 +156,7 @@ def set_asset_state_by_uri(
 ) -> None:
     """Set an asset state value by asset URI."""
     asset_id = _resolve_asset_id_by_uri(uri, session)
-    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)
+    get_state_backend().set(AssetScope(asset_id=asset_id), key, 
json.dumps(body.value), session=session)
 
 
 @router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index 2f824e3ebb2..c59f2461e2a 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from typing import Annotated
 from uuid import UUID
 
@@ -74,7 +75,7 @@ def get_task_state(
                 "message": f"Task state key {key!r} not found",
             },
         )
-    return TaskStateResponse(value=value)
+    return TaskStateResponse(value=json.loads(value))
 
 
 @router.put("/{task_instance_id}/{key}", 
status_code=status.HTTP_204_NO_CONTENT)
@@ -86,7 +87,7 @@ def set_task_state(
 ) -> None:
     """Set a task state key, creating or updating the row."""
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    get_state_backend().set(scope, key, body.value, 
expires_at=body.expires_at, session=session)
+    get_state_backend().set(scope, key, json.dumps(body.value), 
expires_at=body.expires_at, session=session)
 
 
 @router.delete("/{task_instance_id}/{key}", 
status_code=status.HTTP_204_NO_CONTENT)
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
index 779612bbde1..cd2a6861d11 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
@@ -60,7 +60,7 @@ class AddAssetsByAliasEndpoint(VersionChange):
 
 
 class AddStateEndpoints(VersionChange):
-    """Add task state and asset state CRUD endpoints."""
+    """Add task state and asset state API endpoints."""
 
     description = __doc__
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
index 6041d01e7f1..c91171aa05e 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from typing import TYPE_CHECKING
 
 import pytest
@@ -113,7 +114,37 @@ class TestPutAssetStateByName:
             )
         )
         assert row is not None
-        assert row.value == "2026-04-29"
+        # DB stores JSON-encoded string
+        assert row.value == '"2026-04-29"'
+
+    def test_put_int_value_roundtrip(self, client: TestClient, asset: 
AssetModel):
+        response = client.put(
+            _BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"}, 
json={"value": 5}
+        )
+        assert response.status_code == 204
+        assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": 
"total_runs"}).json() == {
+            "value": 5
+        }
+
+    def test_put_dict_value_roundtrip(self, client: TestClient, asset: 
AssetModel):
+        response = client.put(
+            _BY_NAME_VALUE,
+            params={"name": asset.name, "key": "last_run"},
+            json={"value": {"rows": 1234, "status": "ok"}},
+        )
+        assert response.status_code == 204
+        assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": 
"last_run"}).json() == {
+            "value": {"rows": 1234, "status": "ok"}
+        }
+
+    def test_put_list_value_roundtrip(self, client: TestClient, asset: 
AssetModel):
+        response = client.put(
+            _BY_NAME_VALUE, params={"name": asset.name, "key": "ids"}, 
json={"value": [1, 2, 3]}
+        )
+        assert response.status_code == 204
+        assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": 
"ids"}).json() == {
+            "value": [1, 2, 3]
+        }
 
     def test_put_overwrites_existing(self, client: TestClient, asset: 
AssetModel):
         client.put(
@@ -134,6 +165,22 @@ class TestPutAssetStateByName:
 
         assert response.status_code == 422
 
+    def test_put_null_value_returns_422(self, client: TestClient, asset: 
AssetModel):
+        response = client.put(
+            _BY_NAME_VALUE, params={"name": asset.name, "key": "watermark"}, 
json={"value": None}
+        )
+        assert response.status_code == 422
+
+    @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), 
float("-inf")])
+    def test_put_non_finite_float_returns_422(self, client: TestClient, asset: 
AssetModel, bad_float: float):
+        with pytest.raises(ValueError, match="Out of range float values are 
not JSON compliant"):
+            _ = client.put(
+                _BY_NAME_VALUE,
+                params={"name": asset.name, "key": "watermark"},
+                content=json.dumps({"value": bad_float}, 
allow_nan=True).encode(),
+                headers={"Content-Type": "application/json"},
+            )
+
     def test_put_unknown_asset_returns_404(self, client: TestClient):
         response = client.put(
             _BY_NAME_VALUE, params={"name": "nonexistent", "key": 
"watermark"}, json={"value": "x"}
@@ -208,7 +255,7 @@ class TestPutAssetStateByUri:
             )
         )
         assert row is not None
-        assert row.value == "2026-04-29"
+        assert row.value == '"2026-04-29"'
 
     def test_put_unknown_uri_returns_404(self, client: TestClient):
         response = client.put(
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
index d83751050e7..97acc576a50 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from datetime import datetime
 from typing import TYPE_CHECKING
 from uuid import uuid4
@@ -95,7 +96,37 @@ class TestPutTaskState:
                 )
             )
             assert row is not None
-            assert row.value == "spark_001"
+            # DB stores a json string
+            assert row.value == '"spark_001"'
+
+    def test_put_int_value_roundtrip(self, client: TestClient, 
create_task_instance: CreateTaskInstance):
+        ti = create_task_instance()
+
+        response = client.put(_api_url(ti.id, "retry_count"), json={"value": 
3})
+
+        assert response.status_code == 204
+        assert client.get(_api_url(ti.id, "retry_count")).json() == {"value": 
3}
+
+    def test_put_dict_value_roundtrip(self, client: TestClient, 
create_task_instance: CreateTaskInstance):
+        ti = create_task_instance()
+
+        response = client.put(
+            _api_url(ti.id, "poll_result"),
+            json={"value": {"status": "succeeded", "rows": 1234}},
+        )
+
+        assert response.status_code == 204
+        assert client.get(_api_url(ti.id, "poll_result")).json() == {
+            "value": {"status": "succeeded", "rows": 1234}
+        }
+
+    def test_put_list_value_roundtrip(self, client: TestClient, 
create_task_instance: CreateTaskInstance):
+        ti = create_task_instance()
+
+        response = client.put(_api_url(ti.id, "checkpoints"), json={"value": 
[1, 2, 3]})
+
+        assert response.status_code == 204
+        assert client.get(_api_url(ti.id, "checkpoints")).json() == {"value": 
[1, 2, 3]}
 
     def test_put_with_expires_at_creates_row(
         self, client: TestClient, create_task_instance: CreateTaskInstance, 
time_machine
@@ -122,7 +153,7 @@ class TestPutTaskState:
                 )
             )
             assert row is not None
-            assert row.value == "spark_001"
+            assert row.value == '"spark_001"'
             assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0, 
tzinfo=pendulum.UTC)
 
     def test_put_overwrites_existing(self, client: TestClient, 
create_task_instance: CreateTaskInstance):
@@ -155,6 +186,18 @@ class TestPutTaskState:
 
         assert response.status_code == 422
 
+    @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), 
float("-inf")])
+    def test_put_non_finite_float_returns_422(
+        self, client: TestClient, create_task_instance: CreateTaskInstance, 
bad_float: float
+    ):
+        ti = create_task_instance()
+        with pytest.raises(ValueError, match="Out of range float values are 
not JSON compliant"):
+            _ = client.put(
+                _api_url(ti.id, "job_id"),
+                content=json.dumps({"value": bad_float}, 
allow_nan=True).encode(),
+                headers={"Content-Type": "application/json"},
+            )
+
     def test_put_missing_ti_returns_404(self, client: TestClient):
         response = client.put(_api_url(uuid4(), "job_id"), json={"value": "x"})
 
diff --git a/shared/state/pyproject.toml b/shared/state/pyproject.toml
index f317eba1995..3d16d465726 100644
--- a/shared/state/pyproject.toml
+++ b/shared/state/pyproject.toml
@@ -23,7 +23,9 @@ classifiers = [
     "Private :: Do Not Upload",
 ]
 
-dependencies = []
+dependencies = [
+    "pydantic>=2.11.0",
+]
 
 [dependency-groups]
 dev = [
diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index 7aa9fcba837..688cd6a6301 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from typing import TYPE_CHECKING
@@ -23,6 +24,7 @@ from typing import TYPE_CHECKING
 if TYPE_CHECKING:
     from datetime import datetime
 
+    from pydantic import JsonValue
     from sqlalchemy.ext.asyncio import AsyncSession
     from sqlalchemy.orm import Session
 
@@ -96,9 +98,10 @@ class BaseStateBackend(ABC):
     @abstractmethod
     def get(self, scope: StateScope, key: str, *, session: Session | None = 
None) -> str | None:
         """
-        Return the stored value, or None if the key does not exist.
+        Return the stored JSON encoded value string, or None if the key does 
not exist.
 
-        Must handle both ``TaskScope`` and ``AssetScope``.
+        Must handle both ``TaskScope`` and ``AssetScope``. The execution API 
calls
+        ``json.loads`` on the returned string from here, so it must be a valid 
JSON document.
         """
 
     @abstractmethod
@@ -112,9 +115,11 @@ class BaseStateBackend(ABC):
         session: Session | None = None,
     ) -> None:
         """
-        Write or overwrite the value for the given key.
+        Write or overwrite ``value`` for the given key.
 
-        Must handle both ``TaskScope`` and ``AssetScope``.
+        Must handle both ``TaskScope`` and ``AssetScope``. ``value`` is always 
a
+        JSON encoded string (the execution API calls ``json.dumps`` before 
passing it
+        here); store it verbatim so ``get`` can return it unchanged.
 
         ``expires_at`` is an absolute UTC datetime after which the row may be 
deleted.
         Pass ``None`` (default) for a key that should never expire — stored as 
``NULL``,
@@ -147,10 +152,10 @@ class BaseStateBackend(ABC):
     @abstractmethod
     async def aget(self, scope: StateScope, key: str, *, session: AsyncSession 
| None = None) -> str | None:
         """
-        Async variant of get. Must handle both ``TaskScope`` and 
``AssetScope``.
+        Async variant of ``get`` which returns a JSON encoded value string or 
None.
 
-        ``session`` is optional. If provided, implementations should use it 
directly.
-        If ``None``, implementations manage their own async session internally.
+        Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used 
directly
+        when provided; otherwise implementations manage their own session 
internally.
         """
 
     @abstractmethod
@@ -164,10 +169,10 @@ class BaseStateBackend(ABC):
         session: AsyncSession | None = None,
     ) -> None:
         """
-        Async variant of set. Must handle both ``TaskScope`` and 
``AssetScope``.
+        Async variant of ``set``. ``value`` is always a JSON encoded string.
 
-        ``session`` is optional. If provided, implementations should use it 
directly.
-        If ``None``, implementations manage their own async session internally.
+        Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used 
directly
+        when provided; otherwise implementations manage their own session 
internally.
         """
 
     @abstractmethod
@@ -203,7 +208,7 @@ class BaseStateBackend(ABC):
         ``[state_store] default_retention_days``) and deciding what to delete.
         """
 
-    def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) 
-> str:
+    def serialize_task_state_to_ref(self, *, value: JsonValue, key: str, 
ti_id: str) -> str:
         """
         Serialize a task state value before it is sent to the execution API 
for db persistence.
 
@@ -214,20 +219,21 @@ class BaseStateBackend(ABC):
         The returned reference must be deterministic — given the same 
``ti_id`` and ``key`` it
         must always return the same string. Do not use timestamps or random 
UUIDs as part of
         the reference, otherwise ``delete()``/``clear()`` cannot reconstruct 
it and the external
-        object will be orphaned.
+        object will be orphaned. By default, it JSON dumps the value and 
returns a JSON string.
         """
-        return value
+        return json.dumps(value)
 
-    def deserialize_task_state_from_ref(self, stored: str) -> str:
+    def deserialize_task_state_from_ref(self, stored: str) -> JsonValue:
         """
-        Resolve a stored task state string back to the actual value.
+        Resolve a stored task state reference back to the actual value.
 
         Called by ``TaskStateAccessor.get()`` after the stored string is 
retrieved from
-        the execution API. Default: return ``stored`` unchanged.
+        the execution API. By default, it JSON decodes ``stored`` to reverse 
the default
+        ``serialize_task_state_to_ref`` encoding.
         """
-        return stored
+        return json.loads(stored)
 
-    def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: 
str) -> str:
+    def serialize_asset_state_to_ref(self, *, value: JsonValue, key: str, 
asset_ref: str) -> str:
         """
         Serialize an asset state value before it is sent to the Execution API 
for db persistence.
 
@@ -241,15 +247,16 @@ class BaseStateBackend(ABC):
         The returned reference must be deterministic — given the same 
``asset_ref`` and ``key`` it
         must always return the same string. Do not use timestamps or random 
UUIDs as part of
         the reference, otherwise ``delete()``/``clear()`` cannot reconstruct 
it and the external
-        object will be orphaned.
+        object will be orphaned. By default, it JSON dumps the value and 
returns a JSON string.
         """
-        return value
+        return json.dumps(value)
 
-    def deserialize_asset_state_from_ref(self, stored: str) -> str:
+    def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue:
         """
-        Resolve a stored asset state string back to the actual value.
+        Resolve a stored asset state reference back to the actual value.
 
         Called by ``AssetStateAccessor.get()`` after the stored string is 
retrieved from
-        the Execution API. Default: return ``stored`` unchanged.
+        the Execution API. By default, it JSON decodes ``stored`` to reverse 
the default
+        ``serialize_asset_state_to_ref`` encoding.
         """
-        return stored
+        return json.loads(stored)
diff --git a/shared/state/tests/state/test_state.py 
b/shared/state/tests/state/test_state.py
index 1ea31194e27..eb658ff8c74 100644
--- a/shared/state/tests/state/test_state.py
+++ b/shared/state/tests/state/test_state.py
@@ -92,6 +92,18 @@ class TestBaseStateBackend:
         deserialized = backend.deserialize_task_state_from_ref(serialized)
         assert deserialized == original
 
+    def test_task_state_serialize_deserialize_typed_values(self, backend):
+        """Default backend passes typed values through unchanged (custom 
backends handle storage)."""
+        assert (
+            backend.deserialize_task_state_from_ref(
+                backend.serialize_task_state_to_ref(value=42, key="count", 
ti_id="abc-123")
+            )
+            == 42
+        )
+        assert backend.deserialize_task_state_from_ref(
+            backend.serialize_task_state_to_ref(value={"status": "ok"}, 
key="result", ti_id="abc-123")
+        ) == {"status": "ok"}
+
     def test_custom_backend_overrides_task_state_ser_deser(self):
         class MyBackend(BaseStateBackend):
             def get(self, scope, key): ...
@@ -126,6 +138,17 @@ class TestBaseStateBackend:
         deserialized = backend.deserialize_asset_state_from_ref(serialized)
         assert deserialized == original
 
+    def test_asset_state_serialize_deserialize_typed_values(self, backend):
+        assert (
+            backend.deserialize_asset_state_from_ref(
+                backend.serialize_asset_state_to_ref(value=5, 
key="total_runs", asset_ref="my_asset")
+            )
+            == 5
+        )
+        assert backend.deserialize_asset_state_from_ref(
+            backend.serialize_asset_state_to_ref(value={"rows": 1234}, 
key="last_run", asset_ref="my_asset")
+        ) == {"rows": 1234}
+
     def test_custom_backend_overrides_asset_state_ser_deser(self):
         class MyBackend(BaseStateBackend):
             def get(self, scope, key): ...
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 99b1aadb37f..1da539f29a3 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -32,7 +32,7 @@ import msgspec
 import structlog
 from opentelemetry import trace
 from opentelemetry.trace.propagation.tracecontext import 
TraceContextTextMapPropagator
-from pydantic import BaseModel
+from pydantic import BaseModel, JsonValue
 from tenacity import (
     before_log,
     retry,
@@ -721,7 +721,7 @@ class TaskStateOperations:
             raise
         return TaskStateResponse.model_validate_json(resp.read())
 
-    def set(self, ti_id: uuid.UUID, key: str, value: str, expires_at: datetime 
| None) -> OKResponse:
+    def set(self, ti_id: uuid.UUID, key: str, value: JsonValue, expires_at: 
datetime | None) -> OKResponse:
         """Set a task state value via the API server."""
         body = TaskStatePutBody(value=value, expires_at=expires_at)
         self.client.put(f"state/ti/{ti_id}/{key}", 
content=body.model_dump_json())
@@ -774,7 +774,9 @@ class AssetStateOperations:
             raise
         return AssetStateResponse.model_validate_json(resp.read())
 
-    def set(self, key: str, value: str, *, name: str | None = None, uri: str | 
None = None) -> OKResponse:
+    def set(
+        self, key: str, value: JsonValue, *, name: str | None = None, uri: str 
| None = None
+    ) -> OKResponse:
         """Set an asset state value via the API server."""
         endpoint, params = self._resolve_endpoint("value", key=key, name=name, 
uri=uri)
         self.client.put(endpoint, params=params, 
content=AssetStatePutBody(value=value).model_dump_json())
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index fc966a76969..62c43ac17d1 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -63,28 +63,6 @@ class AssetProfile(BaseModel):
     type: Annotated[str, Field(title="Type")]
 
 
-class AssetStatePutBody(BaseModel):
-    """
-    Request body for setting an asset state value.
-    """
-
-    model_config = ConfigDict(
-        extra="forbid",
-    )
-    value: Annotated[str, Field(title="Value")]
-
-
-class AssetStateResponse(BaseModel):
-    """
-    Asset state value returned to a worker.
-    """
-
-    model_config = ConfigDict(
-        extra="forbid",
-    )
-    value: Annotated[str, Field(title="Value")]
-
-
 class ConnectionResponse(BaseModel):
     """
     Connection schema for responses with fields that are needed for Runtime.
@@ -375,7 +353,7 @@ class TaskStatePutBody(BaseModel):
     model_config = ConfigDict(
         extra="forbid",
     )
-    value: Annotated[str, Field(title="Value")]
+    value: JsonValue
     expires_at: Annotated[AwareDatetime | None, Field(title="Expires At")] = 
None
 
 
@@ -387,7 +365,7 @@ class TaskStateResponse(BaseModel):
     model_config = ConfigDict(
         extra="forbid",
     )
-    value: Annotated[str, Field(title="Value")]
+    value: JsonValue
 
 
 class TaskStatesResponse(BaseModel):
@@ -596,6 +574,28 @@ class AssetResponse(BaseModel):
     extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None
 
 
+class AssetStatePutBody(BaseModel):
+    """
+    Request body for setting an asset state value.
+    """
+
+    model_config = ConfigDict(
+        extra="forbid",
+    )
+    value: JsonValue
+
+
+class AssetStateResponse(BaseModel):
+    """
+    Asset state value returned to a worker.
+    """
+
+    model_config = ConfigDict(
+        extra="forbid",
+    )
+    value: JsonValue
+
+
 class HITLDetailRequest(BaseModel):
     """
     Schema for the request part of a Human-in-the-loop detail for a specific 
task instance.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 2364e942ed0..7b494cab835 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -923,7 +923,7 @@ class GetTaskState(BaseModel):
 class SetTaskState(BaseModel):
     ti_id: UUID
     key: str
-    value: str
+    value: JsonValue
     expires_at: AwareDatetime | None
     type: Literal["SetTaskState"] = "SetTaskState"
 
@@ -955,14 +955,14 @@ class GetAssetStateByUri(BaseModel):
 class SetAssetStateByName(BaseModel):
     name: str
     key: str
-    value: str
+    value: JsonValue
     type: Literal["SetAssetStateByName"] = "SetAssetStateByName"
 
 
 class SetAssetStateByUri(BaseModel):
     uri: str
     key: str
-    value: str
+    value: JsonValue
     type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri"
 
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 14922780da4..cba613da85a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -495,12 +495,19 @@ class TaskStateAccessor:
     # is not implemented yet cos it's unclear whether task state values will be
     # used in templates.
 
-    def get(self, key: str) -> str | None:
-        """Return the stored value, or ``None`` if the key does not exist."""
+    def get(self, key: str) -> JsonValue:
+        """
+        Return the stored value, or ``None`` if the key does not exist.
+
+        Supported types: ``str``, ``int``, ``float``, ``bool``, ``list``, 
``dict``.
+        ``datetime`` is not JSON-serializable; store it as 
``value.isoformat()`` and
+        parse it back with ``datetime.fromisoformat(result)``.
+        """
         from airflow.sdk.execution_time.comms import ErrorResponse, 
GetTaskState, TaskStateResult
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
         resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key))
+
         if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.TASK_STATE_NOT_FOUND:
             raise AirflowRuntimeError(resp)
         if isinstance(resp, TaskStateResult):
@@ -508,10 +515,15 @@ class TaskStateAccessor:
             # if custom backend is configured, the stored value in DB is a 
reference, fetch the actual value from
             # custom backend using the reference
             backend = _get_worker_state_backend()
-            return backend.deserialize_task_state_from_ref(stored) if backend 
else stored
+            if backend is not None:
+                # serialize_task_state_to_ref always returns str by contract; 
stored contains the ref.
+                if TYPE_CHECKING:
+                    assert isinstance(stored, str)
+                return backend.deserialize_task_state_from_ref(stored)
+            return stored
         return None
 
-    def set(self, key: str, value: str, *, retention: timedelta | None = None) 
-> None:
+    def set(self, key: str, value: JsonValue, *, retention: timedelta | None = 
None) -> None:
         """
         Write or overwrite the value for the given key.
 
@@ -614,7 +626,7 @@ class AssetStateAccessor:
             return f"<AssetStateAccessor name={self._name!r}>"
         return f"<AssetStateAccessor uri={self._uri!r}>"
 
-    def get(self, key: str) -> str | None:
+    def get(self, key: str) -> JsonValue:
         """Return the stored value, or ``None`` if the key does not exist."""
         from airflow.sdk.execution_time.comms import (
             AssetStateResult,
@@ -635,13 +647,16 @@ class AssetStateAccessor:
             raise AirflowRuntimeError(resp)
         if isinstance(resp, AssetStateResult):
             stored = resp.value
-            # if custom backend is configured, the stored value in DB is a 
reference, fetch the actual value from
-            # custom backend using the reference
             backend = _get_worker_state_backend()
-            return backend.deserialize_asset_state_from_ref(stored) if backend 
else stored
+            if backend is not None:
+                # serialize_asset_state_to_ref always returns str by contract; 
stored contains the ref.
+                if TYPE_CHECKING:
+                    assert isinstance(stored, str)
+                return backend.deserialize_asset_state_from_ref(stored)
+            return stored
         return None
 
-    def set(self, key: str, value: str) -> None:
+    def set(self, key: str, value: JsonValue) -> None:
         """Write or overwrite the value for the given key."""
         from airflow.sdk.execution_time.comms import SetAssetStateByName, 
SetAssetStateByUri, ToSupervisor
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
@@ -756,11 +771,11 @@ class AssetStateAccessors:
             return next(iter(self._by_name.values()))
         return next(iter(self._by_uri.values()))
 
-    def get(self, key: str) -> str | None:
+    def get(self, key: str) -> JsonValue:
         """Return the stored value for the single-inlet task, or ``None`` if 
not found."""
         return self._single_accessor().get(key)
 
-    def set(self, key: str, value: str) -> None:
+    def set(self, key: str, value: JsonValue) -> None:
         """Write or overwrite the value for the single-inlet task."""
         self._single_accessor().set(key, value)
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json 
b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
index d4eb3d9c5a8..fec9596e493 100644
--- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
+++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
@@ -317,6 +317,9 @@
       "type": "object"
     },
     "AssetStateResult": {
+      "$defs": {
+        "JsonValue": {}
+      },
       "additionalProperties": false,
       "description": "Response to GetAssetState; wraps the generated API 
response for supervisor to worker comms.",
       "properties": {
@@ -327,8 +330,7 @@
           "type": "string"
         },
         "value": {
-          "title": "Value",
-          "type": "string"
+          "$ref": "#/$defs/JsonValue"
         }
       },
       "required": [
@@ -4549,6 +4551,9 @@
       "type": "object"
     },
     "SetAssetStateByName": {
+      "$defs": {
+        "JsonValue": {}
+      },
       "properties": {
         "key": {
           "title": "Key",
@@ -4565,8 +4570,7 @@
           "type": "string"
         },
         "value": {
-          "title": "Value",
-          "type": "string"
+          "$ref": "#/$defs/JsonValue"
         }
       },
       "required": [
@@ -4578,6 +4582,9 @@
       "type": "object"
     },
     "SetAssetStateByUri": {
+      "$defs": {
+        "JsonValue": {}
+      },
       "properties": {
         "key": {
           "title": "Key",
@@ -4594,8 +4601,7 @@
           "type": "string"
         },
         "value": {
-          "title": "Value",
-          "type": "string"
+          "$ref": "#/$defs/JsonValue"
         }
       },
       "required": [
@@ -4653,6 +4659,9 @@
       "type": "object"
     },
     "SetTaskState": {
+      "$defs": {
+        "JsonValue": {}
+      },
       "properties": {
         "expires_at": {
           "anyOf": [
@@ -4682,8 +4691,7 @@
           "type": "string"
         },
         "value": {
-          "title": "Value",
-          "type": "string"
+          "$ref": "#/$defs/JsonValue"
         }
       },
       "required": [
@@ -5773,6 +5781,9 @@
       "type": "object"
     },
     "TaskStateResult": {
+      "$defs": {
+        "JsonValue": {}
+      },
       "additionalProperties": false,
       "description": "Response to GetTaskState; wraps the generated API 
response for supervisor to worker comms.",
       "properties": {
@@ -5783,8 +5794,7 @@
           "type": "string"
         },
         "value": {
-          "title": "Value",
-          "type": "string"
+          "$ref": "#/$defs/JsonValue"
         }
       },
       "required": [
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 1763604e477..35c74278748 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -18,11 +18,13 @@
 from __future__ import annotations
 
 from datetime import datetime, timedelta, timezone as dt_timezone
+from typing import TYPE_CHECKING
 from unittest import mock
 from unittest.mock import MagicMock, patch
 from uuid import UUID
 
 import pytest
+from pydantic import ValidationError
 
 from airflow.sdk import BaseOperator, get_current_context, timezone
 from airflow.sdk._shared.state import TaskScope
@@ -102,6 +104,9 @@ from airflow.sdk.state import BaseStateBackend
 
 from tests_common.test_utils.config import conf_vars
 
+if TYPE_CHECKING:
+    from pydantic import JsonValue
+
 
 def test_convert_connection_result_conn():
     """Test that the ConnectionResult is converted to a Connection object."""
@@ -1210,6 +1215,16 @@ class TestTaskStateAccessor:
             ClearTaskState(ti_id=self.TI_ID, all_map_indices=True)
         )
 
+    def test_set_datetime_raises_validation_error(self, mock_supervisor_comms):
+        """datetime is not JSON-serializable; callers must use .isoformat() 
first."""
+        with pytest.raises(ValidationError):
+            TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set(
+                "watermark",
+                datetime(2026, 5, 15, tzinfo=dt_timezone.utc),
+            )
+
+        mock_supervisor_comms.send.assert_not_called()
+
 
 class TestAssetStateAccessor:
     ASSET_NAME = "debug_watcher_asset"
@@ -1417,23 +1432,23 @@ class InMemoryStateBackend(BaseStateBackend):
         self._actual_key_value_store: dict[str, str] = {}  # key -> actual 
value
         self.reference: dict[str, str] = {}  # key -> stored ref (mem:// URI)
 
-    def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) 
-> str:
+    def serialize_task_state_to_ref(self, *, value, key: str, ti_id: str) -> 
str:
         ref = f"mem://{ti_id}/{key}"
         self._actual_key_value_store[key] = value
         self.reference[key] = ref
         return ref
 
-    def deserialize_task_state_from_ref(self, stored: str) -> str:
+    def deserialize_task_state_from_ref(self, stored: str) -> JsonValue:
         key = stored.rsplit("/", 1)[-1]
         return self._actual_key_value_store.get(key, stored)
 
-    def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: 
str) -> str:
+    def serialize_asset_state_to_ref(self, *, value, key: str, asset_ref: str) 
-> str:
         ref = f"mem://{asset_ref}/{key}"
         self._actual_key_value_store[key] = value
         self.reference[key] = ref
         return ref
 
-    def deserialize_asset_state_from_ref(self, stored: str) -> str:
+    def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue:
         key = stored.rsplit("/", 1)[-1]
         return self._actual_key_value_store.get(key, stored)
 
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 56900fbadab..4ba821e537e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5263,6 +5263,40 @@ class TestTaskInstanceStateOperations:
         )
         
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id, 
key="job_id"))
 
+    def test_task_state_set_sends_typed_values(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
+        """set() accepts any JsonValue — dict, int, list — not just strings."""
+
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                ts = context["task_state"]
+                ts.set("retry_count", 3)
+                ts.set("poll_result", {"status": "succeeded", "rows": 1234})
+                ts.set("checkpoints", [1, 2, 3])
+
+        frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(frozen_dt, tick=False)
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+
+        with conf_vars({("state_store", "default_retention_days"): "30"}):
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        expires_at = frozen_dt + timedelta(days=30)
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(ti_id=runtime_ti.id, key="retry_count", value=3, 
expires_at=expires_at)
+        )
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(
+                ti_id=runtime_ti.id,
+                key="poll_result",
+                value={"status": "succeeded", "rows": 1234},
+                expires_at=expires_at,
+            )
+        )
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(ti_id=runtime_ti.id, key="checkpoints", value=[1, 2, 
3], expires_at=expires_at)
+        )
+
     def test_task_can_set_state_with_retention(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
         class MyOperator(BaseOperator):
             def execute(self, context):
diff --git a/uv.lock b/uv.lock
index c1e290951ca..8cdd633364c 100644
--- a/uv.lock
+++ b/uv.lock
@@ -8528,6 +8528,9 @@ mypy = [{ name = "apache-airflow-devel-common", extras = 
["mypy"], editable = "d
 name = "apache-airflow-shared-state"
 version = "0.0"
 source = { editable = "shared/state" }
+dependencies = [
+    { name = "pydantic" },
+]
 
 [package.dev-dependencies]
 dev = [
@@ -8538,6 +8541,7 @@ mypy = [
 ]
 
 [package.metadata]
+requires-dist = [{ name = "pydantic", specifier = ">=2.11.0" }]
 
 [package.metadata.requires-dev]
 dev = [{ name = "apache-airflow-devel-common", editable = "devel-common" }]


Reply via email to