This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0cf6462459e AIP-103: Adding periodic task state garbage collection and 
retention support (#66463)
0cf6462459e is described below

commit 0cf6462459e6f1dae463aedfa38b52fef31dbbb3
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 14 11:58:46 2026 +0530

    AIP-103: Adding periodic task state garbage collection and retention 
support (#66463)
---
 airflow-core/src/airflow/cli/cli_config.py         |  19 +++
 .../airflow/cli/commands/state_store_command.py    |  49 ++++++++
 .../src/airflow/config_templates/config.yml        |  18 +++
 .../src/airflow/jobs/scheduler_job_runner.py       |  32 ++++-
 ..._3_3_0_add_task_state_and_asset_state_tables.py |   7 +-
 airflow-core/src/airflow/models/task_state.py      |  21 +++-
 airflow-core/src/airflow/state/metastore.py        |  73 ++++++++++-
 .../unit/cli/commands/test_state_store_command.py  |  65 ++++++++++
 airflow-core/tests/unit/state/test_metastore.py    | 140 ++++++++++++++++++++-
 shared/state/src/airflow_shared/state/__init__.py  |   9 ++
 10 files changed, 421 insertions(+), 12 deletions(-)

diff --git a/airflow-core/src/airflow/cli/cli_config.py 
b/airflow-core/src/airflow/cli/cli_config.py
index 4c44ab39d67..81b9dcf0600 100644
--- a/airflow-core/src/airflow/cli/cli_config.py
+++ b/airflow-core/src/airflow/cli/cli_config.py
@@ -1531,6 +1531,20 @@ TEAMS_COMMANDS = (
         args=(ARG_VERBOSE,),
     ),
 )
+STATE_STORE_COMMANDS = (
+    ActionCommand(
+        name="cleanup-task-states",
+        help="Remove expired task state rows (MetastoreStateBackend only)",
+        description=(
+            "Reads [state_store] default_retention_days from config and 
deletes task_state rows "
+            "older than the configured threshold. Only applies when 
MetastoreStateBackend is configured; "
+            "custom backends are skipped. Use --dry-run to preview without 
deleting."
+        ),
+        
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"),
+        args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
+    ),
+)
+
 DB_COMMANDS = (
     ActionCommand(
         name="check-migrations",
@@ -2115,6 +2129,11 @@ core_commands: list[CLICommand] = [
         help="Display providers",
         subcommands=PROVIDERS_COMMANDS,
     ),
+    GroupCommand(
+        name="state-store",
+        help="Manage task and asset state storage",
+        subcommands=STATE_STORE_COMMANDS,
+    ),
     ActionCommand(
         name="rotate-fernet-key",
         
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py 
b/airflow-core/src/airflow/cli/commands/state_store_command.py
new file mode 100644
index 00000000000..52bd0952561
--- /dev/null
+++ b/airflow-core/src/airflow/cli/commands/state_store_command.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+
+from airflow.state import get_state_backend
+from airflow.state.metastore import MetastoreStateBackend
+
+log = logging.getLogger(__name__)
+
+# Other state operations (list, get, delete per key) will be added here in the 
future.
+
+
+def cleanup_task_states(args) -> None:
+    """Remove expired task state rows (MetastoreStateBackend only)."""
+    backend = get_state_backend()
+
+    if not isinstance(backend, MetastoreStateBackend):
+        print("Custom backend configured — skipping cleanup (not supported).")
+        return
+
+    if args.dry_run:
+        summary = backend._summary_dry_run()
+        expired = summary["expired"]
+        if not expired:
+            print("Nothing to delete.")
+            return
+        print(f"Would delete {len(expired)} task state row(s):\n")
+        for dag_id, run_id, task_id, map_index, key in expired:
+            print(f"  Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, 
map_index {map_index!r}, key {key!r}")
+        return
+
+    log.info("Running task state cleanup")
+    backend.cleanup()
diff --git a/airflow-core/src/airflow/config_templates/config.yml 
b/airflow-core/src/airflow/config_templates/config.yml
index 8d5d6e5fd26..4b183f9c2b4 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -3025,6 +3025,24 @@ state_store:
       type: string
       example: "mypackage.state.CustomStateBackend"
       default: "airflow.state.metastore.MetastoreStateBackend"
+    default_retention_days:
+      description: |
+        Number of days to retain task state after their last update.
+        Rows older than this are removed when cleanup is triggered.
+        This config does not affect asset_state rows.
+        Set to 0 to disable time-based cleanup entirely.
+      version_added: 3.3.0
+      type: integer
+      example: "7"
+      default: "30"
+    state_cleanup_batch_size:
+      description: |
+        Number of rows deleted per batch during cleanup. Defaults to 0 (no 
batching).
+        Tune this on deployments with large task_state tables to improve 
performance per transaction.
+      version_added: 3.3.0
+      type: integer
+      example: "10000"
+      default: "0"
 
 profiling:
   description: |
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 1a3f55b7f6f..9a650b110c9 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -33,7 +33,20 @@ from functools import lru_cache, partial
 from itertools import groupby
 from typing import TYPE_CHECKING, Any, cast
 
-from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, 
select, text, tuple_, update
+from sqlalchemy import (
+    CTE,
+    and_,
+    case,
+    delete,
+    exists,
+    func,
+    inspect,
+    or_,
+    select,
+    text,
+    tuple_,
+    update,
+)
 from sqlalchemy.exc import DBAPIError, OperationalError
 from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, 
selectinload
 from sqlalchemy.sql import expression
@@ -70,6 +83,7 @@ from airflow.models.asset import (
     TaskInletAssetReference,
     TaskOutletAssetReference,
 )
+from airflow.models.asset_state import AssetStateModel
 from airflow.models.backfill import Backfill, BackfillDagRun
 from airflow.models.callback import Callback, CallbackType, ExecutorCallback
 from airflow.models.dag import DagModel
@@ -3096,6 +3110,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         self._orphan_unreferenced_assets(orphan_query, session=session)
         self._activate_referenced_assets(activate_query, session=session)
+        self._cleanup_orphaned_asset_state(session=session)
 
     @staticmethod
     def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> 
None:
@@ -3204,6 +3219,21 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             session.add(warning)
             existing_warned_dag_ids.add(warning.dag_id)
 
+    @staticmethod
+    def _cleanup_orphaned_asset_state(*, session: Session) -> None:
+        """
+        Delete asset_state rows for assets no longer active in any Dag.
+
+        When _orphan_unreferenced_assets removes an asset from asset_active, 
its
+        asset_state rows become unreachable — no task can write to them 
anymore.
+        This runs in the same pass as asset orphanage to keep the table clean.
+        """
+        active_asset_ids = select(AssetModel.id).join(
+            AssetActive,
+            (AssetActive.name == AssetModel.name) & (AssetActive.uri == 
AssetModel.uri),
+        )
+        
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))
+
     def _executor_to_workloads(
         self,
         workloads: Iterable[SchedulerWorkload],
diff --git 
a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
 
b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
index 7f852d05c6c..e64f80a05b1 100644
--- 
a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
+++ 
b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
@@ -57,6 +57,7 @@ def upgrade():
     )
     op.create_table(
         "task_state",
+        sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
         sa.Column("dag_run_id", sa.Integer(), nullable=False),
         sa.Column("task_id", StringID(), nullable=False),
         sa.Column("map_index", sa.Integer(), server_default="-1", 
nullable=False),
@@ -65,20 +66,24 @@ def upgrade():
         sa.Column("run_id", StringID(), nullable=False),
         sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), 
"mysql"), nullable=False),
         sa.Column("updated_at", UtcDateTime(), nullable=False),
+        sa.Column("expires_at", UtcDateTime(), nullable=True),
         sa.ForeignKeyConstraint(
             ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", 
ondelete="CASCADE"
         ),
-        sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", 
name="task_state_pkey"),
+        sa.PrimaryKeyConstraint("id", name="task_state_pkey"),
+        sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", 
name="task_state_uq"),
     )
     with op.batch_alter_table("task_state", schema=None) as batch_op:
         batch_op.create_index(
             "idx_task_state_lookup", ["dag_id", "run_id", "task_id", 
"map_index"], unique=False
         )
