amoghrajesh opened a new pull request, #66859:
URL: https://github.com/apache/airflow/pull/66859

    <!-- SPDX-License-Identifier: Apache-2.0
         https://www.apache.org/licenses/LICENSE-2.0 -->
   
   <!--
   Thank you for contributing!
   
   Please provide above a brief description of the changes made in this pull 
request.
   Write a good git commit message following this guide: 
http://chris.beams.io/posts/git-commit/
   
   Please make sure that your code changes are covered with tests.
   And in case of new features or big changes remember to adjust the 
documentation.
   
   Feel free to ping (in general) for the review if you do not see reaction for 
a few days
   (72 Hours is the minimum reaction time you can expect from volunteers) - we 
sometimes miss notifications.
   
   In case of an existing issue, reference it using one of the following:
   
   * closes: #ISSUE
   * related: #ISSUE
   -->
   
   
   closes: https://github.com/apache/airflow/issues/66337
   
   ### What
   
   Right now the default path for task state / asset state is that all state 
reads/writes through the execution API → `MetastoreStateBackend (or server side 
configured state backend)` → database. 
   
   However, some deployments have a need that storage credentials must never 
leave the execution plane, ie: worker infrastructure and must never pass 
through a API server. The same constraint that motivated custom XCom backends / 
Secrets backend applies to task and asset state.
   
   ### What's being done
   
   New config keys under `[workers]: state_backend` (class path to a custom 
state backend). If this config is not set, nothing changes.
   
   When a worker backend is configured, each state operation changes behaviour:
   
   - `set(key, value)` — stores the value on the custom backend, gets back a 
reference string, stores that reference in DB instead of the raw value.
   
   `get(key)` — fetches the reference from DB, resolves it back to the actual 
value via the backend. The caller sees the original value, not the reference.
   
   `delete(key)` — fetches the reference from DB, deletes the externally-stored 
object via the backend, then removes the DB reference. Both sides cleaned up.
   
   `clear()` — fetches all references for the scope, purges each one from the 
backend, then clears DB rows. Also called automatically on task success if 
`clear_on_success = True`.
   
   ### Interface changes to `BaseStateBackend`
   The custom backend flow above requires the worker to call into the backend 
at the right moment in each operation. To support this, three pairs of methods 
are added to BaseStateBackend:
   
   `serialize_task_state_value` / `serialize_asset_state_value` — called during 
`set()`. Receives the actual value, stores it externally, returns an opaque 
reference string to be stored in DB.
   `deserialize_task_state_value` / `deserialize_asset_state_value` — called 
during get(). Receives the reference string from DB, returns the actual value.
   `purge_task_state` / `purge_asset_state` — called during delete() and 
clear(). Receives the reference string and deletes the externally stored object.
   
   All six have no-op defaults.
   
   
   ### How it works (end-to-end)
   
   #### Default path (no worker backend configured):
   
   
   `task: context["task_state"].set("job_id", "app_001")`
     → `SetTaskState comms msg` → supervisor → `PUT /state/ti/{id}/job_id 
{"value": "app_001"}` → DB
   
   
   #### Custom backend path:
   
   `task: context["task_state"].set("job_id", "app_001")`
     → `backend.serialize_task_state_value(value="app_001", key="job_id", 
ti_id=...)` → `"s3://bucket/ti_123/job_id"`
     → `SetTaskState` comms msg with `value="s3://bucket/ti_123/job_id"` → 
supervisor → DB
   
   `task: context["task_state"].get("job_id")`
     → `GetTaskState` → supervisor → DB returns `"s3://bucket/ti_123/job_id"`
     → `backend.deserialize_task_state_value("s3://bucket/ti_123/job_id")` → 
"app_001"
   
   `task: context["task_state"].delete("job_id")`
     → `GetTaskState` → gets ref `"s3://bucket/ti_123/job_id"`
     → `backend.purge_task_state("s3://bucket/ti_123/job_id")` → deletes from S3
     → `DeleteTaskState` → removes ref from DB
   
   On task success with `clear_on_success=True`: all stored refs are fetched, 
