This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 91f7df306b8 Decorate custom state refs with an envelope for UI clarity 
(#67530)
91f7df306b8 is described below

commit 91f7df306b88364d4260d1a9656c0a2476cb3840
Author: Amogh Desai <[email protected]>
AuthorDate: Fri May 29 09:34:32 2026 +0530

    Decorate custom state refs with an envelope for UI clarity (#67530)
---
 shared/state/src/airflow_shared/state/__init__.py  | 10 +++
 task-sdk/src/airflow/sdk/execution_time/context.py | 59 ++++++++++-----
 .../tests/task_sdk/execution_time/test_context.py  | 88 +++++++++++++++++++++-
 .../task_sdk/execution_time/test_task_runner.py    | 12 ++-
 4 files changed, 143 insertions(+), 26 deletions(-)

diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index 688cd6a6301..c8abef95ff0 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -216,6 +216,11 @@ class BaseStateBackend(ABC):
         stored in the DB — typically a reference path (e.g. an S3 key) rather 
than the
         actual value. Default: return ``value`` unchanged.
 
+        **Important:** return only the raw reference string. The worker 
framework automatically
+        wraps it in ``{"__airflow_state_ref__": "<ref>"}`` before writing to 
the DB, and strips
+        that wrapper before passing ``stored`` to 
``deserialize_task_state_from_ref()``. Do not
+        wrap the reference yourself.
+
         The returned reference must be deterministic — given the same 
``ti_id`` and ``key`` it
         must always return the same string. Do not use timestamps or random 
UUIDs as part of
         the reference, otherwise ``delete()``/``clear()`` cannot reconstruct 
it and the external
@@ -241,6 +246,11 @@ class BaseStateBackend(ABC):
         stored in the DB — typically a reference path rather than the actual 
value.
         Default: return ``value`` unchanged.
 
+        **Important:** return only the raw reference string. The worker 
framework automatically
+        wraps it in ``{"__airflow_state_ref__": "<ref>"}`` before writing to 
the DB, and strips
+        that wrapper before passing ``stored`` to 
``deserialize_asset_state_from_ref()``. Do not
+        wrap the reference yourself.
+
         ``asset_ref`` is either the asset name or URI, depending on how the 
accessor was
         constructed. It may be a URI string if the task inlet was declared as 
``AssetUriRef``.
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index cba613da85a..159995b840c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -121,6 +121,17 @@ log = structlog.get_logger(logger_name="task")
 #: Example: ``context["task_state"].set("job_id", job_id, 
retention=NEVER_EXPIRE)``
 NEVER_EXPIRE: timedelta = timedelta.max
 
+_EXTERNAL_STATE_REF_KEY = "__airflow_state_ref__"
+
+
+def _wrap_external_ref(ref: str) -> dict[str, JsonValue]:
+    return {_EXTERNAL_STATE_REF_KEY: ref}
+
+
+def _unwrap_external_ref(stored: dict) -> str | None:
+    return stored.get(_EXTERNAL_STATE_REF_KEY)
+
+
 T = TypeVar("T")
 
 
@@ -512,14 +523,17 @@ class TaskStateAccessor:
             raise AirflowRuntimeError(resp)
         if isinstance(resp, TaskStateResult):
             stored = resp.value
-            # if custom backend is configured, the stored value in DB is a 
reference, fetch the actual value from
-            # custom backend using the reference
             backend = _get_worker_state_backend()
+            if backend is not None and isinstance(stored, dict) and (ref := 
_unwrap_external_ref(stored)):
+                # unwrap the marker to get the ref, and retrieve the actual 
value from the backend using the ref
+                return backend.deserialize_task_state_from_ref(ref)
             if backend is not None:
-                # serialize_task_state_to_ref always returns str by contract; 
stored contains the ref.
-                if TYPE_CHECKING:
-                    assert isinstance(stored, str)
-                return backend.deserialize_task_state_from_ref(stored)
+                log.warning(
+                    "Task state key %r was not written through the configured 
state backend — returning raw "
+                    "stored value. To use the backend, ensure the task that 
wrote this key had the same "
+                    "backend configured.",
+                    key,
+                )
             return stored
         return None
 
@@ -549,11 +563,11 @@ class TaskStateAccessor:
         # if custom backend is configured, store the value on the custom 
backend, and return the reference
         # to the stored value to store in the DB
         backend = _get_worker_state_backend()
-        stored = (
-            backend.serialize_task_state_to_ref(value=value, key=key, 
ti_id=str(self._ti_id))
-            if backend
-            else value
-        )
+        stored: JsonValue = value
+        if backend is not None:
+            ref: str = backend.serialize_task_state_to_ref(value=value, 
key=key, ti_id=str(self._ti_id))
+            # wrap the value with a marker to indicate that it's stored 
externally, and include the ref to the external storage
+            stored = _wrap_external_ref(ref)
 
         SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=stored, expires_at=expires_at))
 