+        batch_op.create_index("idx_task_state_expires_at", ["expires_at"], 
unique=False)
 
 
 def downgrade():
     """Unapply add task_state and asset_state tables."""
     with op.batch_alter_table("task_state", schema=None) as batch_op:
+        batch_op.drop_index("idx_task_state_expires_at")
         batch_op.drop_index("idx_task_state_lookup")
 
     op.drop_table("task_state")
diff --git a/airflow-core/src/airflow/models/task_state.py 
b/airflow-core/src/airflow/models/task_state.py
index dbc17e3b069..72a7624eddd 100644
--- a/airflow-core/src/airflow/models/task_state.py
+++ b/airflow-core/src/airflow/models/task_state.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from datetime import datetime
 
-from sqlalchemy import ForeignKeyConstraint, Index, Integer, 
PrimaryKeyConstraint, String, Text
+from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, 
UniqueConstraint
 from sqlalchemy.dialects.mysql import MEDIUMTEXT
 from sqlalchemy.orm import Mapped, mapped_column
 
@@ -39,19 +39,27 @@ class TaskStateModel(Base):
 
     __tablename__ = "task_state"
 
-    dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, 
primary_key=True)
-    task_id: Mapped[str] = mapped_column(StringID(), nullable=False, 
primary_key=True)
-    map_index: Mapped[int] = mapped_column(Integer, primary_key=True, 
nullable=False, server_default="-1")
-    key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), 
nullable=False, primary_key=True)
+    id: Mapped[int] = mapped_column(Integer, primary_key=True, 
autoincrement=True)
+
+    dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
+    task_id: Mapped[str] = mapped_column(StringID(), nullable=False)
+    map_index: Mapped[int] = mapped_column(Integer, nullable=False, 
server_default="-1")
+    key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), 
nullable=False)
 
     dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
     run_id: Mapped[str] = mapped_column(StringID(), nullable=False)
 
     value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, 
