This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7e22395c530 Add session param to BaseStateBackend interface to fix
custom backends (#66708)
7e22395c530 is described below
commit 7e22395c5308fa7e8d39c3435f0c45147d43f7b0
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 14 10:33:30 2026 +0530
Add session param to BaseStateBackend interface to fix custom backends
(#66708)
---
.../execution_api/routes/asset_state.py | 16 +++---
.../api_fastapi/execution_api/routes/task_state.py | 8 +--
airflow-core/src/airflow/state/metastore.py | 64 +++++++++++++++-------
airflow-core/tests/unit/state/test_metastore.py | 12 +++-
shared/state/src/airflow_shared/state/__init__.py | 57 +++++++++++++++----
5 files changed, 113 insertions(+), 44 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
index 3ff321def87..2351caa6dfa 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
@@ -87,7 +87,7 @@ def get_asset_state_by_name(
) -> AssetStateResponse:
"""Get an asset state value by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
- value = get_state_backend().get(AssetScope(asset_id=asset_id), key,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ value = get_state_backend().get(AssetScope(asset_id=asset_id), key,
session=session)
if value is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -105,7 +105,7 @@ def set_asset_state_by_name(
) -> None:
"""Set an asset state value by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
- get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session)
@router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -116,7 +116,7 @@ def delete_asset_state_by_name(
) -> None:
"""Delete a single asset state key by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
- get_state_backend().delete(AssetScope(asset_id=asset_id), key,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().delete(AssetScope(asset_id=asset_id), key,
session=session)
@router.delete("/by-name/clear", status_code=status.HTTP_204_NO_CONTENT)
@@ -126,7 +126,7 @@ def clear_asset_state_by_name(
) -> None:
"""Delete all state keys for an asset by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
- get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
# type: ignore[call-arg] # @provide_session adds session kwarg at runtime;
BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
@router.get("/by-uri/value")
@@ -137,7 +137,7 @@ def get_asset_state_by_uri(
) -> AssetStateResponse:
"""Get an asset state value by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
- value = get_state_backend().get(AssetScope(asset_id=asset_id), key,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ value = get_state_backend().get(AssetScope(asset_id=asset_id), key,
session=session)
if value is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -155,7 +155,7 @@ def set_asset_state_by_uri(
) -> None:
"""Set an asset state value by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
- get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session)
@router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -166,7 +166,7 @@ def delete_asset_state_by_uri(
) -> None:
"""Delete a single asset state key by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
- get_state_backend().delete(AssetScope(asset_id=asset_id), key,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().delete(AssetScope(asset_id=asset_id), key,
session=session)
@router.delete("/by-uri/clear", status_code=status.HTTP_204_NO_CONTENT)
@@ -176,4 +176,4 @@ def clear_asset_state_by_uri(
) -> None:
"""Delete all state keys for an asset by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
- get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
# type: ignore[call-arg] # @provide_session adds session kwarg at runtime;
BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().clear(AssetScope(asset_id=asset_id), session=session)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index db24109969c..acdaa8c6a24 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -65,7 +65,7 @@ def get_task_state(
) -> TaskStateResponse:
"""Get value for a task state."""
scope = _get_task_scope_for_ti(task_instance_id, session)
- value = get_state_backend().get(scope, key, session=session) # type:
ignore[call-arg] # @provide_session adds session kwarg at runtime;
BaseStateBackend signature omits it so mypy can't see it
+ value = get_state_backend().get(scope, key, session=session)
if value is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -86,7 +86,7 @@ def set_task_state(
) -> None:
"""Set a task state key, creating or updating the row."""
scope = _get_task_scope_for_ti(task_instance_id, session)
- get_state_backend().set(scope, key, body.value, session=session) # type:
ignore[call-arg] # @provide_session adds session kwarg at runtime;
BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().set(scope, key, body.value, session=session)
@router.delete("/{task_instance_id}/{key}",
status_code=status.HTTP_204_NO_CONTENT)
@@ -97,7 +97,7 @@ def delete_task_state(
) -> None:
"""Delete a single task state key."""
scope = _get_task_scope_for_ti(task_instance_id, session)
- get_state_backend().delete(scope, key, session=session) # type:
ignore[call-arg] # @provide_session adds session kwarg at runtime;
BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().delete(scope, key, session=session)
@router.delete("/{task_instance_id}", status_code=status.HTTP_204_NO_CONTENT)
@@ -125,4 +125,4 @@ def clear_task_state(
accepted without error.
"""
scope = _get_task_scope_for_ti(task_instance_id, session)
- get_state_backend().clear(scope, all_map_indices=all_map_indices,
session=session) # type: ignore[call-arg] # @provide_session adds session
kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
+ get_state_backend().clear(scope, all_map_indices=all_map_indices,
session=session)
diff --git a/airflow-core/src/airflow/state/metastore.py
b/airflow-core/src/airflow/state/metastore.py
index 3382dad81fc..31b4de3158f 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+from collections.abc import AsyncGenerator
+from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from sqlalchemy import delete, select
@@ -38,6 +40,16 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
+@asynccontextmanager
+async def _async_session(session: AsyncSession | None) ->
AsyncGenerator[AsyncSession, None]:
+ """Use provided async session or create a new one."""
+ if session is not None:
+ yield session
+ else:
+ async with create_session_async() as s:
+ yield s
+
+
def _build_upsert_stmt(
dialect: str | None,
model: type,
@@ -69,7 +81,9 @@ class MetastoreStateBackend(BaseStateBackend):
"""Default state backend for tasks and assets. Stores task and asset state
in the Airflow metadata database."""
@provide_session
- def get(self, scope: StateScope, key: str, *, session: Session =
NEW_SESSION) -> str | None:
+ def get(self, scope: StateScope, key: str, *, session: Session | None =
NEW_SESSION) -> str | None:
+ if TYPE_CHECKING:
+ assert session is not None
match scope:
case TaskScope():
return self._get_task_state(scope, key, session=session)
@@ -79,7 +93,9 @@ class MetastoreStateBackend(BaseStateBackend):
assert_never(scope)
@provide_session
- def set(self, scope: StateScope, key: str, value: str, *, session: Session
= NEW_SESSION) -> None:
+ def set(self, scope: StateScope, key: str, value: str, *, session: Session
| None = NEW_SESSION) -> None:
+ if TYPE_CHECKING:
+ assert session is not None
match scope:
case TaskScope():
self._set_task_state(scope, key, value, session=session)
@@ -89,7 +105,9 @@ class MetastoreStateBackend(BaseStateBackend):
assert_never(scope)
@provide_session
- def delete(self, scope: StateScope, key: str, *, session: Session =
NEW_SESSION) -> None:
+ def delete(self, scope: StateScope, key: str, *, session: Session | None =
NEW_SESSION) -> None:
+ if TYPE_CHECKING:
+ assert session is not None
match scope:
case TaskScope():
self._delete_task_state(scope, key, session=session)
@@ -104,8 +122,10 @@ class MetastoreStateBackend(BaseStateBackend):
scope: StateScope,
*,
all_map_indices: bool = False,
- session: Session = NEW_SESSION,
+ session: Session | None = NEW_SESSION,
) -> None:
+ if TYPE_CHECKING:
+ assert session is not None
match scope:
case TaskScope():
self._clear_task_state(scope, all_map_indices=all_map_indices,
session=session)
@@ -114,43 +134,47 @@ class MetastoreStateBackend(BaseStateBackend):
case _:
assert_never(scope)
- async def aget(self, scope: StateScope, key: str) -> str | None:
- async with create_session_async() as session:
+ async def aget(self, scope: StateScope, key: str, *, session: AsyncSession
| None = None) -> str | None:
+ async with _async_session(session) as s:
match scope:
case TaskScope():
- return await self._aget_task_state(scope, key,
session=session)
+ return await self._aget_task_state(scope, key, session=s)
case AssetScope():
- return await self._aget_asset_state(scope, key,
session=session)
+ return await self._aget_asset_state(scope, key, session=s)
case _:
assert_never(scope)
- async def aset(self, scope: StateScope, key: str, value: str) -> None:
- async with create_session_async() as session:
+ async def aset(
+ self, scope: StateScope, key: str, value: str, *, session:
AsyncSession | None = None
+ ) -> None:
+ async with _async_session(session) as s:
match scope:
case TaskScope():
- await self._aset_task_state(scope, key, value,
session=session)
+ await self._aset_task_state(scope, key, value, session=s)
case AssetScope():
- await self._aset_asset_state(scope, key, value,
session=session)
+ await self._aset_asset_state(scope, key, value, session=s)
case _:
assert_never(scope)
- async def adelete(self, scope: StateScope, key: str) -> None:
- async with create_session_async() as session:
+ async def adelete(self, scope: StateScope, key: str, *, session:
AsyncSession | None = None) -> None:
+ async with _async_session(session) as s:
match scope:
case TaskScope():
- await self._adelete_task_state(scope, key, session=session)
+ await self._adelete_task_state(scope, key, session=s)
case AssetScope():
- await self._adelete_asset_state(scope, key,
session=session)
+ await self._adelete_asset_state(scope, key, session=s)
case _:
assert_never(scope)
- async def aclear(self, scope: StateScope, *, all_map_indices: bool =
False) -> None:
- async with create_session_async() as session:
+ async def aclear(
+ self, scope: StateScope, *, all_map_indices: bool = False, session:
AsyncSession | None = None
+ ) -> None:
+ async with _async_session(session) as s:
match scope:
case TaskScope():
- await self._aclear_task_state(scope,
all_map_indices=all_map_indices, session=session)
+ await self._aclear_task_state(scope,
all_map_indices=all_map_indices, session=s)
case AssetScope():
- await self._aclear_asset_state(scope, session=session)
+ await self._aclear_asset_state(scope, session=s)
case _:
assert_never(scope)
diff --git a/airflow-core/tests/unit/state/test_metastore.py
b/airflow-core/tests/unit/state/test_metastore.py
index 98993d7133c..dfd154cc92a 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -28,7 +28,7 @@ from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.task_state import TaskStateModel
from airflow.state import AssetScope, TaskScope, resolve_state_backend
from airflow.state.metastore import MetastoreStateBackend
-from airflow.utils.session import create_session
+from airflow.utils.session import create_session, create_session_async
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_assets, clear_db_dags,
clear_db_runs
@@ -379,6 +379,16 @@ class TestMetastoreStateBackendAsync:
with pytest.raises(ValueError, match="No DagRun found"):
await backend.aset(scope, "job_id", "app_async")
+ async def test_aset_and_aget_with_provided_session(
+ self, backend: MetastoreStateBackend, dag_run_committed: DagRun
+ ):
+ """async methods use a provided AsyncSession when one is given."""
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ async with create_session_async() as session:
+ await backend.aset(scope, "job_id", "app_with_session",
session=session)
+ result = await backend.aget(scope, "job_id", session=session)
+ assert result == "app_with_session"
+
class TestResolveStateBackend:
@conf_vars({("state_store", "backend"):
"airflow.state.metastore.MetastoreStateBackend"})
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index 463d9f378f3..4920f66ae67 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -18,6 +18,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.orm import Session
@dataclass(frozen=True)
@@ -62,10 +67,16 @@ class BaseStateBackend(ABC):
... # asset-specific storage
Custom backends are configured via ``[state_store] backend`` in
``airflow.cfg``.
+
+ **The ``session`` parameter on ``get``, ``set``, ``delete``, and
``clear``:**
+
+ The default ``MetastoreStateBackend`` passes a SQLAlchemy ``Session``
through
+ these methods. Custom backends that do not use SQLAlchemy should accept
``session`` as a
+ keyword argument and ignore it.
"""
@abstractmethod
- def get(self, scope: StateScope, key: str) -> str | None:
+ def get(self, scope: StateScope, key: str, *, session: Session | None =
None) -> str | None:
"""
Return the stored value, or None if the key does not exist.
@@ -73,7 +84,7 @@ class BaseStateBackend(ABC):
"""
@abstractmethod
- def set(self, scope: StateScope, key: str, value: str) -> None:
+ def set(self, scope: StateScope, key: str, value: str, *, session: Session
| None = None) -> None:
"""
Write or overwrite the value for the given key.
@@ -81,7 +92,7 @@ class BaseStateBackend(ABC):
"""
@abstractmethod
- def delete(self, scope: StateScope, key: str) -> None:
+ def delete(self, scope: StateScope, key: str, *, session: Session | None =
None) -> None:
"""
Delete a single key. No-op if the key does not exist.
@@ -89,7 +100,9 @@ class BaseStateBackend(ABC):
"""
@abstractmethod
- def clear(self, scope: StateScope, *, all_map_indices: bool = False) ->
None:
+ def clear(
+ self, scope: StateScope, *, all_map_indices: bool = False, session:
Session | None = None
+ ) -> None:
"""
Delete all keys under the given scope.
@@ -102,23 +115,45 @@ class BaseStateBackend(ABC):
"""
@abstractmethod
- async def aget(self, scope: StateScope, key: str) -> str | None:
- """Async variant of get. Must handle both ``TaskScope`` and
``AssetScope``."""
+ async def aget(self, scope: StateScope, key: str, *, session: AsyncSession
| None = None) -> str | None:
+ """
+ Async variant of get. Must handle both ``TaskScope`` and
``AssetScope``.
+
+ ``session`` is optional. If provided, implementations should use it
directly.
+ If ``None``, implementations manage their own async session internally.
+ """
@abstractmethod
- async def aset(self, scope: StateScope, key: str, value: str) -> None:
- """Async variant of set. Must handle both ``TaskScope`` and
``AssetScope``."""
+ async def aset(
+ self, scope: StateScope, key: str, value: str, *, session:
AsyncSession | None = None
+ ) -> None:
+ """
+ Async variant of set. Must handle both ``TaskScope`` and
``AssetScope``.
+
+ ``session`` is optional. If provided, implementations should use it
directly.
+ If ``None``, implementations manage their own async session internally.
+ """
@abstractmethod
- async def adelete(self, scope: StateScope, key: str) -> None:
- """Async variant of delete. Must handle both ``TaskScope`` and
``AssetScope``."""
+ async def adelete(self, scope: StateScope, key: str, *, session:
AsyncSession | None = None) -> None:
+ """
+ Async variant of delete. Must handle both ``TaskScope`` and
``AssetScope``.
+
+ ``session`` is optional. If provided, implementations should use it
directly.
+ If ``None``, implementations manage their own async session internally.
+ """
@abstractmethod
- async def aclear(self, scope: StateScope, *, all_map_indices: bool =
False) -> None:
+ async def aclear(
+ self, scope: StateScope, *, all_map_indices: bool = False, session:
AsyncSession | None = None
+ ) -> None:
"""
Async variant of clear. Must handle both ``TaskScope`` and
``AssetScope``.
For ``TaskScope``: by default, only keys for the exact ``map_index``
on the
scope are cleared. Pass ``all_map_indices=True`` to wipe state across
every
mapped instance of the task. For ``AssetScope`` the flag has no effect.
+
+ ``session`` is optional. If provided, implementations should use it
directly.
+ If ``None``, implementations manage their own async session internally.
"""