This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 91f7df306b8 Decorate custom state refs with an envelope for UI clarity
(#67530)
91f7df306b8 is described below
commit 91f7df306b88364d4260d1a9656c0a2476cb3840
Author: Amogh Desai <[email protected]>
AuthorDate: Fri May 29 09:34:32 2026 +0530
Decorate custom state refs with an envelope for UI clarity (#67530)
---
shared/state/src/airflow_shared/state/__init__.py | 10 +++
task-sdk/src/airflow/sdk/execution_time/context.py | 59 ++++++++++-----
.../tests/task_sdk/execution_time/test_context.py | 88 +++++++++++++++++++++-
.../task_sdk/execution_time/test_task_runner.py | 12 ++-
4 files changed, 143 insertions(+), 26 deletions(-)
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index 688cd6a6301..c8abef95ff0 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -216,6 +216,11 @@ class BaseStateBackend(ABC):
stored in the DB — typically a reference path (e.g. an S3 key) rather
than the
actual value. Default: return ``value`` unchanged.
+ **Important:** return only the raw reference string. The worker
framework automatically
+ wraps it in ``{"__airflow_state_ref__": "<ref>"}`` before writing to
the DB, and strips
+ that wrapper before passing ``stored`` to
``deserialize_task_state_from_ref()``. Do not
+ wrap the reference yourself.
+
The returned reference must be deterministic — given the same
``ti_id`` and ``key`` it
must always return the same string. Do not use timestamps or random
UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct
it and the external
@@ -241,6 +246,11 @@ class BaseStateBackend(ABC):
stored in the DB — typically a reference path rather than the actual
value.
Default: return ``value`` unchanged.
+ **Important:** return only the raw reference string. The worker
framework automatically
+ wraps it in ``{"__airflow_state_ref__": "<ref>"}`` before writing to
the DB, and strips
+ that wrapper before passing ``stored`` to
``deserialize_asset_state_from_ref()``. Do not
+ wrap the reference yourself.
+
``asset_ref`` is either the asset name or URI, depending on how the
accessor was
constructed. It may be a URI string if the task inlet was declared as
``AssetUriRef``.
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index cba613da85a..159995b840c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -121,6 +121,17 @@ log = structlog.get_logger(logger_name="task")
#: Example: ``context["task_state"].set("job_id", job_id,
retention=NEVER_EXPIRE)``
NEVER_EXPIRE: timedelta = timedelta.max
+_EXTERNAL_STATE_REF_KEY = "__airflow_state_ref__"
+
+
+def _wrap_external_ref(ref: str) -> dict[str, JsonValue]:
+ return {_EXTERNAL_STATE_REF_KEY: ref}
+
+
+def _unwrap_external_ref(stored: dict) -> str | None:
+ return stored.get(_EXTERNAL_STATE_REF_KEY)
+
+
T = TypeVar("T")
@@ -512,14 +523,17 @@ class TaskStateAccessor:
raise AirflowRuntimeError(resp)
if isinstance(resp, TaskStateResult):
stored = resp.value
- # if custom backend is configured, the stored value in DB is a
reference, fetch the actual value from
- # custom backend using the reference
backend = _get_worker_state_backend()
+ if backend is not None and isinstance(stored, dict) and (ref :=
_unwrap_external_ref(stored)):
+ # unwrap the marker to get the ref, and retrieve the actual
value from the backend using the ref
+ return backend.deserialize_task_state_from_ref(ref)
if backend is not None:
- # serialize_task_state_to_ref always returns str by contract;
stored contains the ref.
- if TYPE_CHECKING:
- assert isinstance(stored, str)
- return backend.deserialize_task_state_from_ref(stored)
+ log.warning(
+ "Task state key %r was not written through the configured
state backend — returning raw "
+ "stored value. To use the backend, ensure the task that
wrote this key had the same "
+ "backend configured.",
+ key,
+ )
return stored
return None
@@ -549,11 +563,11 @@ class TaskStateAccessor:
# if custom backend is configured, store the value on the custom
backend, and return the reference
# to the stored value to store in the DB
backend = _get_worker_state_backend()
- stored = (
- backend.serialize_task_state_to_ref(value=value, key=key,
ti_id=str(self._ti_id))
- if backend
- else value
- )
+ stored: JsonValue = value
+ if backend is not None:
+ ref: str = backend.serialize_task_state_to_ref(value=value,
key=key, ti_id=str(self._ti_id))
+ # wrap the value with a marker to indicate that it's stored
externally, and include the ref to the external storage
+ stored = _wrap_external_ref(ref)
SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=stored, expires_at=expires_at))
@@ -648,11 +662,17 @@ class AssetStateAccessor:
if isinstance(resp, AssetStateResult):
stored = resp.value
backend = _get_worker_state_backend()
+ if backend is not None and isinstance(stored, dict) and (ref :=
_unwrap_external_ref(stored)):
+ # unwrap the marker to get the ref, and retrieve the actual
value from the backend using the ref
+ return backend.deserialize_asset_state_from_ref(ref)
if backend is not None:
- # serialize_asset_state_to_ref always returns str by contract;
stored contains the ref.
- if TYPE_CHECKING:
- assert isinstance(stored, str)
- return backend.deserialize_asset_state_from_ref(stored)
+ log.warning(
+ "Asset state key %r for asset %r was not written through
the configured state backend — "
+ "returning raw stored value. To use the backend, ensure
the task that wrote this key had "
+ "the same backend configured.",
+ key,
+ self._name or self._uri,
+ )
return stored
return None
@@ -665,11 +685,10 @@ class AssetStateAccessor:
# to the stored value to store in the DB
backend = _get_worker_state_backend()
asset_ref = self._name or self._uri or ""
- stored = (
- backend.serialize_asset_state_to_ref(value=value, key=key,
asset_ref=asset_ref)
- if backend
- else value
- )
+ stored: JsonValue = value
+ if backend is not None:
+ ref = backend.serialize_asset_state_to_ref(value=value, key=key,
asset_ref=asset_ref)
+ stored = _wrap_external_ref(ref)
msg: ToSupervisor
if self._name:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 35c74278748..8b69d69b3bd 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -96,6 +96,7 @@ from airflow.sdk.execution_time.context import (
_convert_variable_result_to_variable,
_get_connection,
_process_connection_result_conn,
+ _wrap_external_ref,
context_to_airflow_vars,
set_current_context,
)
@@ -1225,6 +1226,43 @@ class TestTaskStateAccessor:
mock_supervisor_comms.send.assert_not_called()
+ def test_set_with_custom_backend_decorates_value_with_marker(self,
mock_supervisor_comms):
+ """Custom backend ref is wrapped in external state marker before going
to DB."""
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ backend = MagicMock(spec=BaseStateBackend)
+ backend.serialize_task_state_to_ref.return_value =
"s3://bucket/ti_123/job_id"
+
+ with (
+
patch("airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=backend),
+ conf_vars({("state_store", "default_retention_days"): "0"}),
+ ):
+ TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).set("job_id", "spark_001")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(
+ ti_id=self.TI_ID,
+ key="job_id",
+ value=_wrap_external_ref("s3://bucket/ti_123/job_id"),
+ expires_at=None,
+ )
+ )
+
+ def test_get_with_custom_backend_removes_decoration_marker(self,
mock_supervisor_comms):
+ """External state marker is detected and the ref is passed to
deserialize."""
+ mock_supervisor_comms.send.return_value = TaskStateResult(
+ value=_wrap_external_ref("s3://bucket/ti_123/job_id")
+ )
+
+ backend = MagicMock(spec=BaseStateBackend)
+ backend.deserialize_task_state_from_ref.return_value = {"rows": 123}
+
+ with
patch("airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=backend):
+ result = TaskStateAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("job_id")
+
+ assert result == {"rows": 123}
+
backend.deserialize_task_state_from_ref.assert_called_once_with("s3://bucket/ti_123/job_id")
+
class TestAssetStateAccessor:
ASSET_NAME = "debug_watcher_asset"
@@ -1317,6 +1355,41 @@ class TestAssetStateAccessor:
mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.ASSET_URI))
+ def test_set_with_custom_backend_decorates_value_with_marker(self,
mock_supervisor_comms):
+ """Custom backend ref is wrapped in external state marker before going
to DB."""
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ backend = MagicMock(spec=BaseStateBackend)
+ backend.serialize_asset_state_to_ref.return_value =
"s3://bucket/assets/orders/watermark"
+
+ with
patch("airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=backend):
+ AssetStateAccessor(name=self.ASSET_NAME).set("watermark",
"2026-05-01")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetAssetStateByName(
+ name=self.ASSET_NAME,
+ key="watermark",
+
value=_wrap_external_ref("s3://bucket/assets/orders/watermark"),
+ )
+ )
+
+ def test_get_with_custom_backend_removes_decoration_marker(self,
mock_supervisor_comms):
+ """External state marker is detected and the ref is passed to
deserialize."""
+ mock_supervisor_comms.send.return_value = AssetStateResult(
+ value=_wrap_external_ref("s3://bucket/assets/orders/watermark")
+ )
+
+ backend = MagicMock(spec=BaseStateBackend)
+ backend.deserialize_asset_state_from_ref.return_value = "2026-05-01"
+
+ with
patch("airflow.sdk.execution_time.context._get_worker_state_backend",
return_value=backend):
+ result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark")
+
+ assert result == "2026-05-01"
+ backend.deserialize_asset_state_from_ref.assert_called_once_with(
+ "s3://bucket/assets/orders/watermark"
+ )
+
class TestAssetStateAccessors:
ASSET_NAME = "my_asset"
@@ -1494,7 +1567,10 @@ class TestTaskStateAccessorWithCustomBackend:
# comms message has the mem:// reference, not the actual value
mock_supervisor_comms.send.assert_called_once_with(
SetTaskState(
- ti_id=self.TI_ID, key="job_id", value=expected_ref,
expires_at=frozen_dt + timedelta(days=30)
+ ti_id=self.TI_ID,
+ key="job_id",
+ value=_wrap_external_ref(expected_ref),
+ expires_at=frozen_dt + timedelta(days=30),
)
)
# actual value is stored on the backend, reference is stored for DB
@@ -1503,7 +1579,7 @@ class TestTaskStateAccessorWithCustomBackend:
def test_get_resolves_reference_to_actual_value(self,
mock_supervisor_comms, backend):
"""get() fetches mem:// reference from DB, resolves it to actual value
via backend."""
- ref = f"mem://{self.TI_ID}/job_id"
+ ref = _wrap_external_ref(f"mem://{self.TI_ID}/job_id")
backend._actual_key_value_store["job_id"] = "app_001"
mock_supervisor_comms.send.return_value = TaskStateResult(value=ref)
@@ -1557,7 +1633,11 @@ class TestAssetStateAccessorWithCustomBackend:
expected_ref = f"mem://{self.ASSET_NAME}/watermark"
# comms message has the mem:// reference, not the actual value
mock_supervisor_comms.send.assert_called_once_with(
- SetAssetStateByName(name=self.ASSET_NAME, key="watermark",
value=expected_ref)
+ SetAssetStateByName(
+ name=self.ASSET_NAME,
+ key="watermark",
+ value=_wrap_external_ref(expected_ref),
+ )
)
# actual value is stored on the backend, reference is stored for DB
assert backend._actual_key_value_store["watermark"] == "2026-05-01"
@@ -1565,7 +1645,7 @@ class TestAssetStateAccessorWithCustomBackend:
def test_get_resolves_reference_to_actual_value(self,
mock_supervisor_comms, backend):
"""get() fetches mem:// reference from DB, resolves it to actual value
via backend."""
- ref = f"mem://{self.ASSET_NAME}/watermark"
+ ref = _wrap_external_ref(f"mem://{self.ASSET_NAME}/watermark")
backend._actual_key_value_store["watermark"] = "2026-05-01"
mock_supervisor_comms.send.return_value = AssetStateResult(value=ref)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 0fccb92573d..507e63cfff0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -153,6 +153,7 @@ from airflow.sdk.execution_time.context import (
TaskStateAccessor,
TriggeringAssetEventsAccessor,
VariableAccessor,
+ _wrap_external_ref,
)
from airflow.sdk.execution_time.task_runner import (
RuntimeTaskInstance,
@@ -5593,7 +5594,11 @@ class TestTaskInstanceStateOperations:
value="2026-05-01", key="watermark", asset_ref="my_asset"
)
mock_supervisor_comms.send.assert_any_call(
- SetAssetStateByName(name="my_asset", key="watermark",
value="mem://my_asset/watermark")
+ SetAssetStateByName(
+ name="my_asset",
+ key="watermark",
+ value=_wrap_external_ref("mem://my_asset/watermark"),
+ )
)
def test_task_state_set_sends_reference_via_custom_backend(
@@ -5625,7 +5630,10 @@ class TestTaskInstanceStateOperations:
)
mock_supervisor_comms.send.assert_any_call(
SetTaskState(
- ti_id=runtime_ti.id, key="job_id", value=ref,
expires_at=frozen_dt + timedelta(days=30)
+ ti_id=runtime_ti.id,
+ key="job_id",
+ value=_wrap_external_ref(ref),
+ expires_at=frozen_dt + timedelta(days=30),
)
)