"mysql"), nullable=False)
     updated_at: Mapped[datetime] = mapped_column(UtcDateTime, 
default=timezone.utcnow, nullable=False)
+    # Optional override for early expiry. When set, garbage collection deletes 
this row when
+    # expires_at < now(), even if updated_at is recent. NULL means no early 
expiry —
+    # the row is still cleaned up by the global `updated_at + 
default_retention_days` check.
+    # Populated via task_state.set(retention_days=N) for keys that should 
expire differently
+    # than the deployment wide default.
+    expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, 
nullable=True)
 
     __table_args__ = (
-        PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", 
name="task_state_pkey"),
+        UniqueConstraint("dag_run_id", "task_id", "map_index", "key", 
name="task_state_uq"),
         ForeignKeyConstraint(
             ["dag_run_id"],
             ["dag_run.id"],
@@ -59,4 +67,5 @@ class TaskStateModel(Base):
             ondelete="CASCADE",
         ),
         Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", 
"map_index"),
+        Index("idx_task_state_expires_at", "expires_at"),
     )
diff --git a/airflow-core/src/airflow/state/metastore.py 
b/airflow-core/src/airflow/state/metastore.py
index 31b4de3158f..f58c69f5808 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -19,17 +19,20 @@ from __future__ import annotations
 
 from collections.abc import AsyncGenerator
 from contextlib import asynccontextmanager
+from datetime import datetime, timedelta
 from typing import TYPE_CHECKING
 
+import structlog
 from sqlalchemy import delete, select
 
 from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, 
TaskScope
 from airflow._shared.timezones import timezone
+from airflow.configuration import conf
 from airflow.models.asset_state import AssetStateModel
 from airflow.models.dagrun import DagRun
 from airflow.models.task_state import TaskStateModel
 from airflow.typing_compat import assert_never
-from airflow.utils.session import NEW_SESSION, create_session_async, 
provide_session
+from airflow.utils.session import NEW_SESSION, create_session, 
create_session_async, provide_session
 from airflow.utils.sqlalchemy import get_dialect_name
 
 if TYPE_CHECKING:
@@ -40,6 +43,21 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
 
+log = structlog.get_logger(__name__)
+
+
+def _compute_expires_at(now: datetime) -> datetime | None:
+    """
+    Return the expiry timestamp for a new task state row based on config.
+
+    Returns None if default_retention_days is 0 (never expires).
+    """
+    retention_days = conf.getint("state_store", "default_retention_days")
+    if retention_days <= 0:
+        return None
+    return now + timedelta(days=retention_days)
+
+
 @asynccontextmanager
 async def _async_session(session: AsyncSession | None) -> 
AsyncGenerator[AsyncSession, None]:
     """Use provided async session or create a new one."""
@@ -200,6 +218,7 @@ class MetastoreStateBackend(BaseStateBackend):
         if dag_run_id is None:
             raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} 
run_id={scope.run_id!r}")
         now = timezone.utcnow()
+        expires_at = _compute_expires_at(now)
         values = dict(
             dag_run_id=dag_run_id,
             dag_id=scope.dag_id,
@@ -209,13 +228,14 @@ class MetastoreStateBackend(BaseStateBackend):
             key=key,
             value=value,
             updated_at=now,
+            expires_at=expires_at,
         )
         stmt = _build_upsert_stmt(
             get_dialect_name(session),
             TaskStateModel,
             ["dag_run_id", "task_id", "map_index", "key"],
             values,
-            dict(value=value, updated_at=now),
+            dict(value=value, updated_at=now, expires_at=expires_at),
         )
         session.execute(stmt)
 
@@ -276,6 +296,51 @@ class MetastoreStateBackend(BaseStateBackend):
             )
         )
 
+    def cleanup(self) -> None:
+        """
+        Remove expired task state rows.
+
+        ``expires_at`` is set at write time on every ``set()`` call, so 
cleanup is a single
+        ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` 
(default_retention_days=0)
+        are never deleted. Batching is configurable via ``[state_store] 
state_cleanup_batch_size``.
+        """
+        batch_size = conf.getint("state_store", "state_cleanup_batch_size")
+        now = timezone.utcnow()
+
+        def _delete_batched(where_clause) -> int:
+            total = 0
+            with create_session() as session:
+                while True:
+                    id_query = select(TaskStateModel.id).where(where_clause)
+                    if batch_size > 0:
+                        id_query = id_query.limit(batch_size)
+                    ids = session.scalars(id_query).all()
+                    if not ids:
+                        break
+                    
session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids)))
+                    session.commit()
+                    total += len(ids)
+                    if batch_size <= 0 or len(ids) < batch_size:
+                        break
+            return total
+
+        deleted = _delete_batched(TaskStateModel.expires_at < now)
+        log.info("Deleted expired task_state rows", rows_deleted=deleted)
+
+    def _summary_dry_run(self) -> dict[str, list]:
+        """Return rows that would be deleted by cleanup() without deleting 
anything."""
+        now = timezone.utcnow()
+        cols = (
+            TaskStateModel.dag_id,
+            TaskStateModel.run_id,
+            TaskStateModel.task_id,
+            TaskStateModel.map_index,
+            TaskStateModel.key,
+        )
+        with create_session() as session:
+            expired = 
session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all()
+        return {"expired": list(expired)}
+
     async def _aget_task_state(self, scope: TaskScope, key: str, *, session: 
AsyncSession) -> str | None:
         row = await session.scalar(
             select(TaskStateModel).where(
@@ -300,6 +365,7 @@ class MetastoreStateBackend(BaseStateBackend):
         if dag_run_id is None:
             raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} 
run_id={scope.run_id!r}")
         now = timezone.utcnow()
+        expires_at = _compute_expires_at(now)
         values = dict(
             dag_run_id=dag_run_id,
             dag_id=scope.dag_id,
@@ -309,6 +375,7 @@ class MetastoreStateBackend(BaseStateBackend):
             key=key,
             value=value,
             updated_at=now,
+            expires_at=expires_at,
         )
         # get_dialect_name expects a sync Session; sync_session is the 
underlying Session the async wrapper delegates to
         stmt = _build_upsert_stmt(
@@ -316,7 +383,7 @@ class MetastoreStateBackend(BaseStateBackend):
             TaskStateModel,
             ["dag_run_id", "task_id", "map_index", "key"],
             values,
-            dict(value=value, updated_at=now),
+            dict(value=value, updated_at=now, expires_at=expires_at),
         )
         await session.execute(stmt)
 
diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py 
b/airflow-core/tests/unit/cli/commands/test_state_store_command.py
new file mode 100644
index 00000000000..e4b44eee13f
--- /dev/null
+++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from argparse import Namespace
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.cli.commands.state_store_command import cleanup_task_states
+from airflow.state.metastore import MetastoreStateBackend
+
+pytestmark = pytest.mark.db_test
+
+
+class TestStateStoreCleanupCommand:
+    def test_cleanup_calls_backend(self):
+        args = Namespace(dry_run=False, verbose=False)
+        backend = MetastoreStateBackend()
+        with (
+            
mock.patch("airflow.cli.commands.state_store_command.get_state_backend", 
return_value=backend),
+            patch.object(backend, "cleanup"),
+        ):
+            cleanup_task_states(args)
+
+            backend.cleanup.assert_called_once_with()
+
+    def test_dry_run_does_not_call_backend(self, capsys):
+        args = Namespace(dry_run=True, verbose=False)
+        backend = MetastoreStateBackend()
+        with (
+            
mock.patch("airflow.cli.commands.state_store_command.get_state_backend", 
return_value=backend),
+            patch.object(backend, "_summary_dry_run", return_value={"expired": 
[]}),
+        ):
+            cleanup_task_states(args)
+
+            captured = capsys.readouterr()
+            assert "Nothing to delete" in captured.out
+
+    def test_custom_backend_is_skipped(self, capsys):
+        args = Namespace(dry_run=False, verbose=False)
+        custom_backend = MagicMock(spec=[])
+        with mock.patch(
+            "airflow.cli.commands.state_store_command.get_state_backend", 
return_value=custom_backend
+        ):
+            cleanup_task_states(args)
+
+            captured = capsys.readouterr()
+            assert "Custom backend configured" in captured.out
+            assert not hasattr(custom_backend, "cleanup") or not 
custom_backend.cleanup.called
diff --git a/airflow-core/tests/unit/state/test_metastore.py 
b/airflow-core/tests/unit/state/test_metastore.py
index dfd154cc92a..d9e1ff33afd 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -17,13 +17,18 @@
 # under the License.
 from __future__ import annotations
 
+from contextlib import contextmanager
+from datetime import timedelta
 from typing import TYPE_CHECKING
+from unittest.mock import patch
 
 import pytest
-from sqlalchemy import select
+from sqlalchemy import Delete, select
 
 from airflow._shared.timezones import timezone
+from airflow.configuration import conf
 from airflow.models.asset import AssetModel
+from airflow.models.asset_state import AssetStateModel
 from airflow.models.dagrun import DagRun, DagRunType
 from airflow.models.task_state import TaskStateModel
 from airflow.state import AssetScope, TaskScope, resolve_state_backend
@@ -234,6 +239,113 @@ class TestMetastoreStateBackendTaskScope:
         assert backend.get(scope0, "job_id", session=session) is None
         assert backend.get(scope1, "job_id", session=session) is None
 
+    def test_set_populates_expires_at(
+        self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+    ):
+        """set() always populates expires_at so cleanup has a single pass."""
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.flush()
+
+        row = session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "job_id"))
+        assert row is not None
+        assert row.expires_at is not None
+        assert row.expires_at > row.updated_at
+
+    def test_cleanup_removes_expired_rows(
+        self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+    ):
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        backend.set(scope, "old_key", "old_value", session=session)
+        backend.set(scope, "new_key", "new_value", session=session)
+        session.flush()
+
+        # Backdate expires_at on old_key to simulate it having expired
+        old_row = session.scalar(
+            select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, 
TaskStateModel.key == "old_key")
+        )
+        assert old_row is not None
+        old_row.expires_at = timezone.utcnow() - timedelta(hours=1)
+        session.flush()
+        session.commit()
+
+        backend.cleanup()
+
+        session.expire_all()
+        assert session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "old_key")) is None
+        assert session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "new_key")) is not None
+
+    def test_cleanup_removes_expires_at_rows(
+        self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+    ):
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        backend.set(scope, "short_lived", "value", session=session)
+        session.flush()
+
+        row = session.scalar(
+            select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, 
TaskStateModel.key == "short_lived")
+        )
+        assert row is not None
+        row.expires_at = timezone.utcnow() - timedelta(hours=1)
+        session.flush()
+        session.commit()
+
+        backend.cleanup()
+
+        session.expire_all()
+
+        # cleaned up via expires_at, even though updated_at is recent
+        assert session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "short_lived")) is None
+
+    @conf_vars({("state_store", "state_cleanup_batch_size"): "2"})
+    def test_cleanup_batches_deletes(self, session: Session, backend: 
MetastoreStateBackend, dag_run: DagRun):
+        """cleanup() issues one DELETE per batch, not one DELETE for all rows 
at once.
+
+        Verifying this is not straightforward because cleanup() creates its 
own internal session,
+        so we cannot simply inspect it from outside, so what we do is:
+
+        1. Patch `create_session` in the metastore module with a thin wrapper 
(`tracking_cs`) that
+           yields the real session but replaces `session.execute` with a spy.
+        2. The spy checks whether the statement being executed is a sqla 
Delete object and
+           records it if so.
+        3. After cleanup() returns, we assert that exactly ceil(<number of 
rows>/<batch size>).
+        """
+        import airflow.state.metastore as metastore_mod
+
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        for key in ("k1", "k2", "k3", "k4", "k5"):
+            backend.set(scope, key, "v", session=session)
+            session.flush()
+
+        session.execute(
+            
TaskStateModel.__table__.update().values(expires_at=timezone.utcnow() - 
timedelta(hours=1))
+        )
+        session.commit()
+
+        deletes = []
+        original_cs = metastore_mod.create_session
+
+        @contextmanager
+        def tracking_cs(*args, **kwargs):
+            with original_cs(*args, **kwargs) as s:
+                orig_execute = s.execute
+
+                def tracked(stmt, *a, **kw):
+                    if isinstance(stmt, Delete):
+                        deletes.append(stmt)
+                    return orig_execute(stmt, *a, **kw)
+
+                s.execute = tracked
+                yield s
+
+        with patch.object(metastore_mod, "create_session", 
side_effect=tracking_cs):
+            backend.cleanup()
+
+        session.expire_all()
+
+        # batch_size=2, 5 rows -> delete runs 3 times (2+2+1)
+        assert len(deletes) == 3
+
 
 class TestMetastoreStateBackendAssetScope:
     def test_get_returns_none_for_missing_key(
@@ -306,6 +418,19 @@ class TestMetastoreStateBackendAssetScope:
 
         assert backend.get(scope2, "watermark", session=session) is None
 
+    def test_cleanup_does_not_touch_asset_state(
+        self, session: Session, backend: MetastoreStateBackend, asset: 
AssetModel
+    ):
+        scope = AssetScope(asset_id=asset.id)
+        backend.set(scope, "watermark", "2026-01-01", session=session)
+        session.flush()
+        session.commit()
+
+        backend.cleanup()
+
+        session.expire_all()
+        assert 
session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id == 
asset.id)) is not None
+
 
 @pytest.mark.asyncio(loop_scope="class")
 class TestMetastoreStateBackendAsync:
@@ -390,6 +515,19 @@ class TestMetastoreStateBackendAsync:
         assert result == "app_with_session"
 
 
+class TestStateStoreConfig:
+    def test_defaults(self):
+        assert conf.getint("state_store", "default_retention_days") == 30
+        assert conf.getint("state_store", "state_cleanup_batch_size") == 0
+
+    @conf_vars(
+        {("state_store", "default_retention_days"): "7", ("state_store", 
"state_cleanup_batch_size"): "50"}
+    )
+    def test_overrides(self):
+        assert conf.getint("state_store", "default_retention_days") == 7
+        assert conf.getint("state_store", "state_cleanup_batch_size") == 50
+
+
 class TestResolveStateBackend:
     @conf_vars({("state_store", "backend"): 
"airflow.state.metastore.MetastoreStateBackend"})
     def test_resolve_returns_configured_backend(self):
diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index 4920f66ae67..e231bdfd3bd 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -157,3 +157,12 @@ class BaseStateBackend(ABC):
         ``session`` is optional. If provided, implementations should use it 
directly.
         If ``None``, implementations manage their own async session internally.
         """
+
+    def cleanup(self) -> None:
+        """
+        Remove expired and orphaned state records.
+
+        This is a no-op by default. Custom backends override this to implement 
their own
+        retention policy. The backend is responsible for reading any relevant 
config (e.g.
+        ``[state_store] default_retention_days``) and deciding what to delete.
+        """

Reply via email to