@@ -648,11 +662,17 @@ class AssetStateAccessor:
         if isinstance(resp, AssetStateResult):
             stored = resp.value
             backend = _get_worker_state_backend()
+            if backend is not None and isinstance(stored, dict) and (ref := 
_unwrap_external_ref(stored)):
+                # unwrap the marker to get the ref, and retrieve the actual 
value from the backend using the ref
+                return backend.deserialize_asset_state_from_ref(ref)
             if backend is not None:
-                # serialize_asset_state_to_ref always returns str by contract; 
stored contains the ref.
-                if TYPE_CHECKING:
-                    assert isinstance(stored, str)
-                return backend.deserialize_asset_state_from_ref(stored)
+                log.warning(
+                    "Asset state key %r for asset %r was not written through 
the configured state backend — "
+                    "returning raw stored value. To use the backend, ensure 
the task that wrote this key had "
+                    "the same backend configured.",
+                    key,
+                    self._name or self._uri,
+                )
             return stored
         return None
 
@@ -665,11 +685,10 @@ class AssetStateAccessor:
         # to the stored value to store in the DB
         backend = _get_worker_state_backend()
         asset_ref = self._name or self._uri or ""
-        stored = (
-            backend.serialize_asset_state_to_ref(value=value, key=key, 
asset_ref=asset_ref)
-            if backend
-            else value
-        )
+        stored: JsonValue = value
+        if backend is not None:
+            ref = backend.serialize_asset_state_to_ref(value=value, key=key, 
asset_ref=asset_ref)
+            stored = _wrap_external_ref(ref)
 
         msg: ToSupervisor
         if self._name:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 35c74278748..8b69d69b3bd 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -96,6 +96,7 @@ from airflow.sdk.execution_time.context import (
     _convert_variable_result_to_variable,
     _get_connection,
     _process_connection_result_conn,
+    _wrap_external_ref,
     context_to_airflow_vars,
     set_current_context,
 )
@@ -1225,6 +1226,43 @@ class TestTaskStateAccessor:
 
         mock_supervisor_comms.send.assert_not_called()
 
+    def test_set_with_custom_backend_decorates_value_with_marker(self, 
mock_supervisor_comms):
+        """Custom backend ref is wrapped in external state marker before going 
to DB."""
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        backend = MagicMock(spec=BaseStateBackend)
+        backend.serialize_task_state_to_ref.return_value = 
"s3://bucket/ti_123/job_id"
+
+        with (
+            
patch("airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=backend),
+            conf_vars({("state_store", "default_retention_days"): "0"}),
+        ):
+            TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).set("job_id", "spark_001")
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetTaskState(
+                ti_id=self.TI_ID,
+                key="job_id",
+                value=_wrap_external_ref("s3://bucket/ti_123/job_id"),
+                expires_at=None,
+            )
+        )
+
+    def test_get_with_custom_backend_removes_decoration_marker(self, 
mock_supervisor_comms):
+        """External state marker is detected and the ref is passed to 
deserialize."""
+        mock_supervisor_comms.send.return_value = TaskStateResult(
+            value=_wrap_external_ref("s3://bucket/ti_123/job_id")
+        )
+
+        backend = MagicMock(spec=BaseStateBackend)
+        backend.deserialize_task_state_from_ref.return_value = {"rows": 123}
+
+        with 
patch("airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=backend):
+            result = TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("job_id")
+
+        assert result == {"rows": 123}
+        
backend.deserialize_task_state_from_ref.assert_called_once_with("s3://bucket/ti_123/job_id")
+
 
 class TestAssetStateAccessor:
     ASSET_NAME = "debug_watcher_asset"
@@ -1317,6 +1355,41 @@ class TestAssetStateAccessor:
 
         
mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.ASSET_URI))
 
