This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e22395c530 Add session param to BaseStateBackend interface to fix 
custom backends (#66708)
7e22395c530 is described below

commit 7e22395c5308fa7e8d39c3435f0c45147d43f7b0
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 14 10:33:30 2026 +0530

    Add session param to BaseStateBackend interface to fix custom backends 
(#66708)
---
 .../execution_api/routes/asset_state.py            | 16 +++---
 .../api_fastapi/execution_api/routes/task_state.py |  8 +--
 airflow-core/src/airflow/state/metastore.py        | 64 +++++++++++++++-------
 airflow-core/tests/unit/state/test_metastore.py    | 12 +++-
 shared/state/src/airflow_shared/state/__init__.py  | 57 +++++++++++++++----
 5 files changed, 113 insertions(+), 44 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
index 3ff321def87..2351caa6dfa 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
@@ -87,7 +87,7 @@ def get_asset_state_by_name(
 ) -> AssetStateResponse:
     """Get an asset state value by asset name."""
     asset_id = _resolve_asset_id_by_name(name, session)
-    value = get_state_backend().get(AssetScope(asset_id=asset_id), key, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    value = get_state_backend().get(AssetScope(asset_id=asset_id), key, 
session=session)
     if value is None:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
@@ -105,7 +105,7 @@ def set_asset_state_by_name(
 ) -> None:
     """Set an asset state value by asset name."""
     asset_id = _resolve_asset_id_by_name(name, session)
-    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)
 
 
 @router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -116,7 +116,7 @@ def delete_asset_state_by_name(
 ) -> None:
     """Delete a single asset state key by asset name."""
     asset_id = _resolve_asset_id_by_name(name, session)
-    get_state_backend().delete(AssetScope(asset_id=asset_id), key, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().delete(AssetScope(asset_id=asset_id), key, 
session=session)
 
 
 @router.delete("/by-name/clear", status_code=status.HTTP_204_NO_CONTENT)
@@ -126,7 +126,7 @@ def clear_asset_state_by_name(
 ) -> None:
     """Delete all state keys for an asset by asset name."""
     asset_id = _resolve_asset_id_by_name(name, session)
-    get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)  
# type: ignore[call-arg]  # @provide_session adds session kwarg at runtime; 
BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
 
 
 @router.get("/by-uri/value")
@@ -137,7 +137,7 @@ def get_asset_state_by_uri(
 ) -> AssetStateResponse:
     """Get an asset state value by asset URI."""
     asset_id = _resolve_asset_id_by_uri(uri, session)
-    value = get_state_backend().get(AssetScope(asset_id=asset_id), key, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    value = get_state_backend().get(AssetScope(asset_id=asset_id), key, 
session=session)
     if value is None:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
@@ -155,7 +155,7 @@ def set_asset_state_by_uri(
 ) -> None:
     """Set an asset state value by asset URI."""
     asset_id = _resolve_asset_id_by_uri(uri, session)
-    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, 
session=session)
 
 
 @router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -166,7 +166,7 @@ def delete_asset_state_by_uri(
 ) -> None:
     """Delete a single asset state key by asset URI."""
     asset_id = _resolve_asset_id_by_uri(uri, session)
-    get_state_backend().delete(AssetScope(asset_id=asset_id), key, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().delete(AssetScope(asset_id=asset_id), key, 
session=session)
 
 
 @router.delete("/by-uri/clear", status_code=status.HTTP_204_NO_CONTENT)
@@ -176,4 +176,4 @@ def clear_asset_state_by_uri(
 ) -> None:
     """Delete all state keys for an asset by asset URI."""
     asset_id = _resolve_asset_id_by_uri(uri, session)
-    get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)  
# type: ignore[call-arg]  # @provide_session adds session kwarg at runtime; 
BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index db24109969c..acdaa8c6a24 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -65,7 +65,7 @@ def get_task_state(
 ) -> TaskStateResponse:
     """Get value for a task state."""
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    value = get_state_backend().get(scope, key, session=session)  # type: 
ignore[call-arg]  # @provide_session adds session kwarg at runtime; 
BaseStateBackend signature omits it so mypy can't see it
+    value = get_state_backend().get(scope, key, session=session)
     if value is None:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
@@ -86,7 +86,7 @@ def set_task_state(
 ) -> None:
     """Set a task state key, creating or updating the row."""
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    get_state_backend().set(scope, key, body.value, session=session)  # type: 
ignore[call-arg]  # @provide_session adds session kwarg at runtime; 
BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().set(scope, key, body.value, session=session)
 
 
 @router.delete("/{task_instance_id}/{key}", 
status_code=status.HTTP_204_NO_CONTENT)
@@ -97,7 +97,7 @@ def delete_task_state(
 ) -> None:
     """Delete a single task state key."""
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    get_state_backend().delete(scope, key, session=session)  # type: 
ignore[call-arg]  # @provide_session adds session kwarg at runtime; 
BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().delete(scope, key, session=session)
 
 
 @router.delete("/{task_instance_id}", status_code=status.HTTP_204_NO_CONTENT)
@@ -125,4 +125,4 @@ def clear_task_state(
     accepted without error.
     """
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    get_state_backend().clear(scope, all_map_indices=all_map_indices, 
session=session)  # type: ignore[call-arg]  # @provide_session adds session 
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+    get_state_backend().clear(scope, all_map_indices=all_map_indices, 
session=session)
diff --git a/airflow-core/src/airflow/state/metastore.py 
b/airflow-core/src/airflow/state/metastore.py
index 3382dad81fc..31b4de3158f 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+from collections.abc import AsyncGenerator
+from contextlib import asynccontextmanager
 from typing import TYPE_CHECKING
 
 from sqlalchemy import delete, select
@@ -38,6 +40,16 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
 
+@asynccontextmanager
+async def _async_session(session: AsyncSession | None) -> 
AsyncGenerator[AsyncSession, None]:
+    """Use provided async session or create a new one."""
+    if session is not None:
+        yield session
+    else:
+        async with create_session_async() as s:
+            yield s
+
+
 def _build_upsert_stmt(
     dialect: str | None,
     model: type,
@@ -69,7 +81,9 @@ class MetastoreStateBackend(BaseStateBackend):
     """Default state backend for tasks and assets. Stores task and asset state 
in the Airflow metadata database."""
 
     @provide_session
-    def get(self, scope: StateScope, key: str, *, session: Session = 
NEW_SESSION) -> str | None:
+    def get(self, scope: StateScope, key: str, *, session: Session | None = 
NEW_SESSION) -> str | None:
+        if TYPE_CHECKING:
+            assert session is not None
         match scope:
             case TaskScope():
                 return self._get_task_state(scope, key, session=session)
@@ -79,7 +93,9 @@ class MetastoreStateBackend(BaseStateBackend):
                 assert_never(scope)
 
     @provide_session
-    def set(self, scope: StateScope, key: str, value: str, *, session: Session 
= NEW_SESSION) -> None:
+    def set(self, scope: StateScope, key: str, value: str, *, session: Session 
| None = NEW_SESSION) -> None:
+        if TYPE_CHECKING:
+            assert session is not None
         match scope:
             case TaskScope():
                 self._set_task_state(scope, key, value, session=session)
@@ -89,7 +105,9 @@ class MetastoreStateBackend(BaseStateBackend):
                 assert_never(scope)
 
     @provide_session
-    def delete(self, scope: StateScope, key: str, *, session: Session = 
NEW_SESSION) -> None:
+    def delete(self, scope: StateScope, key: str, *, session: Session | None = 
NEW_SESSION) -> None:
+        if TYPE_CHECKING:
+            assert session is not None
         match scope:
             case TaskScope():
                 self._delete_task_state(scope, key, session=session)
@@ -104,8 +122,10 @@ class MetastoreStateBackend(BaseStateBackend):
         scope: StateScope,
         *,
         all_map_indices: bool = False,
-        session: Session = NEW_SESSION,
+        session: Session | None = NEW_SESSION,
     ) -> None:
+        if TYPE_CHECKING:
+            assert session is not None
         match scope:
             case TaskScope():
                 self._clear_task_state(scope, all_map_indices=all_map_indices, 
session=session)
@@ -114,43 +134,47 @@ class MetastoreStateBackend(BaseStateBackend):
             case _:
                 assert_never(scope)
 
-    async def aget(self, scope: StateScope, key: str) -> str | None:
-        async with create_session_async() as session:
+    async def aget(self, scope: StateScope, key: str, *, session: AsyncSession 
| None = None) -> str | None:
+        async with _async_session(session) as s:
             match scope:
                 case TaskScope():
-                    return await self._aget_task_state(scope, key, 
session=session)
+                    return await self._aget_task_state(scope, key, session=s)
                 case AssetScope():
-                    return await self._aget_asset_state(scope, key, 
session=session)
+                    return await self._aget_asset_state(scope, key, session=s)
                 case _:
                     assert_never(scope)
 
-    async def aset(self, scope: StateScope, key: str, value: str) -> None:
-        async with create_session_async() as session:
+    async def aset(
+        self, scope: StateScope, key: str, value: str, *, session: 
AsyncSession | None = None
+    ) -> None:
+        async with _async_session(session) as s:
             match scope:
                 case TaskScope():
-                    await self._aset_task_state(scope, key, value, 
session=session)
+                    await self._aset_task_state(scope, key, value, session=s)
                 case AssetScope():
-                    await self._aset_asset_state(scope, key, value, 
session=session)
+                    await self._aset_asset_state(scope, key, value, session=s)
                 case _:
                     assert_never(scope)
 
-    async def adelete(self, scope: StateScope, key: str) -> None:
-        async with create_session_async() as session:
+    async def adelete(self, scope: StateScope, key: str, *, session: 
AsyncSession | None = None) -> None:
+        async with _async_session(session) as s:
             match scope:
                 case TaskScope():
-                    await self._adelete_task_state(scope, key, session=session)
+                    await self._adelete_task_state(scope, key, session=s)
                 case AssetScope():
-                    await self._adelete_asset_state(scope, key, 
session=session)
+                    await self._adelete_asset_state(scope, key, session=s)
                 case _:
                     assert_never(scope)
 
-    async def aclear(self, scope: StateScope, *, all_map_indices: bool = 
False) -> None:
-        async with create_session_async() as session:
+    async def aclear(
+        self, scope: StateScope, *, all_map_indices: bool = False, session: 
AsyncSession | None = None
+    ) -> None:
+        async with _async_session(session) as s:
             match scope:
                 case TaskScope():
-                    await self._aclear_task_state(scope, 
all_map_indices=all_map_indices, session=session)
+                    await self._aclear_task_state(scope, 
all_map_indices=all_map_indices, session=s)
                 case AssetScope():
-                    await self._aclear_asset_state(scope, session=session)
+                    await self._aclear_asset_state(scope, session=s)
                 case _:
                     assert_never(scope)
 
diff --git a/airflow-core/tests/unit/state/test_metastore.py 
b/airflow-core/tests/unit/state/test_metastore.py
index 98993d7133c..dfd154cc92a 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -28,7 +28,7 @@ from airflow.models.dagrun import DagRun, DagRunType
 from airflow.models.task_state import TaskStateModel
 from airflow.state import AssetScope, TaskScope, resolve_state_backend
 from airflow.state.metastore import MetastoreStateBackend
-from airflow.utils.session import create_session
+from airflow.utils.session import create_session, create_session_async
 
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import clear_db_assets, clear_db_dags, 
clear_db_runs
@@ -379,6 +379,16 @@ class TestMetastoreStateBackendAsync:
         with pytest.raises(ValueError, match="No DagRun found"):
             await backend.aset(scope, "job_id", "app_async")
 
+    async def test_aset_and_aget_with_provided_session(
+        self, backend: MetastoreStateBackend, dag_run_committed: DagRun
+    ):
+        """async methods use a provided AsyncSession when one is given."""
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        async with create_session_async() as session:
+            await backend.aset(scope, "job_id", "app_with_session", 
session=session)
+            result = await backend.aget(scope, "job_id", session=session)
+        assert result == "app_with_session"
+
 
 class TestResolveStateBackend:
     @conf_vars({("state_store", "backend"): 
"airflow.state.metastore.MetastoreStateBackend"})
diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index 463d9f378f3..4920f66ae67 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -18,6 +18,11 @@ from __future__ import annotations
 
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from sqlalchemy.ext.asyncio import AsyncSession
+    from sqlalchemy.orm import Session
 
 
 @dataclass(frozen=True)
@@ -62,10 +67,16 @@ class BaseStateBackend(ABC):
                 ...  # asset-specific storage
 
     Custom backends are configured via ``[state_store] backend`` in 
``airflow.cfg``.
+
+    **The ``session`` parameter on ``get``, ``set``, ``delete``, and 
``clear``:**
+
+    The default ``MetastoreStateBackend`` passes a SQLAlchemy ``Session`` 
through
+    these methods. Custom backends that do not use SQLAlchemy should accept 
``session`` as a
+    keyword argument and ignore it.
     """
 
     @abstractmethod
-    def get(self, scope: StateScope, key: str) -> str | None:
+    def get(self, scope: StateScope, key: str, *, session: Session | None = 
None) -> str | None:
         """
         Return the stored value, or None if the key does not exist.
 
@@ -73,7 +84,7 @@ class BaseStateBackend(ABC):
         """
 
     @abstractmethod
-    def set(self, scope: StateScope, key: str, value: str) -> None:
+    def set(self, scope: StateScope, key: str, value: str, *, session: Session 
| None = None) -> None:
         """
         Write or overwrite the value for the given key.
 
@@ -81,7 +92,7 @@ class BaseStateBackend(ABC):
         """
 
     @abstractmethod
-    def delete(self, scope: StateScope, key: str) -> None:
+    def delete(self, scope: StateScope, key: str, *, session: Session | None = 
None) -> None:
         """
         Delete a single key. No-op if the key does not exist.
 
@@ -89,7 +100,9 @@ class BaseStateBackend(ABC):
         """
 
     @abstractmethod
-    def clear(self, scope: StateScope, *, all_map_indices: bool = False) -> 
None:
+    def clear(
+        self, scope: StateScope, *, all_map_indices: bool = False, session: 
Session | None = None
+    ) -> None:
         """
         Delete all keys under the given scope.
 
@@ -102,23 +115,45 @@ class BaseStateBackend(ABC):
         """
 
     @abstractmethod
-    async def aget(self, scope: StateScope, key: str) -> str | None:
-        """Async variant of get. Must handle both ``TaskScope`` and 
``AssetScope``."""
+    async def aget(self, scope: StateScope, key: str, *, session: AsyncSession 
| None = None) -> str | None:
+        """
+        Async variant of get. Must handle both ``TaskScope`` and 
``AssetScope``.
+
+        ``session`` is optional. If provided, implementations should use it 
directly.
+        If ``None``, implementations manage their own async session internally.
+        """
 
     @abstractmethod
-    async def aset(self, scope: StateScope, key: str, value: str) -> None:
-        """Async variant of set. Must handle both ``TaskScope`` and 
``AssetScope``."""
+    async def aset(
+        self, scope: StateScope, key: str, value: str, *, session: 
AsyncSession | None = None
+    ) -> None:
+        """
+        Async variant of set. Must handle both ``TaskScope`` and 
``AssetScope``.
+
+        ``session`` is optional. If provided, implementations should use it 
directly.
+        If ``None``, implementations manage their own async session internally.
+        """
 
     @abstractmethod
-    async def adelete(self, scope: StateScope, key: str) -> None:
-        """Async variant of delete. Must handle both ``TaskScope`` and 
``AssetScope``."""
+    async def adelete(self, scope: StateScope, key: str, *, session: 
AsyncSession | None = None) -> None:
+        """
+        Async variant of delete. Must handle both ``TaskScope`` and 
``AssetScope``.
+
+        ``session`` is optional. If provided, implementations should use it 
directly.
+        If ``None``, implementations manage their own async session internally.
+        """
 
     @abstractmethod
-    async def aclear(self, scope: StateScope, *, all_map_indices: bool = 
False) -> None:
+    async def aclear(
+        self, scope: StateScope, *, all_map_indices: bool = False, session: 
AsyncSession | None = None
+    ) -> None:
         """
         Async variant of clear. Must handle both ``TaskScope`` and 
``AssetScope``.
 
         For ``TaskScope``: by default, only keys for the exact ``map_index`` 
on the
         scope are cleared. Pass ``all_map_indices=True`` to wipe state across 
every
         mapped instance of the task. For ``AssetScope`` the flag has no effect.
+
+        ``session`` is optional. If provided, implementations should use it 
directly.
+        If ``None``, implementations manage their own async session internally.
         """

Reply via email to