kaxil commented on code in PR #66859:
URL: https://github.com/apache/airflow/pull/66859#discussion_r3270160478
##########
shared/state/tests/state/test_state.py:
##########
@@ -70,3 +70,71 @@ def test_abstract_methods_cover_full_interface(self):
"""BaseStateBackend enforces all 8 sync+async methods as abstract."""
expected = {"get", "set", "delete", "clear", "aget", "aset",
"adelete", "aclear"}
assert BaseStateBackend.__abstractmethods__ == expected
+
+ def test_task_state_serialize_deserialize_round_trip(self, backend):
+ original = "app_1234"
+ serialized = backend.serialize_task_state_to_ref(value=original,
key="job_id", ti_id="abc-123")
+ deserialized = backend.deserialize_task_state_from_ref(serialized)
+ assert deserialized == original
+
+ def test_custom_backend_overrides_task_state_ser_deser(self):
+ class MyBackend(BaseStateBackend):
+ def get(self, scope, key): ...
+ def set(self, scope, key, value): ...
+ def delete(self, scope, key): ...
+ def clear(self, scope, *, all_map_indices=False): ...
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+ def serialize_task_state_value(self, *, value, key, ti_id):
+ return f"s3://bucket/{ti_id}/{key}"
+
+ def deserialize_task_state_value(self, stored):
+ return f"fetched:{stored}"
+
+ b = MyBackend()
+ assert (
+ b.serialize_task_state_value(value="app_1234", key="job_id",
ti_id="abc-123")
+ == "s3://bucket/abc-123/job_id"
+ )
+ assert (
+ b.deserialize_task_state_value("s3://bucket/abc-123/job_id")
+ == "fetched:s3://bucket/abc-123/job_id"
+ )
+
+ def test_asset_state_serialize_deserialize_round_trip(self, backend):
+ original = "2026-05-01"
+ serialized = backend.serialize_asset_state_to_ref(
+ value="2026-05-01", key="watermark", asset_ref="my_asset"
+ )
+ deserialized = backend.deserialize_asset_state_from_ref(serialized)
+ assert deserialized == original
+
+ def test_custom_backend_overrides_asset_state_ser_deser(self):
+ class MyBackend(BaseStateBackend):
+ def get(self, scope, key): ...
+ def set(self, scope, key, value): ...
+ def delete(self, scope, key): ...
+ def clear(self, scope, *, all_map_indices=False): ...
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+ def serialize_asset_state_value(self, *, value, key, asset_ref):
+ return f"s3://bucket/assets/{asset_ref}/{key}"
Review Comment:
Same naming mismatch for asset state: `serialize_asset_state_value` /
`deserialize_asset_state_value` don't match the real hooks
`serialize_asset_state_to_ref` / `deserialize_asset_state_from_ref`. The test
only calls the user methods directly, so the framework integration is
unverified.
##########
shared/state/tests/state/test_state.py:
##########
@@ -70,3 +70,71 @@ def test_abstract_methods_cover_full_interface(self):
"""BaseStateBackend enforces all 8 sync+async methods as abstract."""
expected = {"get", "set", "delete", "clear", "aget", "aset",
"adelete", "aclear"}
assert BaseStateBackend.__abstractmethods__ == expected
+
+ def test_task_state_serialize_deserialize_round_trip(self, backend):
+ original = "app_1234"
+ serialized = backend.serialize_task_state_to_ref(value=original,
key="job_id", ti_id="abc-123")
+ deserialized = backend.deserialize_task_state_from_ref(serialized)
+ assert deserialized == original
+
+ def test_custom_backend_overrides_task_state_ser_deser(self):
+ class MyBackend(BaseStateBackend):
+ def get(self, scope, key): ...
+ def set(self, scope, key, value): ...
+ def delete(self, scope, key): ...
+ def clear(self, scope, *, all_map_indices=False): ...
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+ def serialize_task_state_value(self, *, value, key, ti_id):
Review Comment:
These overrides use the wrong method names. The base class defines
`serialize_task_state_to_ref` / `deserialize_task_state_from_ref`, but here you
override `serialize_task_state_value` / `deserialize_task_state_value`. The
test passes because it invokes the overrides directly, but it doesn't verify
these are the hooks the framework actually calls. The same wrong names appear
in the PR description's example backend, so users will hit this too. Rename to
match the real API and ideally route the assertion through `TaskStateAccessor`
so the test catches future regressions.
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -1455,6 +1465,10 @@ def _handle_current_task_success(
if conf.getboolean("state_store", "clear_on_success"):
log.info("Task state will be cleared by the server because
clear_on_success is enabled.")
+ if _get_worker_state_backend() is not None:
+ # clear the task state keys for custom state backends configured
on worker side
+ context["task_state"].clear()
Review Comment:
`clear_on_success` only clears task state via the worker backend, not asset
state. If a custom backend stored asset state externally (e.g. S3 objects from
`context["asset_state"].set(...)`), those objects are never cleaned up on task
success even though the DB rows may be cleared elsewhere. Should this walk
`context.get("asset_state")` and call `.clear()` on each accessor too?
##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -1889,6 +1889,16 @@ workers:
sensitive: true
example: ~
default: ""
+ state_backend:
+ description: |
+ Full class name of the state backend to use on workers for direct task
state access,
Review Comment:
"bypassing the execution API" isn't quite right. The execution API is still
on the hot path: `set` still sends `SetTaskState` and `get` still calls
`GetTaskState`. What changes is that the value stored in DB is now a reference
instead of the raw value, and the worker handles storage of the actual value.
Suggested wording: "Routes task state values through a custom worker-side
backend so large payloads or credentialed storage stay on worker
infrastructure. The Execution API still records a reference string."
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -1455,6 +1465,10 @@ def _handle_current_task_success(
if conf.getboolean("state_store", "clear_on_success"):
log.info("Task state will be cleared by the server because
clear_on_success is enabled.")
+ if _get_worker_state_backend() is not None:
+ # clear the task state keys for custom state backends configured
on worker side
+ context["task_state"].clear()
Review Comment:
`context["task_state"].clear()` also sends a `ClearTaskState` comms message,
but the server clears task state anyway as part of `SucceedTask`'s
`clear_on_success` handling (see
`api_fastapi/execution_api/routes/task_instances.py:469`). That's a redundant
round trip per task. Worth a "backend-only clear" path that skips the comms
when we know the server will handle the DB side.
##########
shared/state/src/airflow_shared/state/__init__.py:
##########
@@ -47,9 +47,21 @@ class TaskScope:
@dataclass(frozen=True)
class AssetScope:
- """Identifies the state namespace for an asset."""
+ """
+ Identifies the state namespace for an asset.
+
+ Server-side backends receive ``asset_id``. Worker-side backends receive
``name`` or ``uri``
+ since workers do not have access to the integer ``asset_id``.
+
+ Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if
an asset is
+ deactivated and a new one created with the same name, both share the same
``name`` value.
+ State for inactive assets is cleaned up by the orphan GC pass; until then,
stale rows exist
+ in the DB but cannot be written to (the Execution API resolver filters to
active assets only).
+ """
- asset_id: int
+ asset_id: int | None = None
Review Comment:
Making `asset_id` optional with a `None` default is a silent failure trap. A
server-side backend that receives `AssetScope(name=..., uri=...,
asset_id=None)` (or a worker-only backend that forgets `name`/`uri` dispatch
and falls through to the `asset_id` branch) will run `AssetStateModel.asset_id
== None` and get zero rows with no error. Consider validating in
`__post_init__` that the scope has either `asset_id` set or `name`/`uri` set,
depending on how the constructor was used.
##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -495,13 +519,26 @@ def set(self, key: str, value: str, *, retention:
timedelta | None = None) -> No
else:
days = conf.getint("state_store", "default_retention_days")
expires_at = None if days <= 0 else now + timedelta(days=days)
- SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=value, expires_at=expires_at))
+
+ # if custom backend is configured, store the value on the custom
backend, and return the reference
+ # to the stored value to store in the DB
+ backend = _get_worker_state_backend()
+ stored = (
+ backend.serialize_task_state_to_ref(value=value, key=key,
ti_id=str(self._ti_id))
+ if backend
+ else value
+ )
+
+ SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=stored, expires_at=expires_at))
def delete(self, key: str) -> None:
"""Delete a single key. No-op if the key does not exist."""
from airflow.sdk.execution_time.comms import DeleteTaskState
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ backend = _get_worker_state_backend()
+ if backend is not None:
Review Comment:
Backend delete runs before the DB ref is cleared. If
`SUPERVISOR_COMMS.send(DeleteTaskState(...))` fails, the external object is
gone but the DB still has the reference; the next `get()` fetches a dangling
ref and calls `deserialize_task_state_from_ref` on it. Same ordering concern in
`clear()` here and in `AssetStateAccessor.delete` / `.clear()`. Reversing the
order (DB ref first, then backend) makes the failure mode "DB cleared, external
orphaned" instead, which the deterministic-ref docstring already covers via
overwrite-on-next-set.
##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -598,9 +659,21 @@ def delete(self, key: str) -> None:
def clear(self) -> None:
"""Delete all state keys for this asset."""
- from airflow.sdk.execution_time.comms import ClearAssetStateByName,
ClearAssetStateByUri, ToSupervisor
+ from airflow.sdk._shared.state import AssetScope
+ from airflow.sdk.execution_time.comms import (
+ ClearAssetStateByName,
+ ClearAssetStateByUri,
+ ToSupervisor,
+ )
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ backend = _get_worker_state_backend()
+ # custom backends handle external storage cleanup only;
Review Comment:
These two comment lines say the same thing. Drop one.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]