+    def test_set_with_custom_backend_decorates_value_with_marker(self, 
mock_supervisor_comms):
+        """Custom backend ref is wrapped in external state marker before going 
to DB."""
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        backend = MagicMock(spec=BaseStateBackend)
+        backend.serialize_asset_state_to_ref.return_value = 
"s3://bucket/assets/orders/watermark"
+
+        with 
patch("airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=backend):
+            AssetStateAccessor(name=self.ASSET_NAME).set("watermark", 
"2026-05-01")
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetAssetStateByName(
+                name=self.ASSET_NAME,
+                key="watermark",
+                
value=_wrap_external_ref("s3://bucket/assets/orders/watermark"),
+            )
+        )
+
+    def test_get_with_custom_backend_removes_decoration_marker(self, 
mock_supervisor_comms):
+        """External state marker is detected and the ref is passed to 
deserialize."""
+        mock_supervisor_comms.send.return_value = AssetStateResult(
+            value=_wrap_external_ref("s3://bucket/assets/orders/watermark")
+        )
+
+        backend = MagicMock(spec=BaseStateBackend)
+        backend.deserialize_asset_state_from_ref.return_value = "2026-05-01"
+
+        with 
patch("airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=backend):
+            result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark")
+
+        assert result == "2026-05-01"
+        backend.deserialize_asset_state_from_ref.assert_called_once_with(
+            "s3://bucket/assets/orders/watermark"
+        )
+
 
 class TestAssetStateAccessors:
     ASSET_NAME = "my_asset"
@@ -1494,7 +1567,10 @@ class TestTaskStateAccessorWithCustomBackend:
         # comms message has the mem:// reference, not the actual value
         mock_supervisor_comms.send.assert_called_once_with(
             SetTaskState(
-                ti_id=self.TI_ID, key="job_id", value=expected_ref, 
expires_at=frozen_dt + timedelta(days=30)
+                ti_id=self.TI_ID,
+                key="job_id",
+                value=_wrap_external_ref(expected_ref),
+                expires_at=frozen_dt + timedelta(days=30),
             )
         )
         # actual value is stored on the backend, reference is stored for DB
@@ -1503,7 +1579,7 @@ class TestTaskStateAccessorWithCustomBackend:
 
     def test_get_resolves_reference_to_actual_value(self, 
mock_supervisor_comms, backend):
         """get() fetches mem:// reference from DB, resolves it to actual value 
via backend."""
-        ref = f"mem://{self.TI_ID}/job_id"
+        ref = _wrap_external_ref(f"mem://{self.TI_ID}/job_id")
         backend._actual_key_value_store["job_id"] = "app_001"
         mock_supervisor_comms.send.return_value = TaskStateResult(value=ref)
 
@@ -1557,7 +1633,11 @@ class TestAssetStateAccessorWithCustomBackend:
         expected_ref = f"mem://{self.ASSET_NAME}/watermark"
         # comms message has the mem:// reference, not the actual value
         mock_supervisor_comms.send.assert_called_once_with(
-            SetAssetStateByName(name=self.ASSET_NAME, key="watermark", 
value=expected_ref)
+            SetAssetStateByName(
+                name=self.ASSET_NAME,
+                key="watermark",
+                value=_wrap_external_ref(expected_ref),
+            )
         )
         # actual value is stored on the backend, reference is stored for DB
         assert backend._actual_key_value_store["watermark"] == "2026-05-01"
@@ -1565,7 +1645,7 @@ class TestAssetStateAccessorWithCustomBackend:
 
     def test_get_resolves_reference_to_actual_value(self, 
mock_supervisor_comms, backend):
         """get() fetches mem:// reference from DB, resolves it to actual value 
via backend."""
-        ref = f"mem://{self.ASSET_NAME}/watermark"
+        ref = _wrap_external_ref(f"mem://{self.ASSET_NAME}/watermark")
         backend._actual_key_value_store["watermark"] = "2026-05-01"
         mock_supervisor_comms.send.return_value = AssetStateResult(value=ref)
 
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 0fccb92573d..507e63cfff0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -153,6 +153,7 @@ from airflow.sdk.execution_time.context import (
     TaskStateAccessor,
     TriggeringAssetEventsAccessor,
     VariableAccessor,
+    _wrap_external_ref,
 )
 from airflow.sdk.execution_time.task_runner import (
     RuntimeTaskInstance,
@@ -5593,7 +5594,11 @@ class TestTaskInstanceStateOperations:
             value="2026-05-01", key="watermark", asset_ref="my_asset"
         )
         mock_supervisor_comms.send.assert_any_call(
-            SetAssetStateByName(name="my_asset", key="watermark", 
value="mem://my_asset/watermark")
+            SetAssetStateByName(
+                name="my_asset",
+                key="watermark",
+                value=_wrap_external_ref("mem://my_asset/watermark"),
+            )
         )
 
     def test_task_state_set_sends_reference_via_custom_backend(
@@ -5625,7 +5630,10 @@ class TestTaskInstanceStateOperations:
         )
         mock_supervisor_comms.send.assert_any_call(
             SetTaskState(
-                ti_id=runtime_ti.id, key="job_id", value=ref, 
expires_at=frozen_dt + timedelta(days=30)
+                ti_id=runtime_ti.id,
+                key="job_id",
+                value=_wrap_external_ref(ref),
+                expires_at=frozen_dt + timedelta(days=30),
             )
         )
 

Reply via email to