`purge_*` is called for each, then DB is cleared.
   
   (Added in https://github.com/apache/airflow/pull/66586)
   
   
   ### Design decisions worth flagging
   
   Reference/pointer pattern, not value serialization. `serialize_*` returns an 
opaque string (a reference/URI) that gets stored in DB. The DB row always 
contains a string — either the raw value (default path) or a reference (custom 
path). `deserialize_*` resolves the reference back to the value. This is 
identical to how `BaseXCom.serialize_value / deserialize_value` works today.
   
   `serialize_*` / `deserialize_*` / `purge_*` are non-abstract with no-op 
defaults. The default `serialize_*` returns the value unchanged and 
`deserialize_*` returns the stored string as-is — making the default path 
functionally equivalent even if `_get_worker_state_backend()` somehow returned 
a base instance.
   
   UI visibility. Custom backends that store values externally can still 
surface them in the Airflow UI: the reference string stored in DB can be 
queried via the Core API whcih will be added later. This is the same convention 
as custom XCom backends today.
   
   ### What if a server side backend is configured too?
   
   ```
   set("job_id", "app_001")
     → [workers] state_backend.serialize()  →  ref = "s3://bucket/ti_123/job_id"
     → SetTaskState(value=ref)  →  supervisor  →  Execution API
     → [state_store] backend.set(value=ref)  →  e.g. Redis stores ref
   ```
   
   
   ### Testing
   
   Wrote a custom state backend which is an in memory backend, ie: retrieves 
and stores from a dictionary. This backend stores actual values in a dict and 
stores only a
   reference string (``mem://<namespace>/<key>``) in the metadata DB
   
   
   Code:
   ```python
   from airflow.state import BaseStateBackend, StateScope
   
   
   class MemoryStateBackend(BaseStateBackend):
       """Worker-side state backend that stores values in a process-local 
dict."""
   
       _store: dict[str, str] = {}
   
       def _ref(self, namespace: str, key: str) -> str:
           return f"mem://{namespace}/{key}"
   
       def serialize_task_state_value(self, *, value: str, key: str, ti_id: 
str) -> str:
           ref = self._ref(ti_id, key)
           self._store[ref] = value
           print(f"[MemoryStateBackend] task stored {key!r} → {ref!r}")
           return ref
   
       def deserialize_task_state_value(self, stored: str) -> str:
           if stored.startswith("mem://"):
               value = self._store.get(stored, stored)
               print(f"[MemoryStateBackend] task resolved {stored!r} → 
{value!r}")
               return value
           return stored
   
       def serialize_asset_state_value(self, *, value: str, key: str, 
asset_name: str) -> str:
           ref = self._ref(asset_name, key)
           self._store[ref] = value
           print(f"[MemoryStateBackend] asset stored {key!r} → {ref!r}")
           return ref
   
       def deserialize_asset_state_value(self, stored: str) -> str:
           if stored.startswith("mem://"):
               value = self._store.get(stored, stored)
               print(f"[MemoryStateBackend] asset resolved {stored!r} → 
{value!r}")
               return value
           return stored
   
       def purge_task_state(self, stored: str) -> None:
           if stored.startswith("mem://") and stored in self._store:
               del self._store[stored]
               print(f"[MemoryStateBackend] task purged {stored!r}")
   
       def purge_asset_state(self, stored: str) -> None:
           if stored.startswith("mem://") and stored in self._store:
               del self._store[stored]
               print(f"[MemoryStateBackend] asset purged {stored!r}")
   
       def get(self, scope: StateScope, key: str, *, session=None) -> str | 
None:
           raise NotImplementedError("Worker-side backend — use 
serialize/deserialize instead")
   
       def set(self, scope: StateScope, key: str, value: str, *, session=None) 
-> None:
           raise NotImplementedError("Worker-side backend — use 
serialize/deserialize instead")
   
       def delete(self, scope: StateScope, key: str, *, session=None) -> None:
           raise NotImplementedError("Worker-side backend — use 
serialize/deserialize instead")
   
       def clear(self, scope: StateScope, *, all_map_indices: bool = False, 
session=None) -> None:
           raise NotImplementedError("Worker-side backend — use 
serialize/deserialize instead")
   
       async def aget(self, scope: StateScope, key: str, *, session=None) -> 
str | None:
           raise NotImplementedError
   
       async def aset(self, scope: StateScope, key: str, value: str, *, 
session=None) -> None:
           raise NotImplementedError
   
       async def adelete(self, scope: StateScope, key: str, *, session=None) -> 
None:
           raise NotImplementedError
   
       async def aclear(self, scope: StateScope, *, all_map_indices: bool = 
False, session=None) -> None:
           raise NotImplementedError
   ```
   
   #### Testing task_state
   
   ```python
   from airflow.sdk import DAG, task
   
   with DAG(
       dag_id="aip103_memory_backend_test",
       schedule=None,
       start_date=pendulum.datetime(2026, 1, 1, tz="UTC"),
   ) as dag:
   
       @task
       def test_set_and_get(**context):
           """set() stores via backend, get() resolves via backend."""
           ts = context["task_state"]
           ts.set("job_id", "spark_app_001")
           ts.set("checkpoint", "step_3")
   
           result = ts.get("job_id")
           print(f"get('job_id') = {result!r}")
           assert result == "spark_app_001", f"Expected 'spark_app_001', got 
{result!r}"
   
           result2 = ts.get("checkpoint")
           print(f"get('checkpoint') = {result2!r}")
           assert result2 == "step_3", f"Expected 'step_3', got {result2!r}"
   
           print("set + get: PASS")
   
       @task
       def test_delete(**context):
           """delete() purges from backend and removes DB reference."""
           ts = context["task_state"]
           ts.set("to_delete", "temporary_value")
   
           assert ts.get("to_delete") == "temporary_value"
           ts.delete("to_delete")
   
           result = ts.get("to_delete")
           print(f"get after delete = {result!r}")
           assert result is None, f"Expected None after delete, got {result!r}"
   
           print("delete: PASS")
   
       @task
       def test_clear(**context):
           """clear() purges all backend objects and removes all DB 
references."""
           ts = context["task_state"]
           ts.set("key_a", "value_a")
           ts.set("key_b", "value_b")
   
           ts.clear()
   
           result_a = ts.get("key_a")
           result_b = ts.get("key_b")
           print(f"get after clear: key_a={result_a!r}, key_b={result_b!r}")
           assert result_a is None and result_b is None, "Expected None for all 
keys after clear"
   
           print("clear: PASS")
   
       test_set_and_get() >> test_delete() >> test_clear()
   ```
   
   Starting breeze with this: `export 
AIRFLOW__WORKERS__STATE_BACKEND=memory_state_backend.MemoryStateBackend
   `
   
   set + get:
   
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/7ca95c89-d699-43d3-b2b8-8b3c9d205112";
 />
   
   get after delete:
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/0f28d3c1-aae0-48b4-8a6d-0c9aed8987d5";
 />
   
   clear:
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/d0838bb1-cd0f-4329-90bc-edbda4b88a71";
 />
   
   
   DB only has refs left for task 1
   
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/4ccd23d2-f166-4a22-b504-fd0f03965374";
 />
   
   
   #### Testing asset_state
   
   Using a dag and linking it to multiple tasks and managing asset states
   
   DAG:
   
   ```python
   import pendulum
   
   from airflow.sdk import DAG, Asset, task
   
   watched_asset = Asset(name="memory_backend_test_asset", 
uri="s3://aip103-test/memory-backend")
   
   with DAG(
       dag_id="aip103_memory_backend_asset_state_test",
       schedule=None,
       start_date=pendulum.datetime(2026, 1, 1, tz="UTC"),
   ) as dag:
   
       @task(inlets=[watched_asset])
       def test_set_and_get(**context):
           """set() stores via backend, get() resolves via backend."""
           state = context["asset_state"][watched_asset]
           state.set("watermark", "2026-05-01")
           state.set("file_count", "42")
   
           result = state.get("watermark")
           print(f"get('watermark') = {result!r}")
           assert result == "2026-05-01", f"Expected '2026-05-01', got 
{result!r}"
   
           result2 = state.get("file_count")
           print(f"get('file_count') = {result2!r}")
           assert result2 == "42", f"Expected '42', got {result2!r}"
   
           print("set + get: PASS")
   
       @task(inlets=[watched_asset])
       def test_delete(**context):
           """delete() purges from backend and removes DB reference."""
           state = context["asset_state"][watched_asset]
           state.set("to_delete", "temporary_value")
   
           assert state.get("to_delete") == "temporary_value"
           state.delete("to_delete")
   
           result = state.get("to_delete")
           print(f"get after delete = {result!r}")
           assert result is None, f"Expected None after delete, got {result!r}"
   
           print("delete: PASS")
   
       @task(inlets=[watched_asset])
       def test_clear(**context):
           """clear() purges all backend objects and removes all DB 
references."""
           state = context["asset_state"][watched_asset]
           state.set("key_a", "value_a")
           state.set("key_b", "value_b")
   
           state.clear()
   
           result_a = state.get("key_a")
           result_b = state.get("key_b")
           print(f"get after clear: key_a={result_a!r}, key_b={result_b!r}")
           assert result_a is None and result_b is None, "Expected None for all 
keys after clear"
   
           print("clear: PASS")
   
       test_set_and_get() >> test_delete() >> test_clear()
   
   ```
   
   set + get:
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/66632067-6133-4efd-9a48-8b12375e5961";
 />
   
   
   get after delete:
   
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/b9d4197c-7097-481c-8432-374e1fc5f399";
 />
   
   
   clear:
   
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/ee0c405b-7738-4fa0-a509-8297400c457f";
 />
   
   
   
   Once clear is called, all rows will be deleted for asset state since it 
isn't task scoped, so DB is empty
   
   <img width="1722" height="780" alt="image" 
src="https://github.com/user-attachments/assets/27387a67-9213-4033-85fa-d9b609af3d1f";
 />
   
   
   
   ---
   
   ##### Was generative AI tooling used to co-author this PR?
   
   <!--
   If generative AI tooling has been used in the process of authoring this PR, 
please
   change below checkbox to `[X]` followed by the name of the tool, uncomment 
the "Generated-by".
   -->
   
   - [ ] Yes (please specify the tool below)
   
   <!--
   Generated-by: [Tool Name] following [the 
guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#gen-ai-assisted-contributions)
   -->
   
   ---
   
   * Read the **[Pull Request 
Guidelines](https://github.com/apache/airflow/blob/main/contributing-docs/05_pull_requests.rst#pull-request-guidelines)**
 for more information. Note: commit author/co-author name and email in commits 
become permanently public when merged.
   * For fundamental code changes, an Airflow Improvement Proposal 
([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals))
 is needed.
   * When adding dependency, check compliance with the [ASF 3rd Party License 
Policy](https://www.apache.org/legal/resolved.html#category-x).
   * For significant user-facing changes create newsfragment: 
`{pr_number}.significant.rst`, in 
[airflow-core/newsfragments](https://github.com/apache/airflow/tree/main/airflow-core/newsfragments).
 You can add this file in a follow-up commit after the PR is created so you 
know the PR number.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to