amoghrajesh commented on PR #67292:
URL: https://github.com/apache/airflow/pull/67292#issuecomment-4645633357

   > Thanks for the awesome work, @bbovenzi!
   > 
   > Looking really nice, I left some comments on the PRs for issues after 
testing these many things, you have my dags but the last one for mapped is here:
   > 
   > ```python
   > from __future__ import annotations
   > 
   > import json
   > import random
   > from datetime import datetime, timezone
   > 
   > from airflow.sdk import DAG, task
   > 
   > TABLES = ["orders", "customers", "products"]
   > 
   > with DAG(
   >     dag_id="example_task_state_mapped",
   >     schedule=None,
   >     start_date=datetime(2026, 1, 1),
   >     catchup=False,
   >     tags=["example", "aip-103", "task-state", "mapped"],
   >     doc_md=__doc__,
   > ) as dag:
   > 
   >     @task
   >     def get_tables() -> list[str]:
   >         """Return the list of tables to process."""
   >         return TABLES
   > 
   >     @task
   >     def process_table(table: str, **context) -> dict:
   >         """Process one table — each mapped instance gets its own task 
state."""
   >         ts = context["task_state"]
   >         map_index = context["task_instance"].map_index
   > 
   >         row_count = random.randint(100, 10000)
   >         result = {
   >             "table": table,
   >             "map_index": map_index,
   >             "row_count": row_count,
   >             "processed_at": 
datetime.now(tz=timezone.utc).isoformat(timespec="seconds"),
   >         }
   > 
   >         ts.set("table", table)
   >         ts.set("status", "complete")
   >         ts.set("row_count", str(row_count))
   >         ts.set("result", json.dumps(result))
   > 
   >         print(f"[map_index={map_index}] Processed {table}: {row_count} 
rows")
   >         return result
   > 
   >     tables = get_tables()
   >     process_table.expand(table=tables)
   > ```
   > 
   > ## Task State — Spark DAG
   > * [x]  All keys visible after a completed run (`job_id`, `submitted_at`, 
`status`, `poll_result`, `completed_at`)
   > * [x]  `poll_result` JSON is pretty-printed
   > * [x]  `job_id` shows **Never** in Expires At, other keys show a date
   > * [x]  After retry-reattach: same `job_id` persists, `status` updates to 
`complete`
   > * [x]  Delete a single key — row gone, others intact
   > * [x]  Edit a key — new value shows immediately
   > * [x]  Clear all — table goes empty
   > 
   > ## Asset State — Watermark DAG
   > * [x]  First run: `watermark`, `total_runs=1`, `last_run_summary` appear 
on asset detail page
   > * [x]  Subsequent runs: `total_runs` increments, `watermark` advances, 
`prev_watermark` matches previous run
   > * [x]  Consumer DAG fires automatically after each producer run
   > * [x]  Clear asset state then re-trigger: `total_runs=1`, 
`prev_watermark=null`
   > 
   > ## Mapped Tasks — Mapped DAG (`example_task_state_mapped`)
   > * [x]  Trigger DAG — 3 mapped instances run (map_index 0, 1, 2 for 
orders/customers/products)
   > * [x]  Each mapped TI shows its own `table`, `row_count`, `result` in 
Storage tab — no bleed between instances
   > * [x]  Switching between map_index 0/1/2 in the UI shows different state 
values
   > * [x]  Clear single instance (`map_index=0`) — only that instance's state 
is gone, others intact
   > * [x]  Clear all (`all_map_indices=true`) — state wiped across all 3 
instances
   
   I am retesting similar scenarios but with the example dags for task store 
and asset store committed to repo. 
   
   1. The section: ## Task State — Spark DAG still looks great! And love the 
validation here.
   
   <img width="865" height="903" alt="image" 
src="https://github.com/user-attachments/assets/4af8c142-a3cf-4e12-a007-c80e7a3d886e";
 />
   
   2. Same with: `Asset State — Watermark DAG` section, looks great.
   Validation on UI is great too.
   
   <img width="989" height="818" alt="image" 
