This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 008cbe90e2a AIP-103: Adding ability for per task state key retention
from operators (#66699)
008cbe90e2a is described below
commit 008cbe90e2a8c3fc7e315ee61de0223854e88ec0
Author: Amogh Desai <[email protected]>
AuthorDate: Tue May 19 11:59:45 2026 +0530
AIP-103: Adding ability for per task state key retention from operators
(#66699)
---
.../execution_api/datamodels/task_state.py | 3 ++
.../api_fastapi/execution_api/routes/task_state.py | 2 +-
airflow-core/src/airflow/state/metastore.py | 56 ++++++++++++--------
.../execution_api/versions/head/test_task_state.py | 30 +++++++++++
airflow-core/tests/unit/state/test_metastore.py | 19 +++++--
shared/state/src/airflow_shared/state/__init__.py | 24 ++++++++-
task-sdk/docs/api.rst | 4 ++
task-sdk/src/airflow/sdk/__init__.py | 3 ++
task-sdk/src/airflow/sdk/api/client.py | 5 +-
.../src/airflow/sdk/api/datamodels/_generated.py | 1 +
task-sdk/src/airflow/sdk/execution_time/comms.py | 1 +
task-sdk/src/airflow/sdk/execution_time/context.py | 31 ++++++++++--
.../src/airflow/sdk/execution_time/supervisor.py | 2 +-
task-sdk/tests/task_sdk/api/test_client.py | 39 +++++++++++++-
.../tests/task_sdk/execution_time/test_context.py | 59 ++++++++++++++++++++--
.../task_sdk/execution_time/test_supervisor.py | 24 ++++++++-
.../task_sdk/execution_time/test_task_runner.py | 35 +++++++++++--
17 files changed, 295 insertions(+), 43 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
index 3200f3177af..20980b315c3 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
@@ -17,6 +17,8 @@
from __future__ import annotations
+from datetime import datetime
+
from airflow.api_fastapi.core_api.base import StrictBaseModel
@@ -30,3 +32,4 @@ class TaskStatePutBody(StrictBaseModel):
"""Request body for setting a task state value."""
value: str
+ expires_at: datetime | None = None
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index acdaa8c6a24..2f824e3ebb2 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -86,7 +86,7 @@ def set_task_state(
) -> None:
"""Set a task state key, creating or updating the row."""
scope = _get_task_scope_for_ti(task_instance_id, session)
- get_state_backend().set(scope, key, body.value, session=session)
+ get_state_backend().set(scope, key, body.value,
expires_at=body.expires_at, session=session)
@router.delete("/{task_instance_id}/{key}",
status_code=status.HTTP_204_NO_CONTENT)
diff --git a/airflow-core/src/airflow/state/metastore.py
b/airflow-core/src/airflow/state/metastore.py
index f58c69f5808..e5e0a82be3e 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
-from datetime import datetime, timedelta
+from datetime import datetime
from typing import TYPE_CHECKING
import structlog
@@ -46,18 +46,6 @@ if TYPE_CHECKING:
log = structlog.get_logger(__name__)
-def _compute_expires_at(now: datetime) -> datetime | None:
- """
- Return the expiry timestamp for a new task state row based on config.
-
- Returns None if default_retention_days is 0 (never expires).
- """
- retention_days = conf.getint("state_store", "default_retention_days")
- if retention_days <= 0:
- return None
- return now + timedelta(days=retention_days)
-
-
@asynccontextmanager
async def _async_session(session: AsyncSession | None) ->
AsyncGenerator[AsyncSession, None]:
"""Use provided async session or create a new one."""
@@ -111,12 +99,20 @@ class MetastoreStateBackend(BaseStateBackend):
assert_never(scope)
@provide_session
- def set(self, scope: StateScope, key: str, value: str, *, session: Session
| None = NEW_SESSION) -> None:
+ def set(
+ self,
+ scope: StateScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: Session | None = NEW_SESSION,
+ ) -> None:
if TYPE_CHECKING:
assert session is not None
match scope:
case TaskScope():
- self._set_task_state(scope, key, value, session=session)
+ self._set_task_state(scope, key, value, expires_at=expires_at,
session=session)
case AssetScope():
self._set_asset_state(scope, key, value, session=session)
case _:
@@ -163,12 +159,18 @@ class MetastoreStateBackend(BaseStateBackend):
assert_never(scope)
async def aset(
- self, scope: StateScope, key: str, value: str, *, session:
AsyncSession | None = None
+ self,
+ scope: StateScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: AsyncSession | None = None,
) -> None:
async with _async_session(session) as s:
match scope:
case TaskScope():
- await self._aset_task_state(scope, key, value, session=s)
+ await self._aset_task_state(scope, key, value,
expires_at=expires_at, session=s)
case AssetScope():
await self._aset_asset_state(scope, key, value, session=s)
case _:
@@ -208,7 +210,15 @@ class MetastoreStateBackend(BaseStateBackend):
)
return row.value if row is not None else None
- def _set_task_state(self, scope: TaskScope, key: str, value: str, *,
session: Session) -> None:
+ def _set_task_state(
+ self,
+ scope: TaskScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: Session,
+ ) -> None:
dag_run_id = session.scalar(
select(DagRun.id).where(
DagRun.dag_id == scope.dag_id,
@@ -218,7 +228,6 @@ class MetastoreStateBackend(BaseStateBackend):
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r}
run_id={scope.run_id!r}")
now = timezone.utcnow()
- expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
@@ -354,7 +363,13 @@ class MetastoreStateBackend(BaseStateBackend):
return row.value if row is not None else None
async def _aset_task_state(
- self, scope: TaskScope, key: str, value: str, *, session: AsyncSession
+ self,
+ scope: TaskScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: AsyncSession,
) -> None:
dag_run_id = await session.scalar(
select(DagRun.id).where(
@@ -365,7 +380,6 @@ class MetastoreStateBackend(BaseStateBackend):
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r}
run_id={scope.run_id!r}")
now = timezone.utcnow()
- expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
index 8a66a0a23c7..d83751050e7 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
@@ -16,9 +16,11 @@
# under the License.
from __future__ import annotations
+from datetime import datetime
from typing import TYPE_CHECKING
from uuid import uuid4
+import pendulum
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
@@ -95,6 +97,34 @@ class TestPutTaskState:
assert row is not None
assert row.value == "spark_001"
+ def test_put_with_expires_at_creates_row(
+ self, client: TestClient, create_task_instance: CreateTaskInstance,
time_machine
+ ):
+
+ ti = create_task_instance()
+ time_machine.move_to(datetime(2026, 5, 5, 12, 0, 0), tick=False)
+ response = client.put(
+ _api_url(ti.id, "job_id"),
+ json={
+ "value": "spark_001",
+ "expires_at": datetime(2026, 5, 15, 12, 0, 0,
tzinfo=pendulum.UTC).isoformat(),
+ },
+ )
+
+ assert response.status_code == 204
+ with create_session() as session:
+ row = session.scalar(
+ select(TaskStateModel).where(
+ TaskStateModel.dag_id == ti.dag_id,
+ TaskStateModel.run_id == ti.run_id,
+ TaskStateModel.task_id == ti.task_id,
+ TaskStateModel.key == "job_id",
+ )
+ )
+ assert row is not None
+ assert row.value == "spark_001"
+ assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0,
tzinfo=pendulum.UTC)
+
def test_put_overwrites_existing(self, client: TestClient,
create_task_instance: CreateTaskInstance):
ti = create_task_instance()
client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"})
diff --git a/airflow-core/tests/unit/state/test_metastore.py
b/airflow-core/tests/unit/state/test_metastore.py
index d9e1ff33afd..fbd37ddc30e 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -239,18 +239,29 @@ class TestMetastoreStateBackendTaskScope:
assert backend.get(scope0, "job_id", session=session) is None
assert backend.get(scope1, "job_id", session=session) is None
- def test_set_populates_expires_at(
+ def test_set_without_expires_at_stores_null(
self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
):
- """set() always populates expires_at so cleanup has a single pass."""
+ """set() without expires_at stores NULL — the worker is responsible
for computing expiry."""
scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
backend.set(scope, "job_id", "app_1234", session=session)
session.flush()
row = session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "job_id"))
assert row is not None
- assert row.expires_at is not None
- assert row.expires_at > row.updated_at
+ assert row.expires_at is None
+
+ def test_set_expires_at_none_stores_null(
+ self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+ ):
+ """expires_at=None stores NULL — the key never expires regardless of
global config."""
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ backend.set(scope, "job_id", "app_1234", session=session)
+ session.flush()
+
+ row = session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "job_id"))
+ assert row is not None
+ assert row.expires_at is None
def test_cleanup_removes_expired_rows(
self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index e231bdfd3bd..bfce8db3328 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -21,6 +21,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
+ from datetime import datetime
+
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
@@ -84,11 +86,23 @@ class BaseStateBackend(ABC):
"""
@abstractmethod
- def set(self, scope: StateScope, key: str, value: str, *, session: Session
| None = None) -> None:
+ def set(
+ self,
+ scope: StateScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: Session | None = None,
+ ) -> None:
"""
Write or overwrite the value for the given key.
Must handle both ``TaskScope`` and ``AssetScope``.
+
+ ``expires_at`` is an absolute UTC datetime after which the row may be
deleted.
+ Pass ``None`` (default) for a key that should never expire — stored as
``NULL``,
+ skipped by garbage collection.
"""
@abstractmethod
@@ -125,7 +139,13 @@ class BaseStateBackend(ABC):
@abstractmethod
async def aset(
- self, scope: StateScope, key: str, value: str, *, session:
AsyncSession | None = None
+ self,
+ scope: StateScope,
+ key: str,
+ value: str,
+ *,
+ expires_at: datetime | None = None,
+ session: AsyncSession | None = None,
) -> None:
"""
Async variant of set. Must handle both ``TaskScope`` and
``AssetScope``.
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index cb9789f5bb6..222b28ad53d 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -261,6 +261,10 @@ For a complete list of available context variables (such
as ``dag_run``,
``task_instance``, ``logical_date``, etc.), see the
:ref:`Templates reference <templates-ref>`.
+.. rubric:: Task State
+
+.. autodata:: airflow.sdk.NEVER_EXPIRE
+
.. rubric:: Logging
.. autofunction:: airflow.sdk.log.mask_secret
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index f304b068237..05ececc2956 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -55,6 +55,7 @@ __all__ = [
"IdentityMapper",
"Label",
"Metadata",
+ "NEVER_EXPIRE",
"MultipleCronTriggerTimetable",
"ObjectStoragePath",
"Param",
@@ -170,6 +171,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.sdk.execution_time import macros
+ from airflow.sdk.execution_time.context import NEVER_EXPIRE
from airflow.sdk.io.path import ObjectStoragePath
from airflow.sdk.types import TaskInstance
@@ -245,6 +247,7 @@ __lazy_imports: dict[str, str] = {
"conf": ".configuration",
"cross_downstream": ".bases.operator",
"dag": ".definitions.dag",
+ "NEVER_EXPIRE": ".execution_time.context",
"get_current_context": ".definitions.context",
"get_parsing_context": ".definitions.context",
"literal": ".definitions.template",
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 269978ac9dd..24074824b88 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -21,6 +21,7 @@ import logging
import ssl
import sys
import uuid
+from datetime import datetime
from functools import cache
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, TypeVar
@@ -693,9 +694,9 @@ class TaskStateOperations:
raise
return TaskStateResponse.model_validate_json(resp.read())
- def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse:
+ def set(self, ti_id: uuid.UUID, key: str, value: str, expires_at: datetime
| None) -> OKResponse:
"""Set a task state value via the API server."""
- body = TaskStatePutBody(value=value)
+ body = TaskStatePutBody(value=value, expires_at=expires_at)
self.client.put(f"state/ti/{ti_id}/{key}",
content=body.model_dump_json())
return OKResponse(ok=True)
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 9f1dadeef51..fc966a76969 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -376,6 +376,7 @@ class TaskStatePutBody(BaseModel):
extra="forbid",
)
value: Annotated[str, Field(title="Value")]
+ expires_at: Annotated[AwareDatetime | None, Field(title="Expires At")] =
None
class TaskStateResponse(BaseModel):
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index c56f5b23ab3..2364e942ed0 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -924,6 +924,7 @@ class SetTaskState(BaseModel):
ti_id: UUID
key: str
value: str
+ expires_at: AwareDatetime | None
type: Literal["SetTaskState"] = "SetTaskState"
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index cdd25803989..afc5a120c0c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -21,7 +21,7 @@ import contextlib
import functools
import inspect
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from uuid import UUID
@@ -29,6 +29,7 @@ from uuid import UUID
import attrs
import structlog
+from airflow.sdk.configuration import conf
from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import (
@@ -108,6 +109,11 @@ AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
log = structlog.get_logger(logger_name="task")
+#: Pass as ``retention`` to ``task_state.set()`` to store a key that never
expires,
+#: regardless of the global ``[state_store] default_retention_days`` config.
+#: Example: ``context["task_state"].set("job_id", job_id,
retention=NEVER_EXPIRE)``
+NEVER_EXPIRE: timedelta = timedelta.max
+
T = TypeVar("T")
@@ -467,12 +473,29 @@ class TaskStateAccessor:
return resp.value
return None
- def set(self, key: str, value: str) -> None:
- """Write or overwrite the value for the given key."""
+ def set(self, key: str, value: str, *, retention: timedelta | None = None)
-> None:
+ """
+ Write or overwrite the value for the given key.
+
+ ``retention`` is an optional key that controls when this key expires:
+
+ - ``timedelta(...)`` — expire after the given duration (e.g.
``timedelta(hours=6)``).
+ - ``NEVER_EXPIRE`` — key never expires, regardless of the global
config and is skipped by garbage collection.
+ - ``None`` (default) — use the global ``[state_store]
default_retention_days`` config.
+ """
from airflow.sdk.execution_time.comms import SetTaskState
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
- SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=value))
+ # expires_at is always resolved on the worker in UTC before being sent.
+ now = datetime.now(tz=timezone.utc)
+ if retention is NEVER_EXPIRE:
+ expires_at = None
+ elif retention is not None:
+ expires_at = now + retention
+ else:
+ days = conf.getint("state_store", "default_retention_days")
+ expires_at = None if days <= 0 else now + timedelta(days=days)
+ SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=value, expires_at=expires_at))
def delete(self, key: str) -> None:
"""Delete a single key. No-op if the key does not exist."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 3e6236c5786..5e46b6bc864 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1660,7 +1660,7 @@ class ActivitySubprocess(WatchedSubprocess):
else TaskStateResult.from_task_state_response(task_state)
)
elif isinstance(msg, SetTaskState):
- self.client.task_state.set(msg.ti_id, msg.key, msg.value)
+ self.client.task_state.set(msg.ti_id, msg.key, msg.value,
expires_at=msg.expires_at)
resp = OKResponse(ok=True)
elif isinstance(msg, DeleteTaskState):
self.client.task_state.delete(msg.ti_id, msg.key)
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index a179ff08436..805dac87934 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -1764,6 +1764,8 @@ class TestTaskStateOperations:
assert result.error == ErrorType.TASK_STATE_NOT_FOUND
def test_set_success(self):
+ expires = datetime(2026, 6, 13, 12, 0, 0, tzinfo=dt_timezone.utc)
+
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.method == "PUT"
assert request.url.path == f"/state/ti/{self.TI_ID}/job_id"
@@ -1771,7 +1773,42 @@ class TestTaskStateOperations:
return httpx.Response(status_code=204)
client = make_client(transport=httpx.MockTransport(handle_request))
- result = client.task_state.set(ti_id=self.TI_ID, key="job_id",
value="spark_app_001")
+ result = client.task_state.set(
+ ti_id=self.TI_ID, key="job_id", value="spark_app_001",
expires_at=expires
+ )
+ assert result == OKResponse(ok=True)
+
+ def test_set_with_expires_at_sends_field(self):
+ """expires_at is forwarded as an ISO datetime string in the request
body."""
+ expires = datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt_timezone.utc)
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ body = json.loads(request.content)
+ assert body["value"] == "spark_app_001"
+ assert body["expires_at"] == "2026-05-21T12:00:00Z"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.set(
+ ti_id=self.TI_ID, key="job_id", value="spark_app_001",
expires_at=expires
+ )
+ assert result == OKResponse(ok=True)
+
+ def test_set_with_never_expire_sends_null_expires_at(self):
+ """NEVER_EXPIRE sends expires_at=null — stored as NULL in DB, GC skips
it."""
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ body = json.loads(request.content)
+ assert body.get("expires_at") is None
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.set(
+ ti_id=self.TI_ID,
+ key="job_id",
+ value="v",
+ expires_at=None,
+ )
assert result == OKResponse(ok=True)
def test_delete_success(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index ff0e6025c63..a5ff7be9ce8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+from datetime import datetime, timedelta, timezone as dt_timezone
from unittest import mock
from unittest.mock import MagicMock, patch
from uuid import UUID
@@ -73,6 +74,7 @@ from airflow.sdk.execution_time.comms import (
XComResult,
)
from airflow.sdk.execution_time.context import (
+ NEVER_EXPIRE,
AssetStateAccessor,
AssetStateAccessors,
ConnectionAccessor,
@@ -92,6 +94,8 @@ from airflow.sdk.execution_time.context import (
)
from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend
+from tests_common.test_utils.config import conf_vars
+
def test_convert_connection_result_conn():
"""Test that the ConnectionResult is converted to a Connection object."""
@@ -1085,13 +1089,62 @@ class TestTaskStateAccessor:
with pytest.raises(AirflowRuntimeError):
TaskStateAccessor(ti_id=self.TI_ID).get("some_key")
- def test_set_operation(self, mock_supervisor_comms):
+ def test_set_operation_with_global_retention(self, mock_supervisor_comms,
time_machine):
+ """set() with no retention uses global default_retention_days
config."""
+
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+ now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(now, tick=False)
+
+ with conf_vars({("state_store", "default_retention_days"): "30"}):
+ TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(
+ ti_id=self.TI_ID,
+ key="job_id",
+ value="app_001",
+ expires_at=datetime(2026, 6, 13, 12, 0, 0,
tzinfo=dt_timezone.utc),
+ )
+ )
+
+ def test_set_with_retention_computes_expires_at(self,
mock_supervisor_comms, time_machine):
+ """set(retention=timedelta(...)) computes expires_at on the worker and
sends it."""
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+ now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(now, tick=False)
+
+ TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001",
retention=timedelta(days=7))
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(
+ ti_id=self.TI_ID,
+ key="job_id",
+ value="app_001",
+ expires_at=datetime(2026, 5, 21, 12, 0, 0,
tzinfo=dt_timezone.utc),
+ )
+ )
+
+ def test_set_with_never_expire_sends_null_expires_at(self,
mock_supervisor_comms):
+ """set(retention=NEVER_EXPIRE) sends expires_at=None"""
+
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001",
retention=NEVER_EXPIRE)
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001",
expires_at=None)
+ )
+
+ def test_set_global_default_zero_sends_null_expires_at(self,
mock_supervisor_comms):
+ """When default_retention_days=0 (never expire globally),
expires_at=None (stored as NULL)."""
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
- TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+ with conf_vars({("state_store", "default_retention_days"): "0"}):
+ TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
mock_supervisor_comms.send.assert_called_once_with(
- SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001")
+ SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001",
expires_at=None)
)
def test_delete_operation(self, mock_supervisor_comms):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 51131f3c48d..b4ba8de42e2 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2722,11 +2722,33 @@ REQUEST_TEST_CASES = [
expected_body={"value": "spark_app_001", "type": "TaskStateResult"},
),
RequestTestCase(
- message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"),
+ message=SetTaskState(
+ ti_id=TI_ID,
+ key="job_id",
+ value="spark_app_001",
+ expires_at=datetime(2026, 6, 13, 12, 0, 0, tzinfo=dt_timezone.utc),
+ ),
test_id="set_task_state",
client_mock=ClientMock(
method_path="task_state.set",
args=(TI_ID, "job_id", "spark_app_001"),
+ kwargs={"expires_at": datetime(2026, 6, 13, 12, 0, 0,
tzinfo=dt_timezone.utc)},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=SetTaskState(
+ ti_id=TI_ID,
+ key="job_id",
+ value="spark_app_001",
+ expires_at=datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt_timezone.utc),
+ ),
+ test_id="set_task_state_with_expires_at",
+ client_mock=ClientMock(
+ method_path="task_state.set",
+ args=(TI_ID, "job_id", "spark_app_001"),
+ kwargs={"expires_at": datetime(2026, 5, 21, 12, 0, 0,
tzinfo=dt_timezone.utc)},
response=OKResponse(ok=True),
),
expected_body={"ok": True, "type": "OKResponse"},
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 5234b30b84c..54afc412f56 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -4966,7 +4966,7 @@ def test_dag_add_result(create_runtime_ti,
mock_supervisor_comms):
class TestTaskInstanceStateOperations:
"""Tests to verify that tasks can perform state operations (task / asset)
via the supervisor."""
- def test_task_can_set_and_get_state(self, create_runtime_ti,
mock_supervisor_comms):
+ def test_task_can_set_and_get_state(self, create_runtime_ti,
mock_supervisor_comms, time_machine):
class MyOperator(BaseOperator):
def execute(self, context):
ts = context["task_state"]
@@ -4975,14 +4975,43 @@ class TestTaskInstanceStateOperations:
task = MyOperator(task_id="t")
runtime_ti = create_runtime_ti(task=task)
+ frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(frozen_dt, tick=False)
- run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+ with conf_vars({("state_store", "default_retention_days"): "30"}):
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
mock_supervisor_comms.send.assert_any_call(
- SetTaskState(ti_id=runtime_ti.id, key="job_id",
value="spark_app_001")
+ SetTaskState(
+ ti_id=runtime_ti.id,
+ key="job_id",
+ value="spark_app_001",
+ expires_at=frozen_dt + timedelta(days=30),
+ )
)
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id,
key="job_id"))
+ def test_task_can_set_state_with_retention(self, create_runtime_ti,
mock_supervisor_comms, time_machine):
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ context["task_state"].set("job_id", "spark_app_001",
retention=timedelta(days=7))
+
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+ frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(frozen_dt, tick=False)
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(
+ ti_id=runtime_ti.id,
+ key="job_id",
+ value="spark_app_001",
+ expires_at=frozen_dt + timedelta(days=7),
+ )
+ )
+
def test_task_can_delete_state(self, create_runtime_ti,
mock_supervisor_comms):
class MyOperator(BaseOperator):
def execute(self, context):