This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ec2d56a473d AIP-103: Worker side custom state backend support (#66859)
ec2d56a473d is described below
commit ec2d56a473d6a919814788f7a0e3fff28b33c615
Author: Amogh Desai <[email protected]>
AuthorDate: Wed May 20 16:06:38 2026 +0530
AIP-103: Worker side custom state backend support (#66859)
---
.../src/airflow/config_templates/config.yml | 11 ++
shared/state/src/airflow_shared/state/__init__.py | 71 ++++++-
shared/state/tests/state/test_state.py | 83 ++++++++-
task-sdk/pyproject.toml | 3 +
task-sdk/src/airflow/sdk/_shared/state | 1 +
task-sdk/src/airflow/sdk/api/client.py | 3 +-
task-sdk/src/airflow/sdk/configuration.py | 3 +-
task-sdk/src/airflow/sdk/execution_time/context.py | 97 +++++++++-
.../src/airflow/sdk/execution_time/task_runner.py | 13 +-
task-sdk/src/airflow/sdk/state.py | 25 +++
task-sdk/tests/task_sdk/docs/test_public_api.py | 1 +
.../tests/task_sdk/execution_time/test_context.py | 207 +++++++++++++++++++--
.../task_sdk/execution_time/test_task_runner.py | 98 +++++++++-
13 files changed, 588 insertions(+), 28 deletions(-)
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index f93642dd8ed..0dfbbbda6c4 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -1899,6 +1899,17 @@ workers:
sensitive: true
example: ~
default: ""
+ state_backend:
+ description: |
+ Full class name of a custom worker-side state backend. When set, task
state values are
+ routed through this backend so large payloads or credentialed storage
stay on worker
+ infrastructure. The Execution API still records a reference string in
the database.
+
+ Leave empty (default) to use the standard path through the task sdk
supervisor.
+ version_added: 3.3.0
+ type: string
+ example: "mypackage.state.S3StateBackend"
+ default: ""
min_heartbeat_interval:
description: |
The minimum interval (in seconds) at which the worker checks the task
instance's
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index bfce8db3328..7aa9fcba837 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -47,9 +47,25 @@ class TaskScope:
@dataclass(frozen=True)
class AssetScope:
- """Identifies the state namespace for an asset."""
+ """
+ Identifies the state namespace for an asset.
+
+ Server-side backends receive ``asset_id``. Worker-side backends receive
``name`` or ``uri``
+ since workers do not have access to the integer ``asset_id``.
+
+ Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if
an asset is
+ deactivated and a new one created with the same name, both share the same
``name`` value.
+ State for inactive assets is cleaned up by the orphan GC pass; until then,
stale rows exist
+ in the DB but cannot be written to (the Execution API resolver filters to
active assets only).
+ """
+
+ asset_id: int | None = None
+ name: str | None = None
+ uri: str | None = None
- asset_id: int
+ def __post_init__(self) -> None:
+ if self.asset_id is None and self.name is None and self.uri is None:
+ raise ValueError("AssetScope requires at least one of: asset_id,
name, or uri")
StateScope = TaskScope | AssetScope
@@ -186,3 +202,54 @@ class BaseStateBackend(ABC):
retention policy. The backend is responsible for reading any relevant
config (e.g.
``[state_store] default_retention_days``) and deciding what to delete.
"""
+
+ def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str)
-> str:
+ """
+ Serialize a task state value before it is sent to the execution API
for db persistence.
+
+ Called by ``TaskStateAccessor.set()`` on the worker. The return value
is what gets
+ stored in the DB — typically a reference path (e.g. an S3 key) rather
than the
+ actual value. Default: return ``value`` unchanged.
+
+ The returned reference must be deterministic — given the same
``ti_id`` and ``key`` it
+ must always return the same string. Do not use timestamps or random
UUIDs as part of
+ the reference, otherwise ``delete()``/``clear()`` cannot reconstruct
it and the external
+ object will be orphaned.
+ """
+ return value
+
+ def deserialize_task_state_from_ref(self, stored: str) -> str:
+ """
+ Resolve a stored task state string back to the actual value.
+
+ Called by ``TaskStateAccessor.get()`` after the stored string is
retrieved from
+ the execution API. Default: return ``stored`` unchanged.
+ """
+ return stored
+
+ def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref:
str) -> str:
+ """
+ Serialize an asset state value before it is sent to the Execution API
for db persistence.
+
+ Called by ``AssetStateAccessor.set()`` on the worker. The return value
is what gets
+ stored in the DB — typically a reference path rather than the actual
value.
+ Default: return ``value`` unchanged.
+
+ ``asset_ref`` is either the asset name or URI, depending on how the
accessor was
+ constructed. It may be a URI string if the task inlet was declared as
``AssetUriRef``.
+
+ The returned reference must be deterministic — given the same
``asset_ref`` and ``key`` it
+ must always return the same string. Do not use timestamps or random
UUIDs as part of
+ the reference, otherwise ``delete()``/``clear()`` cannot reconstruct
it and the external
+ object will be orphaned.
+ """
+ return value
+
+ def deserialize_asset_state_from_ref(self, stored: str) -> str:
+ """
+ Resolve a stored asset state string back to the actual value.
+
+ Called by ``AssetStateAccessor.get()`` after the stored string is
retrieved from
+ the Execution API. Default: return ``stored`` unchanged.
+ """
+ return stored
diff --git a/shared/state/tests/state/test_state.py
b/shared/state/tests/state/test_state.py
index 47bce18a69e..1ea31194e27 100644
--- a/shared/state/tests/state/test_state.py
+++ b/shared/state/tests/state/test_state.py
@@ -18,7 +18,22 @@ from __future__ import annotations
import pytest
-from airflow_shared.state import BaseStateBackend, StateScope
+from airflow_shared.state import AssetScope, BaseStateBackend, StateScope
+
+
+class TestAssetScope:
+ def test_requires_at_least_one_identifier(self):
+ with pytest.raises(ValueError, match="at least one of"):
+ AssetScope()
+
+ def test_asset_id_alone_is_valid(self):
+ AssetScope(asset_id=1)
+
+ def test_name_alone_is_valid(self):
+ AssetScope(name="my_asset")
+
+ def test_uri_alone_is_valid(self):
+ AssetScope(uri="s3://bucket/key")
class TestBaseStateBackend:
@@ -70,3 +85,69 @@ class TestBaseStateBackend:
"""BaseStateBackend enforces all 8 sync+async methods as abstract."""
expected = {"get", "set", "delete", "clear", "aget", "aset",
"adelete", "aclear"}
assert BaseStateBackend.__abstractmethods__ == expected
+
+ def test_task_state_serialize_deserialize_round_trip(self, backend):
+ original = "app_1234"
+ serialized = backend.serialize_task_state_to_ref(value=original,
key="job_id", ti_id="abc-123")
+ deserialized = backend.deserialize_task_state_from_ref(serialized)
+ assert deserialized == original
+
+ def test_custom_backend_overrides_task_state_ser_deser(self):
+ class MyBackend(BaseStateBackend):
+ def get(self, scope, key): ...
+ def set(self, scope, key, value): ...
+ def delete(self, scope, key): ...
+ def clear(self, scope, *, all_map_indices=False): ...
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+ def serialize_task_state_to_ref(self, *, value, key, ti_id):
+ return f"s3://bucket/{ti_id}/{key}"
+
+ def deserialize_task_state_from_ref(self, stored):
+ return f"fetched:{stored}"
+
+ b = MyBackend()
+ assert b.serialize_task_state_to_ref(value="app_1234", key="job_id",
ti_id="abc-123") == (
+ "s3://bucket/abc-123/job_id"
+ )
+ assert (
+ b.deserialize_task_state_from_ref("s3://bucket/abc-123/job_id")
+ == "fetched:s3://bucket/abc-123/job_id"
+ )
+
+ def test_asset_state_serialize_deserialize_round_trip(self, backend):
+ original = "2026-05-01"
+ serialized = backend.serialize_asset_state_to_ref(
+ value="2026-05-01", key="watermark", asset_ref="my_asset"
+ )
+ deserialized = backend.deserialize_asset_state_from_ref(serialized)
+ assert deserialized == original
+
+ def test_custom_backend_overrides_asset_state_ser_deser(self):
+ class MyBackend(BaseStateBackend):
+ def get(self, scope, key): ...
+ def set(self, scope, key, value): ...
+ def delete(self, scope, key): ...
+ def clear(self, scope, *, all_map_indices=False): ...
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+ def serialize_asset_state_to_ref(self, *, value, key, asset_ref):
+ return f"s3://bucket/assets/{asset_ref}/{key}"
+
+ def deserialize_asset_state_from_ref(self, stored):
+ return f"resolved:{stored}"
+
+ b = MyBackend()
+ assert b.serialize_asset_state_to_ref(value="2026-05-01",
key="watermark", asset_ref="my_asset") == (
+ "s3://bucket/assets/my_asset/watermark"
+ )
+ assert (
+
b.deserialize_asset_state_from_ref("s3://bucket/assets/my_asset/watermark")
+ == "resolved:s3://bucket/assets/my_asset/watermark"
+ )
diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml
index 89a17fb52c3..4fc9f3d586a 100644
--- a/task-sdk/pyproject.toml
+++ b/task-sdk/pyproject.toml
@@ -147,6 +147,7 @@ path = "src/airflow/sdk/__init__.py"
"../shared/listeners/src/airflow_shared/listeners" =
"src/airflow/sdk/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" =
"src/airflow/sdk/_shared/plugins_manager"
"../shared/providers_discovery/src/airflow_shared/providers_discovery" =
"src/airflow/sdk/_shared/providers_discovery"
+"../shared/state/src/airflow_shared/state" = "src/airflow/sdk/_shared/state"
"../shared/template_rendering/src/airflow_shared/template_rendering" =
"src/airflow/sdk/_shared/template_rendering"
[tool.hatch.build.targets.wheel]
@@ -240,6 +241,7 @@ apache-airflow = {workspace = true}
apache-airflow-devel-common = {workspace = true}
apache-airflow-providers-common-sql = {workspace = true}
apache-airflow-providers-standard = {workspace = true}
+apache-airflow-shared-state = {workspace = true}
# To use:
#
@@ -316,6 +318,7 @@ shared_distributions = [
"apache-airflow-shared-secrets-backend",
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-serialization",
+ "apache-airflow-shared-state",
"apache-airflow-shared-timezones",
"apache-airflow-shared-observability",
"apache-airflow-shared-plugins-manager",
diff --git a/task-sdk/src/airflow/sdk/_shared/state
b/task-sdk/src/airflow/sdk/_shared/state
new file mode 120000
index 00000000000..752da632206
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/_shared/state
@@ -0,0 +1 @@
+../../../../../shared/state/src/airflow_shared/state
\ No newline at end of file
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 4c8c565fe43..99b1aadb37f 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -93,6 +93,7 @@ from airflow.sdk.execution_time.comms import (
OKResponse,
PreviousDagRunResult,
PreviousTIResult,
+ RescheduleTask,
SkipDownstreamTasks,
TaskRescheduleStartDate,
TICount,
@@ -104,8 +105,6 @@ if TYPE_CHECKING:
from datetime import datetime
from typing import ParamSpec
- from airflow.sdk.execution_time.comms import RescheduleTask
-
P = ParamSpec("P")
T = TypeVar("T")
diff --git a/task-sdk/src/airflow/sdk/configuration.py
b/task-sdk/src/airflow/sdk/configuration.py
index 4e438a2cbf7..fb32f990c58 100644
--- a/task-sdk/src/airflow/sdk/configuration.py
+++ b/task-sdk/src/airflow/sdk/configuration.py
@@ -32,6 +32,7 @@ from airflow.sdk._shared.configuration.parser import (
configure_parser_from_configuration_description,
expand_env_var,
)
+from airflow.sdk._shared.module_loading import import_string
from airflow.sdk.execution_time.secrets import
_SERVER_DEFAULT_SECRETS_SEARCH_PATH
log = logging.getLogger(__name__)
@@ -236,8 +237,6 @@ def initialize_secrets_backends(
Uses SDK's conf instead of Core's conf.
"""
- from airflow.sdk._shared.module_loading import import_string
-
backend_list = []
worker_mode = False
# Determine worker mode - if default_backends is not the server default,
it's worker mode
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 1e6874121fc..14922780da4 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from airflow.sdk import Variable
+ from airflow.sdk._shared.state import TaskScope
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context
@@ -70,6 +71,7 @@ if TYPE_CHECKING:
ReceiveMsgType,
VariableResult,
)
+ from airflow.sdk.state import BaseStateBackend
from airflow.sdk.types import OutletEventAccessorsProtocol
@@ -454,11 +456,29 @@ class VariableAccessor:
raise
+@cache
+def _get_worker_state_backend() -> BaseStateBackend | None:
+ """Return the configured worker-side state backend, instantiated once and
cached."""
+ class_name = conf.get("workers", "state_backend", fallback="")
+ if not class_name:
+ return None
+ from airflow.sdk._shared.module_loading import import_string
+
+ try:
+ return import_string(class_name)()
+ except (ImportError, AttributeError) as e:
+ raise ValueError(
+ f"Could not load worker state backend {class_name!r}. "
+ f"Check the [workers] state_backend config value. Error: {e}"
+ ) from e
+
+
class TaskStateAccessor:
"""Accessor for task state scoped to the current task instance. Available
as ``context['task_state']`` at task execution time."""
- def __init__(self, ti_id: UUID) -> None:
+ def __init__(self, ti_id: UUID, scope: TaskScope) -> None:
self._ti_id = ti_id
+ self._scope = scope
def __eq__(self, other: object) -> bool:
if not isinstance(other, TaskStateAccessor):
@@ -484,7 +504,11 @@ class TaskStateAccessor:
if isinstance(resp, ErrorResponse) and resp.error !=
ErrorType.TASK_STATE_NOT_FOUND:
raise AirflowRuntimeError(resp)
if isinstance(resp, TaskStateResult):
- return resp.value
+ stored = resp.value
+ # if custom backend is configured, the stored value in DB is a
reference, fetch the actual value from
+ # custom backend using the reference
+ backend = _get_worker_state_backend()
+ return backend.deserialize_task_state_from_ref(stored) if backend
else stored
return None
def set(self, key: str, value: str, *, retention: timedelta | None = None)
-> None:
@@ -509,14 +533,29 @@ class TaskStateAccessor:
else:
days = conf.getint("state_store", "default_retention_days")
expires_at = None if days <= 0 else now + timedelta(days=days)
- SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=value, expires_at=expires_at))
+
+ # if custom backend is configured, store the value on the custom
backend, and return the reference
+ # to the stored value to store in the DB
+ backend = _get_worker_state_backend()
+ stored = (
+ backend.serialize_task_state_to_ref(value=value, key=key,
ti_id=str(self._ti_id))
+ if backend
+ else value
+ )
+
+ SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=stored, expires_at=expires_at))
def delete(self, key: str) -> None:
"""Delete a single key. No-op if the key does not exist."""
from airflow.sdk.execution_time.comms import DeleteTaskState
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ # cleanup the DB ref first, if backend cleanup fails after this, the
ref is gone and
+ # deterministic keys are recoverable on next set().
SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key))
+ backend = _get_worker_state_backend()
+ if backend is not None:
+ backend.delete(self._scope, key)
def clear(self, all_map_indices: bool = False) -> None:
"""
@@ -529,7 +568,23 @@ class TaskStateAccessor:
from airflow.sdk.execution_time.comms import ClearTaskState
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ # cleanup the DB ref first, if backend cleanup fails after this, the
ref is gone and
+ # deterministic keys are recoverable on next set().
SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id,
all_map_indices=all_map_indices))
+ backend = _get_worker_state_backend()
+ if backend is not None:
+ backend.clear(self._scope, all_map_indices=all_map_indices)
+
+ def _clear_backend_only(self) -> None:
+ """
+ Clear external storage via the worker backend without sending a comms
message.
+
+ Used by clear_on_success: the server already clears DB rows as part of
SucceedTask,
+ so the comms round-trip is redundant.
+ """
+ backend = _get_worker_state_backend()
+ if backend is not None:
+ backend.clear(self._scope)
class AssetStateAccessor:
@@ -579,7 +634,11 @@ class AssetStateAccessor:
if isinstance(resp, ErrorResponse) and resp.error !=
ErrorType.ASSET_STATE_NOT_FOUND:
raise AirflowRuntimeError(resp)
if isinstance(resp, AssetStateResult):
- return resp.value
+ stored = resp.value
+ # if custom backend is configured, the stored value in DB is a
reference, fetch the actual value from
+ # custom backend using the reference
+ backend = _get_worker_state_backend()
+ return backend.deserialize_asset_state_from_ref(stored) if backend
else stored
return None
def set(self, key: str, value: str) -> None:
@@ -587,15 +646,26 @@ class AssetStateAccessor:
from airflow.sdk.execution_time.comms import SetAssetStateByName,
SetAssetStateByUri, ToSupervisor
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ # if custom backend is configured, store the value on the custom
backend, and return the reference
+ # to the stored value to store in the DB
+ backend = _get_worker_state_backend()
+ asset_ref = self._name or self._uri or ""
+ stored = (
+ backend.serialize_asset_state_to_ref(value=value, key=key,
asset_ref=asset_ref)
+ if backend
+ else value
+ )
+
msg: ToSupervisor
if self._name:
- msg = SetAssetStateByName(name=self._name, key=key, value=value)
+ msg = SetAssetStateByName(name=self._name, key=key, value=stored)
elif self._uri:
- msg = SetAssetStateByUri(uri=self._uri, key=key, value=value)
+ msg = SetAssetStateByUri(uri=self._uri, key=key, value=stored)
SUPERVISOR_COMMS.send(msg)
def delete(self, key: str) -> None:
"""Delete a single key. No-op if the key does not exist."""
+ from airflow.sdk._shared.state import AssetScope
from airflow.sdk.execution_time.comms import (
DeleteAssetStateByName,
DeleteAssetStateByUri,
@@ -608,11 +678,21 @@ class AssetStateAccessor:
msg = DeleteAssetStateByName(name=self._name, key=key)
elif self._uri:
msg = DeleteAssetStateByUri(uri=self._uri, key=key)
+ # DB ref first: if backend cleanup fails after this, the ref is gone
and
+ # deterministic keys are recoverable on next set().
SUPERVISOR_COMMS.send(msg)
+ backend = _get_worker_state_backend()
+ if backend is not None:
+ backend.delete(AssetScope(name=self._name, uri=self._uri), key)
def clear(self) -> None:
"""Delete all state keys for this asset."""
- from airflow.sdk.execution_time.comms import ClearAssetStateByName,
ClearAssetStateByUri, ToSupervisor
+ from airflow.sdk._shared.state import AssetScope
+ from airflow.sdk.execution_time.comms import (
+ ClearAssetStateByName,
+ ClearAssetStateByUri,
+ ToSupervisor,
+ )
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
msg: ToSupervisor
@@ -621,6 +701,9 @@ class AssetStateAccessor:
elif self._uri:
msg = ClearAssetStateByUri(uri=self._uri)
SUPERVISOR_COMMS.send(msg)
+ backend = _get_worker_state_backend()
+ if backend is not None:
+ backend.clear(AssetScope(name=self._name, uri=self._uri))
class AssetStateAccessors:
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 852b5da9cc6..10977fb011b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -134,6 +134,7 @@ from airflow.sdk.execution_time.sentry import Sentry
from airflow.sdk.execution_time.xcom import XCom
from airflow.sdk.listener import get_listener_manager
from airflow.sdk.observability.metrics import stats_utils
+from airflow.sdk.state import TaskScope
from airflow.sdk.timezone import coerce_datetime
if TYPE_CHECKING:
@@ -260,7 +261,15 @@ class RuntimeTaskInstance(TaskInstance):
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
- "task_state": TaskStateAccessor(ti_id=self.id),
+ "task_state": TaskStateAccessor(
+ ti_id=self.id,
+ scope=TaskScope(
+ dag_id=self.dag_id,
+ run_id=self.run_id,
+ task_id=self.task_id,
+ map_index=self.map_index if self.map_index is not None
else -1,
+ ),
+ ),
}
if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef,
AssetAlias)) for i in self.task.inlets):
self._cached_template_context["asset_state"] =
AssetStateAccessors(self.task.inlets)
@@ -1492,6 +1501,8 @@ def _handle_current_task_success(
if conf.getboolean("state_store", "clear_on_success"):
log.info("Task state will be cleared by the server because
clear_on_success is enabled.")
+ context["task_state"]._clear_backend_only()
+
msg = SucceedTask(
end_date=end_date,
task_outlets=task_outlets,
diff --git a/task-sdk/src/airflow/sdk/state.py
b/task-sdk/src/airflow/sdk/state.py
new file mode 100644
index 00000000000..ac2a1126fe4
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/state.py
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.sdk._shared.state import (
+ AssetScope as AssetScope,
+ BaseStateBackend as BaseStateBackend,
+ TaskScope as TaskScope,
+)
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index 98391927f8a..a21424ea101 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -65,6 +65,7 @@ def test_airflow_sdk_no_unexpected_exports():
"providers_manager_runtime",
"lineage",
"types",
+ "state",
}
unexpected = actual - public - ignore
assert not unexpected, f"Unexpected exports in airflow.sdk:
{sorted(unexpected)}"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 062645d25b1..1763604e477 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -25,7 +25,12 @@ from uuid import UUID
import pytest
from airflow.sdk import BaseOperator, get_current_context, timezone
-from airflow.sdk.api.datamodels._generated import AssetEventResponse,
AssetResponse, DagRun
+from airflow.sdk._shared.state import TaskScope
+from airflow.sdk.api.datamodels._generated import (
+ AssetEventResponse,
+ AssetResponse,
+ DagRun,
+)
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions.asset import (
Asset,
@@ -93,6 +98,7 @@ from airflow.sdk.execution_time.context import (
set_current_context,
)
from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend
+from airflow.sdk.state import BaseStateBackend
from tests_common.test_utils.config import conf_vars
@@ -1092,11 +1098,12 @@ class TestSecretsBackend:
class TestTaskStateAccessor:
TI_ID = UUID("01900000-0000-0000-0000-000000000001")
+ SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task")
def test_get_returns_value(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value =
TaskStateResult(value="app_001")
- result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id")
+ result = TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("job_id")
assert result == "app_001"
mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID,
key="job_id"))
@@ -1106,7 +1113,7 @@ class TestTaskStateAccessor:
error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"}
)
- result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key")
+ result = TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("missing_key")
assert result is None
@@ -1116,7 +1123,7 @@ class TestTaskStateAccessor:
)
with pytest.raises(AirflowRuntimeError):
- TaskStateAccessor(ti_id=self.TI_ID).get("some_key")
+ TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("some_key")
def test_set_operation_with_global_retention(self, mock_supervisor_comms,
time_machine):
"""set() with no retention uses global default_retention_days
config."""
@@ -1126,7 +1133,7 @@ class TestTaskStateAccessor:
time_machine.move_to(now, tick=False)
with conf_vars({("state_store", "default_retention_days"): "30"}):
- TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+ TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).set("job_id", "app_001")
mock_supervisor_comms.send.assert_called_once_with(
SetTaskState(
@@ -1143,7 +1150,9 @@ class TestTaskStateAccessor:
now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
time_machine.move_to(now, tick=False)
- TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001",
retention=timedelta(days=7))
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set(
+ "job_id", "app_001", retention=timedelta(days=7)
+ )
mock_supervisor_comms.send.assert_called_once_with(
SetTaskState(
@@ -1159,7 +1168,7 @@ class TestTaskStateAccessor:
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
- TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001",
retention=NEVER_EXPIRE)
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id",
"app_001", retention=NEVER_EXPIRE)
mock_supervisor_comms.send.assert_called_once_with(
SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001",
expires_at=None)
@@ -1170,7 +1179,7 @@ class TestTaskStateAccessor:
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
with conf_vars({("state_store", "default_retention_days"): "0"}):
- TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+ TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).set("job_id", "app_001")
mock_supervisor_comms.send.assert_called_once_with(
SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001",
expires_at=None)
@@ -1179,14 +1188,14 @@ class TestTaskStateAccessor:
def test_delete_operation(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
- TaskStateAccessor(ti_id=self.TI_ID).delete("job_id")
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id")
mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID,
key="job_id"))
def test_clear_default_sends_all_map_indices_false(self,
mock_supervisor_comms):
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
- TaskStateAccessor(ti_id=self.TI_ID).clear()
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear()
mock_supervisor_comms.send.assert_called_once_with(
ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)
@@ -1195,7 +1204,7 @@ class TestTaskStateAccessor:
def test_clear_all_map_indices_sends_flag_true(self,
mock_supervisor_comms):
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
- TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True)
+ TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).clear(all_map_indices=True)
mock_supervisor_comms.send.assert_called_once_with(
ClearTaskState(ti_id=self.TI_ID, all_map_indices=True)
@@ -1399,3 +1408,179 @@ class TestAssetStateAccessors:
accessors = AssetStateAccessors([alias])
assert accessors._total == 0
+
+
+class InMemoryStateBackend(BaseStateBackend):
+ """Simple in-memory test backend."""
+
+ def __init__(self):
+ self._actual_key_value_store: dict[str, str] = {} # key -> actual
value
+ self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI)
+
+ def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str)
-> str:
+ ref = f"mem://{ti_id}/{key}"
+ self._actual_key_value_store[key] = value
+ self.reference[key] = ref
+ return ref
+
+ def deserialize_task_state_from_ref(self, stored: str) -> str:
+ key = stored.rsplit("/", 1)[-1]
+ return self._actual_key_value_store.get(key, stored)
+
+ def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref:
str) -> str:
+ ref = f"mem://{asset_ref}/{key}"
+ self._actual_key_value_store[key] = value
+ self.reference[key] = ref
+ return ref
+
+ def deserialize_asset_state_from_ref(self, stored: str) -> str:
+ key = stored.rsplit("/", 1)[-1]
+ return self._actual_key_value_store.get(key, stored)
+
+ def get(self, scope, key, *, session=None): ...
+ def set(self, scope, key, value, *, session=None): ...
+
+ def delete(self, scope, key, *, session=None) -> None:
+ self._actual_key_value_store.pop(key, None)
+ self.reference.pop(key, None)
+
+ def clear(self, scope, *, all_map_indices=False, session=None) -> None:
+ self._actual_key_value_store.clear()
+ self.reference.clear()
+
+ async def aget(self, scope, key): ...
+ async def aset(self, scope, key, value): ...
+ async def adelete(self, scope, key): ...
+ async def aclear(self, scope, *, all_map_indices=False): ...
+
+
+class TestTaskStateAccessorWithCustomBackend:
+ TI_ID = UUID("01900000-0000-0000-0000-000000000002")
+ SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task")
+
+ @pytest.fixture(autouse=True)
+ def backend(self):
+ b = InMemoryStateBackend()
+ with mock.patch(
+ "airflow.sdk.execution_time.context._get_worker_state_backend",
+ return_value=b,
+ ):
+ yield b
+
+ def test_set_returns_reference_to_storage(self, mock_supervisor_comms,
backend, time_machine):
+ """set() stores actual value in backend and sends mem:// reference via
comms."""
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+ expected_ref = f"mem://{self.TI_ID}/job_id"
+
+ frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(frozen_dt, tick=False)
+
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id",
"app_001")
+ # comms message has the mem:// reference, not the actual value
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(
+ ti_id=self.TI_ID, key="job_id", value=expected_ref,
expires_at=frozen_dt + timedelta(days=30)
+ )
+ )
+ # actual value is stored on the backend, reference is stored for DB
+ assert backend._actual_key_value_store["job_id"] == "app_001"
+ assert backend.reference["job_id"] == expected_ref
+
+ def test_get_resolves_reference_to_actual_value(self,
mock_supervisor_comms, backend):
+ """get() fetches mem:// reference from DB, resolves it to actual value
via backend."""
+ ref = f"mem://{self.TI_ID}/job_id"
+ backend._actual_key_value_store["job_id"] = "app_001"
+ mock_supervisor_comms.send.return_value = TaskStateResult(value=ref)
+
+ result = TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("job_id")
+ # actual value is resolved from mem:// reference via backend
+ assert result == "app_001"
+
+ def test_deletes_from_backend_and_removes_db_ref(self,
mock_supervisor_comms, backend):
+ """delete() purges from backend storage and removes the DB
reference."""
+ backend._actual_key_value_store["job_id"] = "app_001"
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id")
+
+ # backend does not have the value anymore
+ assert "job_id" not in backend._actual_key_value_store
+ # request to delete reference in DB was made
+
mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=self.TI_ID,
key="job_id"))
+
+ def test_clears_all_from_backend_and_clears_db(self,
mock_supervisor_comms, backend):
+ """clear() purges all backend objects for the TI and removes all DB
references."""
+ backend._actual_key_value_store["job_id"] = "app_001"
+ backend._actual_key_value_store["checkpoint"] = "step_3"
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear()
+
+ assert "job_id" not in backend._actual_key_value_store
+ assert "checkpoint" not in backend._actual_key_value_store
+
mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=self.TI_ID,
all_map_indices=False))
+
+
+class TestAssetStateAccessorWithCustomBackend:
+ ASSET_NAME = "my_asset"
+
+ @pytest.fixture(autouse=True)
+ def backend(self):
+ b = InMemoryStateBackend()
+ with mock.patch(
+ "airflow.sdk.execution_time.context._get_worker_state_backend",
+ return_value=b,
+ ):
+ yield b
+
+ def test_set_sends_reference_not_value(self, mock_supervisor_comms,
backend):
+ """set() stores actual value in backend and sends mem:// reference via
comms."""
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).set("watermark", "2026-05-01")
+
+ expected_ref = f"mem://{self.ASSET_NAME}/watermark"
+ # comms message has the mem:// reference, not the actual value
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetAssetStateByName(name=self.ASSET_NAME, key="watermark",
value=expected_ref)
+ )
+ # actual value is stored on the backend, reference is stored for DB
+ assert backend._actual_key_value_store["watermark"] == "2026-05-01"
+ assert backend.reference["watermark"] == expected_ref
+
+ def test_get_resolves_reference_to_actual_value(self,
mock_supervisor_comms, backend):
+ """get() fetches mem:// reference from DB, resolves it to actual value
via backend."""
+ ref = f"mem://{self.ASSET_NAME}/watermark"
+ backend._actual_key_value_store["watermark"] = "2026-05-01"
+ mock_supervisor_comms.send.return_value = AssetStateResult(value=ref)
+
+ result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark")
+
+ # actual value is resolved from mem:// reference via backend
+ assert result == "2026-05-01"
+
+ def test_delete_purges_from_backend_and_removes_db_ref(self,
mock_supervisor_comms, backend):
+ """delete() purges from backend storage and removes the DB
reference."""
+ backend._actual_key_value_store["watermark"] = "2026-05-01"
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).delete("watermark")
+
+ # backend doesn't have the value anymore
+ assert "watermark" not in backend._actual_key_value_store
+ # request to delete reference in DB was made
+ mock_supervisor_comms.send.assert_any_call(
+ DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_clear_purges_all_from_backend_and_clears_db(self,
mock_supervisor_comms, backend):
+ """clear() purges all backend objects and removes all DB references."""
+ backend._actual_key_value_store["watermark"] = "2026-05-01"
+ backend._actual_key_value_store["file_count"] = "42"
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).clear()
+
+ assert "watermark" not in backend._actual_key_value_store
+ assert "file_count" not in backend._actual_key_value_store
+
mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name=self.ASSET_NAME))
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 9425ce362ed..b37d2569ea4 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -55,6 +55,7 @@ from airflow.sdk import (
timezone,
)
from airflow.sdk._shared.observability.metrics.base_stats_logger import
StatsLogger
+from airflow.sdk._shared.state import TaskScope
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
AssetResponse,
@@ -1899,7 +1900,9 @@ class TestRuntimeTaskInstance:
"run_id": "test_run",
"task": task,
"task_instance": runtime_ti,
- "task_state": TaskStateAccessor(ti_id=ti_id),
+ "task_state": TaskStateAccessor(
+ ti_id=ti_id, scope=TaskScope(dag_id=dag_id, run_id="test_run",
task_id="hello")
+ ),
"ti": runtime_ti,
}
@@ -1945,7 +1948,10 @@ class TestRuntimeTaskInstance:
"run_id": "test_run",
"task": task,
"task_instance": runtime_ti,
- "task_state": TaskStateAccessor(ti_id=runtime_ti.id),
+ "task_state": TaskStateAccessor(
+ ti_id=runtime_ti.id,
+ scope=TaskScope(dag_id=runtime_ti.dag_id, run_id="test_run",
task_id="hello"),
+ ),
"ti": runtime_ti,
"dag_run": dr,
"data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0),
@@ -5372,3 +5378,91 @@ class TestTaskInstanceStateOperations:
mock_supervisor_comms.send.assert_any_call(
SetAssetStateByName(name="asset_b", key="watermark_b",
value="2026-05-02")
)
+
+ def test_asset_state_set_sends_reference_via_custom_backend(
+ self, create_runtime_ti, mock_supervisor_comms
+ ):
+ """When a worker backend is configured, asset state set() sends a
reference, not the actual value."""
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"].set("watermark", "2026-05-01")
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ mock_backend = mock.MagicMock()
+ mock_backend.serialize_asset_state_to_ref.return_value =
"mem://my_asset/watermark"
+
+ with mock.patch(
+ "airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=mock_backend
+ ):
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_backend.serialize_asset_state_to_ref.assert_called_once_with(
+ value="2026-05-01", key="watermark", asset_ref="my_asset"
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="my_asset", key="watermark",
value="mem://my_asset/watermark")
+ )
+
+ def test_task_state_set_sends_reference_via_custom_backend(
+ self, create_runtime_ti, mock_supervisor_comms, time_machine
+ ):
+ """When a worker backend is configured, task state set() sends a
reference, not the actual value."""
+
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ context["task_state"].set("job_id", "app_001")
+
+ frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(frozen_dt, tick=False)
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ mock_backend = mock.MagicMock()
+ ref = f"mem://{runtime_ti.id}/job_id"
+ mock_backend.serialize_task_state_to_ref.return_value = ref
+
+ with mock.patch(
+ "airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=mock_backend
+ ):
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_backend.serialize_task_state_to_ref.assert_called_once_with(
+ value="app_001", key="job_id", ti_id=str(runtime_ti.id)
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(
+ ti_id=runtime_ti.id, key="job_id", value=ref,
expires_at=frozen_dt + timedelta(days=30)
+ )
+ )
+
+ @conf_vars({("state_store", "clear_on_success"): "True"})
+ def test_clear_on_success_clears_backend_without_comms_roundtrip(
+ self, create_runtime_ti, mock_supervisor_comms
+ ):
+ """clear_on_success calls backend.clear() directly without sending
ClearTaskState comms."""
+ mock_backend = mock.MagicMock()
+
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ pass
+
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+
+ with mock.patch(
+ "airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=mock_backend
+ ):
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_backend.clear.assert_called_once()
+ sent_types = [
+ type(call.kwargs.get("msg") or (call.args[0] if call.args else
None))
+ for call in mock_supervisor_comms.send.call_args_list
+ ]
+ assert ClearTaskState not in sent_types