src="https://github.com/user-attachments/assets/b6fa1c78-7af3-4c68-82dd-161512542b09";
 />
   
   
   3. Using this dag for mapped
   
   ```python
   from __future__ import annotations
   
   import random
   from datetime import datetime, timezone
   
   from airflow.sdk import DAG, task
   
   TABLES = ["orders", "customers", "products"]
   
   with DAG(
       dag_id="example_task_store_mapped",
       schedule=None,
       start_date=datetime(2026, 1, 1),
       catchup=False,
       tags=["example", "aip-103", "task-state", "mapped"],
       doc_md=__doc__,
   ) as dag:
   
       @task
       def get_tables() -> list[str]:
           """Return the list of tables to process."""
           return TABLES
   
       @task
       def process_table(table: str, **context) -> dict:
           """Process one table — each mapped instance gets its own task 
state."""
           ts = context["task_store"]
           map_index = context["task_instance"].map_index
   
           row_count = random.randint(100, 10000)
           result = {
               "table": table,
               "map_index": map_index,
               "row_count": row_count,
               "processed_at": 
datetime.now(tz=timezone.utc).isoformat(timespec="seconds"),
           }
   
           ts.set("table", table)
           ts.set("status", "complete")
           ts.set("row_count", row_count)
           ts.set("result", result)
   
           print(f"[map_index={map_index}] Processed {table}: {row_count} rows")
           return result
   
       tables = get_tables()
       process_table.expand(table=tables)
   ```
   
   Looks great too, tried, get, set, clear, clear all etc.
   
   
   Now for custom backends, using this backend:
   
   ```python
   # Licensed to the Apache Software Foundation (ASF) under one
   # or more contributor license agreements.  See the NOTICE file
   # distributed with this work for additional information
   # regarding copyright ownership.  The ASF licenses this file
   # to you under the Apache License, Version 2.0 (the
   # "License"); you may not use this file except in compliance
   # with the License.  You may obtain a copy of the License at
   #
   #   http://www.apache.org/licenses/LICENSE-2.0
   #
   # Unless required by applicable law or agreed to in writing,
   # software distributed under the License is distributed on an
   # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
   # KIND, either express or implied.  See the License for the
   # specific language governing permissions and limitations
   # under the License.
   """
   File-based state backend for testing the ExternalState envelope.
   
   Workers write state values as JSON files under /tmp/airflow_state/ and store
   the file path as the external reference.  The DB therefore holds:
   
       {"__type": "ExternalState", "__var": 
"/tmp/airflow_state/ti_<id>/job_id.json"}
   
   instead of the raw value, which lets you verify the envelope behaviour 
end-to-end.
   
   Configure in airflow.cfg (or via env var) before starting a worker:
   
       [workers]
       state_backend = dev.file_state_backend.FileStateBackend
   
   The server-side abstract methods (get/set/delete/clear and their async 
variants)
   raise NotImplementedError — this backend is purely a worker-side 
serialization hook.
   """
   
   from __future__ import annotations
   
   import json
   from pathlib import Path
   from typing import TYPE_CHECKING
   
   from airflow.sdk.state import BaseStoreBackend
   
   if TYPE_CHECKING:
       from datetime import datetime
   
       from pydantic import JsonValue
       from sqlalchemy.ext.asyncio import AsyncSession
       from sqlalchemy.orm import Session
   
       from airflow_shared.state import StoreScope
   
   BASE_DIR = Path("/tmp/airflow_state")
   
   
   class FileStateBackend(BaseStoreBackend):
       """Stores task/asset state values as local JSON files; returns the path 
as the ref."""
       def serialize_task_store_to_ref(self, *, value: JsonValue, key: str, 
ti_id: str) -> str:
           path = BASE_DIR / f"ti_{ti_id}" / f"{key}.json"
           path.parent.mkdir(parents=True, exist_ok=True)
           path.write_text(json.dumps(value))
           return str(path)
   
       def deserialize_task_store_from_ref(self, stored: str) -> JsonValue:
           return json.loads(Path(stored).read_text())
   
       def serialize_asset_state_to_ref(self, *, value: JsonValue, key: str, 
asset_ref: str) -> str:
           safe = asset_ref.replace("/", "_").replace(":", "")
           path = BASE_DIR / "assets" / safe / f"{key}.json"
           path.parent.mkdir(parents=True, exist_ok=True)
           path.write_text(json.dumps(value))
           return str(path)
   
       def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue:
           return json.loads(Path(stored).read_text())
   
   
       def get(self, scope: StoreScope, key: str, *, session: Session | None = 
None) -> str | None:
           raise NotImplementedError(
               "FileStateBackend is a worker-side backend; server uses 
MetastoreStateBackend"
           )
   
       def set(
           self,
           scope: StoreScope,
           key: str,
           value: str,
           *,
           expires_at: datetime | None = None,
           session: Session | None = None,
       ) -> None:
           raise NotImplementedError
   
       def delete(self, scope: StoreScope, key: str, *, session: Session | None 
= None) -> None:
           raise NotImplementedError
   
       def clear(
           self, scope: StoreScope, *, all_map_indices: bool = False, session: 
Session | None = None
       ) -> None:
           raise NotImplementedError
   
       async def aget(self, scope: StoreScope, key: str, *, session: 
AsyncSession | None = None) -> str | None:
           raise NotImplementedError
   
       async def aset(
           self,
           scope: StoreScope,
           key: str,
           value: str,
           *,
           expires_at: datetime | None = None,
           session: AsyncSession | None = None,
       ) -> None:
           raise NotImplementedError
   
       async def adelete(self, scope: StoreScope, key: str, *, session: 
AsyncSession | None = None) -> None:
           raise NotImplementedError
   
       async def aclear(
           self, scope: StoreScope, *, all_map_indices: bool = False, session: 
AsyncSession | None = None
       ) -> None:
           raise NotImplementedError
   
   ```
   
   And tried out whether the custom ref envelope shows up, and it looks fine:
   <img width="1345" height="831" alt="image" 
src="https://github.com/user-attachments/assets/df3b64ad-5bef-4f3d-8ec3-c1f1e920be54";
 />
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to