This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6d5174494df AIP-103: Wiring up task SDK comms and context accessors
(#66160)
6d5174494df is described below
commit 6d5174494dfae7d81bc32b4d0edd4c06e8314392
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 7 16:12:26 2026 +0530
AIP-103: Wiring up task SDK comms and context accessors (#66160)
---
.../api_fastapi/execution_api/routes/assets.py | 11 +-
.../api_fastapi/execution_api/versions/__init__.py | 7 +-
.../execution_api/versions/v2026_04_17.py | 8 +
.../tests/unit/dag_processing/test_processor.py | 18 ++
airflow-core/tests/unit/jobs/test_triggerer_job.py | 18 ++
.../tests/unit/standard/operators/test_python.py | 6 +
.../check_template_context_variable_in_sync.py | 3 +
task-sdk/src/airflow/sdk/api/client.py | 113 ++++++++
task-sdk/src/airflow/sdk/definitions/context.py | 8 +-
task-sdk/src/airflow/sdk/exceptions.py | 2 +
task-sdk/src/airflow/sdk/execution_time/comms.py | 132 +++++++++-
task-sdk/src/airflow/sdk/execution_time/context.py | 226 +++++++++++++++-
.../src/airflow/sdk/execution_time/supervisor.py | 66 +++++
.../src/airflow/sdk/execution_time/task_runner.py | 7 +
task-sdk/tests/task_sdk/api/test_client.py | 220 ++++++++++++++++
.../tests/task_sdk/execution_time/test_context.py | 287 ++++++++++++++++++++-
.../task_sdk/execution_time/test_supervisor.py | 181 +++++++++++++
.../task_sdk/execution_time/test_task_runner.py | 238 ++++++++++++++++-
18 files changed, 1543 insertions(+), 8 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py
index 40397e44f43..385ec509f9f 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py
@@ -24,7 +24,7 @@ from sqlalchemy import select
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
-from airflow.models.asset import AssetModel
+from airflow.models.asset import AssetModel, expand_alias_to_assets
router = APIRouter(
responses={
@@ -58,6 +58,15 @@ def get_asset_by_uri(
return AssetResponse.model_validate(asset)
[email protected]("/by-alias")
+def get_assets_by_alias(
+ alias_name: Annotated[str, Query(description="The name of the
AssetAlias")],
+ session: SessionDep,
+) -> list[AssetResponse]:
+ """Get all Airflow Assets resolved from an AssetAlias by `alias_name`."""
+ return [AssetResponse.model_validate(a) for a in
expand_alias_to_assets(alias_name, session=session)]
+
+
def _raise_if_not_found(asset, msg):
if asset is None:
raise HTTPException(
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index dfa27f53ebd..e05bd22c273 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -40,7 +40,11 @@ from airflow.api_fastapi.execution_api.versions.v2026_04_06
import (
MovePreviousRunEndpoint,
RemoveUpstreamMapIndexesField,
)
-from airflow.api_fastapi.execution_api.versions.v2026_04_17 import
AddStateEndpoints, AddTeamNameField
+from airflow.api_fastapi.execution_api.versions.v2026_04_17 import (
+ AddAssetsByAliasEndpoint,
+ AddStateEndpoints,
+ AddTeamNameField,
+)
from airflow.api_fastapi.execution_api.versions.v2026_06_16 import
AddRetryPolicyFields
bundle = VersionBundle(
@@ -50,6 +54,7 @@ bundle = VersionBundle(
"2026-04-17",
AddTeamNameField,
AddStateEndpoints,
+ AddAssetsByAliasEndpoint,
),
Version(
"2026-04-06",
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py
index dd45b11825d..41e63cf858c 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py
@@ -36,6 +36,14 @@ class AddTeamNameField(VersionChange):
response.body["dag_run"].pop("team_name", None)
+class AddAssetsByAliasEndpoint(VersionChange):
+ """Add endpoint to resolve assets from an AssetAlias."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version =
(endpoint("/assets/by-alias", ["GET"]).didnt_exist,)
+
+
class AddStateEndpoints(VersionChange):
"""Add task state and asset state CRUD endpoints."""
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py
b/airflow-core/tests/unit/dag_processing/test_processor.py
index be5005e4b23..d2ac085b0b7 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -1947,6 +1947,7 @@ class TestDagProcessingMessageTypes:
"DeleteXCom",
"GetAssetByName",
"GetAssetByUri",
+ "GetAssetsByAlias",
"GetAssetEventByAsset",
"GetAssetEventByAssetAlias",
"GetDagRun",
@@ -1971,10 +1972,24 @@ class TestDagProcessingMessageTypes:
"UpdateHITLDetail",
"GetHITLDetailResponse",
"SetRenderedMapIndex",
+ # AIP-103 task/asset state — Dag processor has no task execution
context.
+ "GetTaskState",
+ "SetTaskState",
+ "DeleteTaskState",
+ "ClearTaskState",
+ "GetAssetStateByName",
+ "GetAssetStateByUri",
+ "SetAssetStateByName",
+ "SetAssetStateByUri",
+ "DeleteAssetStateByName",
+ "DeleteAssetStateByUri",
+ "ClearAssetStateByName",
+ "ClearAssetStateByUri",
}
in_task_runner_but_not_in_dag_processing_process = {
"AssetResult",
+ "AssetsByAliasResult",
"AssetEventsResult",
"DagResult",
"DagRunResult",
@@ -1989,6 +2004,9 @@ class TestDagProcessingMessageTypes:
"InactiveAssetsResult",
"CreateHITLDetailPayload",
"HITLDetailRequestResult",
+ # AIP-103 task/asset state results — worker-only responses to the
above messages.
+ "TaskStateResult",
+ "AssetStateResult",
}
supervisor_diff = supervisor_types - manager_types -
in_supervisor_but_not_in_manager
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 752247c6d17..968858be336 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -1758,6 +1758,7 @@ class TestTriggererMessageTypes:
"DeferTask",
"GetAssetByName",
"GetAssetByUri",
+ "GetAssetsByAlias",
"GetAssetEventByAsset",
"GetAssetEventByAssetAlias",
"GetDagRun",
@@ -1780,10 +1781,24 @@ class TestTriggererMessageTypes:
"CreateHITLDetailPayload",
"SetRenderedMapIndex",
"GetDag",
+ # AIP-103 task/asset state — triggerer has no task execution
context.
+ "GetTaskState",
+ "SetTaskState",
+ "DeleteTaskState",
+ "ClearTaskState",
+ "GetAssetStateByName",
+ "GetAssetStateByUri",
+ "SetAssetStateByName",
+ "SetAssetStateByUri",
+ "DeleteAssetStateByName",
+ "DeleteAssetStateByUri",
+ "ClearAssetStateByName",
+ "ClearAssetStateByUri",
}
in_task_but_not_in_trigger_runner = {
"AssetResult",
+ "AssetsByAliasResult",
"AssetEventsResult",
"DagRunResult",
"SentFDs",
@@ -1800,6 +1815,9 @@ class TestTriggererMessageTypes:
"PreviousTIResult",
"HITLDetailRequestResult",
"DagResult",
+ # AIP-103 task/asset state results — worker-only responses to the
above messages.
+ "TaskStateResult",
+ "AssetStateResult",
}
supervisor_diff = (
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py
b/providers/standard/tests/unit/standard/operators/test_python.py
index 437ae89ac26..8142157af47 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -71,6 +71,7 @@ from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS,
+ AIRFLOW_V_3_3_PLUS,
NOTSET,
)
@@ -1131,6 +1132,11 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
"inlet_events",
"outlet_events",
}
+ if AIRFLOW_V_3_3_PLUS:
+ # AIP-103: task_state is a live accessor backed by the supervisor
pipe —
+ # not serializable and meaningless in a virtualenv subprocess.
+ # asset_state is excluded via its absence: only present when a
task has inlets.
+ intentionally_excluded_context_keys.add("task_state")
ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id,
schedule=None)
context = ti.get_template_context()
diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py
b/scripts/ci/prek/check_template_context_variable_in_sync.py
index 96d661d5aab..26508dd5d29 100755
--- a/scripts/ci/prek/check_template_context_variable_in_sync.py
+++ b/scripts/ci/prek/check_template_context_variable_in_sync.py
@@ -48,6 +48,9 @@ IGNORE = {
"data_interval_start",
"prev_data_interval_start_success",
"prev_data_interval_end_success",
+ # AIP-103: task_state/asset_state aren't documented yet. Will be done in a
later PR.
+ "task_state",
+ "asset_state",
}
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 15513010062..493225b4699 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -46,6 +46,8 @@ from airflow.sdk.api.datamodels._generated import (
API_VERSION,
AssetEventsResponse,
AssetResponse,
+ AssetStatePutBody,
+ AssetStateResponse,
ConnectionResponse,
DagResponse,
DagRun,
@@ -58,6 +60,8 @@ from airflow.sdk.api.datamodels._generated import (
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
TaskInstanceState,
+ TaskStatePutBody,
+ TaskStateResponse,
TaskStatesResponse,
TerminalStateNonSuccess,
TIDeferredStatePayload,
@@ -80,6 +84,7 @@ from airflow.sdk.api.datamodels._generated import (
from airflow.sdk.configuration import conf
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
+ AssetsByAliasResult,
CreateHITLDetailPayload,
DRCount,
ErrorResponse,
@@ -662,6 +667,95 @@ class XComOperations:
return XComSequenceSliceResponse.model_validate_json(resp.read())
+class TaskStateOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def get(self, ti_id: uuid.UUID, key: str) -> TaskStateResponse |
ErrorResponse:
+ """Get a task state value from the API server."""
+ try:
+ resp = self.client.get(f"state/ti/{ti_id}/{key}")
+ except ServerResponseError as e:
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
+ log.debug("Task state key not found", ti_id=ti_id, key=key)
+ return ErrorResponse(error=ErrorType.TASK_STATE_NOT_FOUND,
detail={"key": key})
+ raise
+ return TaskStateResponse.model_validate_json(resp.read())
+
+ def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse:
+ """Set a task state value via the API server."""
+ body = TaskStatePutBody(value=value)
+ self.client.put(f"state/ti/{ti_id}/{key}",
content=body.model_dump_json())
+ return OKResponse(ok=True)
+
+ def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse:
+ """Delete a single task state key via the API server."""
+ self.client.delete(f"state/ti/{ti_id}/{key}")
+ return OKResponse(ok=True)
+
+ def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) ->
OKResponse:
+ """Clear all task state keys for a task instance via the API server."""
+ params = {"all_map_indices": "true"} if all_map_indices else {}
+ self.client.delete(f"state/ti/{ti_id}", params=params)
+ return OKResponse(ok=True)
+
+
+class AssetStateOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def _resolve_endpoint(
+ self, op: str, *, key: str | None = None, name: str | None = None,
uri: str | None = None
+ ) -> tuple[str, dict[str, str]]:
+ if name:
+ params: dict[str, str] = {"name": name}
+ endpoint = f"state/asset/by-name/{op}"
+ elif uri:
+ params = {"uri": uri}
+ endpoint = f"state/asset/by-uri/{op}"
+ else:
+ raise ValueError("Either `name` or `uri` must be provided")
+ if key is not None:
+ params["key"] = key
+ return endpoint, params
+
+ def get(
+ self, key: str, *, name: str | None = None, uri: str | None = None
+ ) -> AssetStateResponse | ErrorResponse:
+ """Get an asset state value from the API server."""
+ endpoint, params = self._resolve_endpoint("value", key=key, name=name,
uri=uri)
+ try:
+ resp = self.client.get(endpoint, params=params)
+ except ServerResponseError as e:
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
+ log.debug("Asset state key not found", name=name, uri=uri,
key=key)
+ return ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND,
detail={"key": key})
+ raise
+ return AssetStateResponse.model_validate_json(resp.read())
+
+ def set(self, key: str, value: str, *, name: str | None = None, uri: str |
None = None) -> OKResponse:
+ """Set an asset state value via the API server."""
+ endpoint, params = self._resolve_endpoint("value", key=key, name=name,
uri=uri)
+ self.client.put(endpoint, params=params,
content=AssetStatePutBody(value=value).model_dump_json())
+ return OKResponse(ok=True)
+
+ def delete(self, key: str, *, name: str | None = None, uri: str | None =
None) -> OKResponse:
+ """Delete a single asset state key via the API server."""
+ endpoint, params = self._resolve_endpoint("value", key=key, name=name,
uri=uri)
+ self.client.delete(endpoint, params=params)
+ return OKResponse(ok=True)
+
+ def clear(self, *, name: str | None = None, uri: str | None = None) ->
OKResponse:
+ """Clear all state keys for an asset via the API server."""
+ endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri)
+ self.client.delete(endpoint, params=params)
+ return OKResponse(ok=True)
+
+
class AssetOperations:
__slots__ = ("client",)
@@ -694,6 +788,13 @@ class AssetOperations:
return AssetResponse.model_validate_json(resp.read())
+ def get_by_alias(self, alias_name: str) -> AssetsByAliasResult:
+ """Get all Assets resolved from an AssetAlias."""
+ resp = self.client.get("assets/by-alias", params={"alias_name":
alias_name})
+ return AssetsByAliasResult.from_asset_responses(
+ [AssetResponse.model_validate(a) for a in resp.json()]
+ )
+
class AssetEventOperations:
__slots__ = ("client",)
@@ -1078,6 +1179,18 @@ class Client(httpx.Client):
"""Operations related to Asset Events."""
return AssetEventOperations(self)
+ @lru_cache() # type: ignore[misc]
+ @property
+ def task_state(self) -> TaskStateOperations:
+ """Operations related to task state."""
+ return TaskStateOperations(self)
+
+ @lru_cache() # type: ignore[misc]
+ @property
+ def asset_state(self) -> AssetStateOperations:
+ """Operations related to asset state."""
+ return AssetStateOperations(self)
+
@lru_cache() # type: ignore[misc]
@property
def hitl(self):
diff --git a/task-sdk/src/airflow/sdk/definitions/context.py
b/task-sdk/src/airflow/sdk/definitions/context.py
index b7a63284608..c422c346298 100644
--- a/task-sdk/src/airflow/sdk/definitions/context.py
+++ b/task-sdk/src/airflow/sdk/definitions/context.py
@@ -30,7 +30,11 @@ if TYPE_CHECKING:
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.execution_time.context import InletEventsAccessors
+ from airflow.sdk.execution_time.context import (
+ AssetStateAccessors,
+ InletEventsAccessors,
+ TaskStateAccessor,
+ )
from airflow.sdk.types import (
DagRunProtocol,
Operator,
@@ -72,6 +76,8 @@ class Context(TypedDict, total=False):
task_reschedule_count: int
task_instance: RuntimeTaskInstanceProtocol
task_instance_key_str: str
+ task_state: TaskStateAccessor
+ asset_state: AssetStateAccessors
# `templates_dict` is only set in PythonOperator
templates_dict: NotRequired[dict[str, Any] | None]
test_mode: bool
diff --git a/task-sdk/src/airflow/sdk/exceptions.py
b/task-sdk/src/airflow/sdk/exceptions.py
index 7d42dad5d85..ed3bb3f1493 100644
--- a/task-sdk/src/airflow/sdk/exceptions.py
+++ b/task-sdk/src/airflow/sdk/exceptions.py
@@ -80,6 +80,8 @@ class ErrorType(enum.Enum):
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
ASSET_NOT_FOUND = "ASSET_NOT_FOUND"
+ TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND"
+ ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND"
DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS"
GENERIC_ERROR = "GENERIC_ERROR"
API_SERVER_ERROR = "API_SERVER_ERROR"
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 1e11e9636e5..01528c728b1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -69,6 +69,7 @@ from airflow.sdk.api.datamodels._generated import (
AssetEventResponse,
AssetEventsResponse,
AssetResponse,
+ AssetStateResponse,
BundleInfo,
ConnectionResponse,
DagResponse,
@@ -81,6 +82,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskBreadcrumbsResponse,
TaskInstance,
TaskInstanceState,
+ TaskStateResponse,
TaskStatesResponse,
TIDeferredStatePayload,
TIRescheduleStatePayload,
@@ -561,6 +563,40 @@ class VariableResult(VariableResponse):
return cls(**variable_response.model_dump(exclude_defaults=True),
type="VariableResult")
+class TaskStateResult(TaskStateResponse):
+ """Response to GetTaskState; wraps the generated API response for
supervisor to worker comms."""
+
+ type: Literal["TaskStateResult"] = "TaskStateResult"
+
+ @classmethod
+ def from_task_state_response(cls, resp: TaskStateResponse) ->
TaskStateResult:
+ return cls(**resp.model_dump(exclude_defaults=True),
type="TaskStateResult")
+
+
+class AssetStateResult(AssetStateResponse):
+ """Response to GetAssetState; wraps the generated API response for
supervisor to worker comms."""
+
+ type: Literal["AssetStateResult"] = "AssetStateResult"
+
+ @classmethod
+ def from_asset_state_response(cls, resp: AssetStateResponse) ->
AssetStateResult:
+ return cls(**resp.model_dump(exclude_defaults=True),
type="AssetStateResult")
+
+
+class AssetsByAliasResult(BaseModel):
+ """Response to GetAssetsByAlias; list of concrete assets resolved from an
alias."""
+
+ assets: list[AssetResult]
+ type: Literal["AssetsByAliasResult"] = "AssetsByAliasResult"
+
+ @classmethod
+ def from_asset_responses(cls, asset_responses: list[AssetResponse]) ->
AssetsByAliasResult:
+ return cls(
+ assets=[AssetResult.from_asset_response(a) for a in
asset_responses],
+ type="AssetsByAliasResult",
+ )
+
+
class DagRunResult(DagRun):
type: Literal["DagRunResult"] = "DagRunResult"
@@ -728,7 +764,9 @@ class DagResult(DagResponse):
ToTask = Annotated[
AssetResult
+ | AssetsByAliasResult
| AssetEventsResult
+ | AssetStateResult
| ConnectionResult
| DagRunResult
| DagRunStateResult
@@ -740,6 +778,7 @@ ToTask = Annotated[
| SentFDs
| StartupDetails
| TaskRescheduleStartDate
+ | TaskStateResult
| TICount
| TaskBreadcrumbsResult
| TaskStatesResult
@@ -868,6 +907,79 @@ class DeleteXCom(BaseModel):
type: Literal["DeleteXCom"] = "DeleteXCom"
+class GetTaskState(BaseModel):
+ ti_id: UUID
+ key: str
+ type: Literal["GetTaskState"] = "GetTaskState"
+
+
+class SetTaskState(BaseModel):
+ ti_id: UUID
+ key: str
+ value: str
+ type: Literal["SetTaskState"] = "SetTaskState"
+
+
+class DeleteTaskState(BaseModel):
+ ti_id: UUID
+ key: str
+ type: Literal["DeleteTaskState"] = "DeleteTaskState"
+
+
+class ClearTaskState(BaseModel):
+ ti_id: UUID
+ all_map_indices: bool = False
+ type: Literal["ClearTaskState"] = "ClearTaskState"
+
+
+class GetAssetStateByName(BaseModel):
+ name: str
+ key: str
+ type: Literal["GetAssetStateByName"] = "GetAssetStateByName"
+
+
+class GetAssetStateByUri(BaseModel):
+ uri: str
+ key: str
+ type: Literal["GetAssetStateByUri"] = "GetAssetStateByUri"
+
+
+class SetAssetStateByName(BaseModel):
+ name: str
+ key: str
+ value: str
+ type: Literal["SetAssetStateByName"] = "SetAssetStateByName"
+
+
+class SetAssetStateByUri(BaseModel):
+ uri: str
+ key: str
+ value: str
+ type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri"
+
+
+class DeleteAssetStateByName(BaseModel):
+ name: str
+ key: str
+ type: Literal["DeleteAssetStateByName"] = "DeleteAssetStateByName"
+
+
+class DeleteAssetStateByUri(BaseModel):
+ uri: str
+ key: str
+ type: Literal["DeleteAssetStateByUri"] = "DeleteAssetStateByUri"
+
+
+class ClearAssetStateByName(BaseModel):
+ name: str
+ type: Literal["ClearAssetStateByName"] = "ClearAssetStateByName"
+
+
+class ClearAssetStateByUri(BaseModel):
+ uri: str
+ type: Literal["ClearAssetStateByUri"] = "ClearAssetStateByUri"
+
+
class GetConnection(BaseModel):
conn_id: str
type: Literal["GetConnection"] = "GetConnection"
@@ -957,6 +1069,11 @@ class GetAssetByUri(BaseModel):
type: Literal["GetAssetByUri"] = "GetAssetByUri"
+class GetAssetsByAlias(BaseModel):
+ alias_name: str
+ type: Literal["GetAssetsByAlias"] = "GetAssetsByAlias"
+
+
class GetAssetEventByAsset(BaseModel):
name: str | None
uri: str | None
@@ -1058,12 +1175,21 @@ class GetDag(BaseModel):
ToSupervisor = Annotated[
- DeferTask
+ ClearAssetStateByName
+ | ClearAssetStateByUri
+ | ClearTaskState
+ | DeferTask
+ | DeleteAssetStateByName
+ | DeleteAssetStateByUri
+ | DeleteTaskState
| DeleteXCom
| GetAssetByName
| GetAssetByUri
+ | GetAssetsByAlias
| GetAssetEventByAsset
| GetAssetEventByAssetAlias
+ | GetAssetStateByName
+ | GetAssetStateByUri
| GetConnection
| GetDagRun
| GetDagRunState
@@ -1073,6 +1199,7 @@ ToSupervisor = Annotated[
| GetPreviousDagRun
| GetPreviousTI
| GetTaskRescheduleStartDate
+ | GetTaskState
| GetTICount
| GetTaskBreadcrumbs
| GetTaskStates
@@ -1084,8 +1211,11 @@ ToSupervisor = Annotated[
| PutVariable
| RescheduleTask
| RetryTask
+ | SetAssetStateByName
+ | SetAssetStateByUri
| SetRenderedFields
| SetRenderedMapIndex
+ | SetTaskState
| SetXCom
| SkipDownstreamTasks
| SucceedTask
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 66c1f3aa8b7..7fdd8bddc7f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -24,6 +24,7 @@ from collections.abc import Generator, Iterable, Iterator,
Mapping, Sequence
from datetime import datetime
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
+from uuid import UUID
import attrs
import structlog
@@ -45,8 +46,6 @@ from airflow.sdk.exceptions import AirflowNotFoundException,
AirflowRuntimeError
from airflow.sdk.log import mask_secret
if TYPE_CHECKING:
- from uuid import UUID
-
from pydantic.types import JsonValue
from typing_extensions import Self
@@ -406,6 +405,229 @@ class VariableAccessor:
raise
+class TaskStateAccessor:
+ """Accessor for task state scoped to the current task instance. Available
as ``context['task_state']`` at task execution time."""
+
+ def __init__(self, ti_id: UUID) -> None:
+ self._ti_id = ti_id
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TaskStateAccessor):
+ return False
+ return self._ti_id == other._ti_id
+
+ def __hash__(self) -> int:
+ return hash(self._ti_id)
+
+ def __repr__(self) -> str:
+ return f"<TaskStateAccessor ti_id={self._ti_id}>"
+
+ # TODO: ``__getattr__`` for jinja template access like ``{{
task_state.job_id }}``
+ # is not implemented yet cos it's unclear whether task state values will be
+ # used in templates.
+
+ def get(self, key: str) -> str | None:
+ """Return the stored value, or ``None`` if the key does not exist."""
+ from airflow.sdk.execution_time.comms import ErrorResponse,
GetTaskState, TaskStateResult
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key))
+ if isinstance(resp, ErrorResponse) and resp.error !=
ErrorType.TASK_STATE_NOT_FOUND:
+ raise AirflowRuntimeError(resp)
+ if isinstance(resp, TaskStateResult):
+ return resp.value
+ return None
+
+ def set(self, key: str, value: str) -> None:
+ """Write or overwrite the value for the given key."""
+ from airflow.sdk.execution_time.comms import SetTaskState
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key,
value=value))
+
+ def delete(self, key: str) -> None:
+ """Delete a single key. No-op if the key does not exist."""
+ from airflow.sdk.execution_time.comms import DeleteTaskState
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key))
+
+ def clear(self, all_map_indices: bool = False) -> None:
+ """
+ Delete all keys for this task instance.
+
+ Pass ``all_map_indices=True`` to wipe state across every mapped
+ instance of the task (fleet-wide reset). Defaults to clearing only
+ this task instance's own state.
+ """
+ from airflow.sdk.execution_time.comms import ClearTaskState
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id,
all_map_indices=all_map_indices))
+
+
+class AssetStateAccessor:
+ """
+ Accessor for asset state scoped to a single asset.
+
+ Obtained via ``context['asset_state'][MY_ASSET]`` or, as sugar for
single-inlet
+ tasks, directly as ``context['asset_state']``.
+ """
+
+ def __init__(self, *, name: str | None = None, uri: str | None = None) ->
None:
+ if not name and not uri:
+ raise ValueError("Either `name` or `uri` must be provided")
+ self._name = name
+ self._uri = uri
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, AssetStateAccessor):
+ return False
+ return self._name == other._name and self._uri == other._uri
+
+ def __hash__(self) -> int:
+ return hash((self._name, self._uri))
+
+ def __repr__(self) -> str:
+ if self._name is not None:
+ return f"<AssetStateAccessor name={self._name!r}>"
+ return f"<AssetStateAccessor uri={self._uri!r}>"
+
+ def get(self, key: str) -> str | None:
+ """Return the stored value, or ``None`` if the key does not exist."""
+ from airflow.sdk.execution_time.comms import (
+ AssetStateResult,
+ ErrorResponse,
+ GetAssetStateByName,
+ GetAssetStateByUri,
+ ToSupervisor,
+ )
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg: ToSupervisor
+ if self._name:
+ msg = GetAssetStateByName(name=self._name, key=key)
+ elif self._uri:
+ msg = GetAssetStateByUri(uri=self._uri, key=key)
+ resp = SUPERVISOR_COMMS.send(msg)
+ if isinstance(resp, ErrorResponse) and resp.error !=
ErrorType.ASSET_STATE_NOT_FOUND:
+ raise AirflowRuntimeError(resp)
+ if isinstance(resp, AssetStateResult):
+ return resp.value
+ return None
+
+ def set(self, key: str, value: str) -> None:
+ """Write or overwrite the value for the given key."""
+ from airflow.sdk.execution_time.comms import SetAssetStateByName,
SetAssetStateByUri, ToSupervisor
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg: ToSupervisor
+ if self._name:
+ msg = SetAssetStateByName(name=self._name, key=key, value=value)
+ elif self._uri:
+ msg = SetAssetStateByUri(uri=self._uri, key=key, value=value)
+ SUPERVISOR_COMMS.send(msg)
+
+ def delete(self, key: str) -> None:
+ """Delete a single key. No-op if the key does not exist."""
+ from airflow.sdk.execution_time.comms import (
+ DeleteAssetStateByName,
+ DeleteAssetStateByUri,
+ ToSupervisor,
+ )
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg: ToSupervisor
+ if self._name:
+ msg = DeleteAssetStateByName(name=self._name, key=key)
+ elif self._uri:
+ msg = DeleteAssetStateByUri(uri=self._uri, key=key)
+ SUPERVISOR_COMMS.send(msg)
+
+ def clear(self) -> None:
+ """Delete all state keys for this asset."""
+ from airflow.sdk.execution_time.comms import ClearAssetStateByName,
ClearAssetStateByUri, ToSupervisor
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg: ToSupervisor
+ if self._name:
+ msg = ClearAssetStateByName(name=self._name)
+ elif self._uri:
+ msg = ClearAssetStateByUri(uri=self._uri)
+ SUPERVISOR_COMMS.send(msg)
+
+
+class AssetStateAccessors:
+ """
+ Mapping of asset state accessors for all concrete inlets of a task.
+
+ Available as ``context['asset_state']``. Subscript by asset to get a per
asset
+ accessor as: ``context['asset_state'][MY_ASSET].get('watermark')``.
+
+ For tasks with exactly one concrete inlet, the accessor methods (``get``,
``set``,
+ ``delete``, ``clear``) can be called directly without subscripting.
+ """
+
+ def __init__(self, inlets: list) -> None:
+ self._by_name: dict[str, AssetStateAccessor] = {}
+ self._by_uri: dict[str, AssetStateAccessor] = {}
+
+ for inlet in inlets:
+ if isinstance(inlet, (Asset, AssetNameRef)):
+ self._by_name[inlet.name] = AssetStateAccessor(name=inlet.name)
+ elif isinstance(inlet, AssetUriRef):
+ self._by_uri[inlet.uri] = AssetStateAccessor(uri=inlet.uri)
+ elif isinstance(inlet, AssetAlias):
+ from airflow.sdk.execution_time.comms import
AssetsByAliasResult, GetAssetsByAlias
+ from airflow.sdk.execution_time.task_runner import
SUPERVISOR_COMMS
+
+ resp =
SUPERVISOR_COMMS.send(GetAssetsByAlias(alias_name=inlet.name))
+ if isinstance(resp, AssetsByAliasResult):
+ for asset in resp.assets:
+ self._by_name[asset.name] =
AssetStateAccessor(name=asset.name)
+
+ self._total = len(self._by_name) + len(self._by_uri)
+
+ def __getitem__(self, key: Asset | AssetNameRef | AssetUriRef) ->
AssetStateAccessor:
+ try:
+ if isinstance(key, (Asset, AssetNameRef)):
+ return self._by_name[key.name]
+ if isinstance(key, AssetUriRef):
+ return self._by_uri[key.uri]
+ except KeyError:
+ raise KeyError(f"{key!r} is not in this task's inlets")
+ raise TypeError(f"Expected Asset, AssetNameRef, or AssetUriRef; got
{type(key).__name__}")
+
+ def _single_accessor(self) -> AssetStateAccessor:
+ if self._total != 1:
+ raise ValueError(
+ f"Task has {self._total} concrete inlets — use
context['asset_state'][MY_ASSET] to specify which"
+ )
+ if self._by_name:
+ return next(iter(self._by_name.values()))
+ return next(iter(self._by_uri.values()))
+
+ def get(self, key: str) -> str | None:
+ """Return the stored value for the single-inlet task, or ``None`` if
not found."""
+ return self._single_accessor().get(key)
+
+ def set(self, key: str, value: str) -> None:
+ """Write or overwrite the value for the single-inlet task."""
+ self._single_accessor().set(key, value)
+
+ def delete(self, key: str) -> None:
+ """Delete a single key for the single-inlet task."""
+ self._single_accessor().delete(key)
+
+ def clear(self) -> None:
+ """Delete all state keys for the single-inlet task."""
+ self._single_accessor().clear()
+
+ def __repr__(self) -> str:
+ parts = [f"name={k!r}" for k in self._by_name] + [f"uri={k!r}" for k
in self._by_uri]
+ return f"<AssetStateAccessors [{', '.join(parts)}]>"
+
+
class MacrosAccessor:
"""Wrapper to access Macros module lazily."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index cd25d927957..b8f2eeff045 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -64,12 +64,19 @@ from airflow.sdk.execution_time import comms
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
AssetResult,
+ AssetStateResult,
+ ClearAssetStateByName,
+ ClearAssetStateByUri,
+ ClearTaskState,
ConnectionResult,
CreateHITLDetailPayload,
DagResult,
DagRunResult,
DagRunStateResult,
DeferTask,
+ DeleteAssetStateByName,
+ DeleteAssetStateByUri,
+ DeleteTaskState,
DeleteVariable,
DeleteXCom,
ErrorResponse,
@@ -77,6 +84,9 @@ from airflow.sdk.execution_time.comms import (
GetAssetByUri,
GetAssetEventByAsset,
GetAssetEventByAssetAlias,
+ GetAssetsByAlias,
+ GetAssetStateByName,
+ GetAssetStateByUri,
GetConnection,
GetDag,
GetDagRun,
@@ -87,6 +97,7 @@ from airflow.sdk.execution_time.comms import (
GetPrevSuccessfulDagRun,
GetTaskBreadcrumbs,
GetTaskRescheduleStartDate,
+ GetTaskState,
GetTaskStates,
GetTICount,
GetVariable,
@@ -97,20 +108,25 @@ from airflow.sdk.execution_time.comms import (
HITLDetailRequestResult,
InactiveAssetsResult,
MaskSecret,
+ OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
ResendLoggingFD,
RetryTask,
SentFDs,
+ SetAssetStateByName,
+ SetAssetStateByUri,
SetRenderedFields,
SetRenderedMapIndex,
+ SetTaskState,
SetXCom,
SkipDownstreamTasks,
StartupDetails,
SucceedTask,
TaskBreadcrumbsResult,
TaskState,
+ TaskStateResult,
TaskStatesResult,
ToSupervisor,
TriggerDagRun,
@@ -1509,6 +1525,8 @@ class ActivitySubprocess(WatchedSubprocess):
dump_opts = {"exclude_unset": True}
else:
resp = asset_resp
+ elif isinstance(msg, GetAssetsByAlias):
+ resp = self.client.assets.get_by_alias(alias_name=msg.alias_name)
elif isinstance(msg, GetAssetEventByAsset):
asset_event_resp = self.client.asset_events.get(
uri=msg.uri,
@@ -1628,6 +1646,54 @@ class ActivitySubprocess(WatchedSubprocess):
dag_id=msg.dag_id,
)
resp = DagResult.from_api_response(dag)
+ elif isinstance(msg, GetTaskState):
+ task_state = self.client.task_state.get(msg.ti_id, msg.key)
+ resp = (
+ task_state
+ if isinstance(task_state, ErrorResponse)
+ else TaskStateResult.from_task_state_response(task_state)
+ )
+ elif isinstance(msg, SetTaskState):
+ self.client.task_state.set(msg.ti_id, msg.key, msg.value)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, DeleteTaskState):
+ self.client.task_state.delete(msg.ti_id, msg.key)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, ClearTaskState):
+ self.client.task_state.clear(msg.ti_id,
all_map_indices=msg.all_map_indices)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, GetAssetStateByName):
+ asset_state = self.client.asset_state.get(msg.key, name=msg.name)
+ resp = (
+ asset_state
+ if isinstance(asset_state, ErrorResponse)
+ else AssetStateResult.from_asset_state_response(asset_state)
+ )
+ elif isinstance(msg, GetAssetStateByUri):
+ asset_state = self.client.asset_state.get(msg.key, uri=msg.uri)
+ resp = (
+ asset_state
+ if isinstance(asset_state, ErrorResponse)
+ else AssetStateResult.from_asset_state_response(asset_state)
+ )
+ elif isinstance(msg, SetAssetStateByName):
+ self.client.asset_state.set(msg.key, msg.value, name=msg.name)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, SetAssetStateByUri):
+ self.client.asset_state.set(msg.key, msg.value, uri=msg.uri)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, DeleteAssetStateByName):
+ self.client.asset_state.delete(msg.key, name=msg.name)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, DeleteAssetStateByUri):
+ self.client.asset_state.delete(msg.key, uri=msg.uri)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, ClearAssetStateByName):
+ self.client.asset_state.clear(name=msg.name)
+ resp = OKResponse(ok=True)
+ elif isinstance(msg, ClearAssetStateByUri):
+ self.client.asset_state.clear(uri=msg.uri)
+ resp = OKResponse(ok=True)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 56ba8343c64..7c318fc499e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -111,10 +111,12 @@ from airflow.sdk.execution_time.comms import (
ValidateInletsAndOutlets,
)
from airflow.sdk.execution_time.context import (
+ AssetStateAccessors,
ConnectionAccessor,
InletEventsAccessors,
MacrosAccessor,
OutletEventAccessors,
+ TaskStateAccessor,
TriggeringAssetEventsAccessor,
VariableAccessor,
context_get_outlet_events,
@@ -249,7 +251,12 @@ class RuntimeTaskInstance(TaskInstance):
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
+ "task_state": TaskStateAccessor(ti_id=self.id),
}
+ if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef,
AssetAlias)) for i in self.task.inlets):
+ self._cached_template_context["asset_state"] =
AssetStateAccessors(self.task.inlets)
+ # AssetAlias inlets are resolved to their concrete assets at
context build time
+ # via GetAssetsByAlias comms. If an alias maps to no active
assets, it doesnt contribute to asset_state.
if TYPE_CHECKING:
assert self._cached_template_context is not None
if from_server:
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index 26e2a7e66bb..a179ff08436 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -36,6 +36,7 @@ from airflow.sdk.api.client import Client,
RemoteValidationError, ServerResponse
from airflow.sdk.api.datamodels._generated import (
AssetEventsResponse,
AssetResponse,
+ AssetStateResponse,
ConnectionResponse,
DagResponse,
DagRunState,
@@ -43,12 +44,14 @@ from airflow.sdk.api.datamodels._generated import (
HITLDetailRequest,
HITLDetailResponse,
HITLUser,
+ TaskStateResponse,
TerminalTIState,
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
+ AssetsByAliasResult,
DeferTask,
ErrorResponse,
OKResponse,
@@ -1221,6 +1224,37 @@ class TestAssetOperations:
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.ASSET_NOT_FOUND
+ def test_get_by_alias_returns_list(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.url.path == "/assets/by-alias"
+ assert request.url.params["alias_name"] == "my_alias"
+ return httpx.Response(
+ status_code=200,
+ json=[
+ {"name": "asset_a", "uri": "s3://bucket/a", "group":
"asset", "extra": {}},
+ {"name": "asset_b", "uri": "s3://bucket/b", "group":
"asset", "extra": {}},
+ ],
+ )
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.assets.get_by_alias("my_alias")
+
+ assert isinstance(result, AssetsByAliasResult)
+ assert len(result.assets) == 2
+ assert isinstance(result.assets[0], AssetResponse)
+ assert isinstance(result.assets[1], AssetResponse)
+ assert result.assets[0].name == "asset_a"
+ assert result.assets[1].name == "asset_b"
+
+ def test_get_by_alias_returns_empty_for_unknown_alias(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ return httpx.Response(status_code=200, json=[])
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.assets.get_by_alias("unknown_alias")
+
+ assert result.assets == []
+
class TestDagRunOperations:
def test_trigger(self):
@@ -1700,3 +1734,189 @@ class TestDagsOperations:
with pytest.raises(ServerResponseError):
client.dags.get(dag_id="test_dag")
+
+
+class TestTaskStateOperations:
+ TI_ID = uuid7()
+
+ def test_get_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/state/ti/{self.TI_ID}/job_id":
+ return httpx.Response(status_code=200, json={"value":
"spark_app_001"})
+ return httpx.Response(status_code=400)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.get(ti_id=self.TI_ID, key="job_id")
+
+ assert isinstance(result, TaskStateResponse)
+ assert result.value == "spark_app_001"
+
+ def test_get_returns_error_response_on_404(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ return httpx.Response(
+ status_code=404,
+ json={"detail": {"reason": "not_found", "message": "Task state
key 'job_id' not found"}},
+ )
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.get(ti_id=self.TI_ID, key="job_id")
+ assert isinstance(result, ErrorResponse)
+ assert result.error == ErrorType.TASK_STATE_NOT_FOUND
+
+ def test_set_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "PUT"
+ assert request.url.path == f"/state/ti/{self.TI_ID}/job_id"
+ assert b'"value":"spark_app_001"' in request.content
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.set(ti_id=self.TI_ID, key="job_id",
value="spark_app_001")
+ assert result == OKResponse(ok=True)
+
+ def test_delete_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == f"/state/ti/{self.TI_ID}/job_id"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.delete(ti_id=self.TI_ID, key="job_id")
+ assert result == OKResponse(ok=True)
+
+ def test_clear_default_no_query_param(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == f"/state/ti/{self.TI_ID}"
+ assert "all_map_indices" not in str(request.url.query)
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.clear(ti_id=self.TI_ID)
+ assert result == OKResponse(ok=True)
+
+ def test_clear_all_map_indices_sends_query_param(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert "all_map_indices=true" in str(request.url.query)
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_state.clear(ti_id=self.TI_ID,
all_map_indices=True)
+ assert result == OKResponse(ok=True)
+
+
+class TestAssetStateOperations:
+ def test_get_by_name_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if (
+ request.url.path == "/state/asset/by-name/value"
+ and request.url.params["name"] == "test_asset"
+ and request.url.params["key"] == "watermark"
+ ):
+ return httpx.Response(status_code=200, json={"value":
"2026-04-30T00:00:00Z"})
+ return httpx.Response(status_code=400)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.get(key="watermark", name="test_asset")
+
+ assert isinstance(result, AssetStateResponse)
+ assert result.value == "2026-04-30T00:00:00Z"
+
+ def test_get_by_uri_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if (
+ request.url.path == "/state/asset/by-uri/value"
+ and request.url.params["uri"] == "s3://bucket/key"
+ and request.url.params["key"] == "watermark"
+ ):
+ return httpx.Response(status_code=200, json={"value":
"2026-04-30T00:00:00Z"})
+ return httpx.Response(status_code=400)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.get(key="watermark", uri="s3://bucket/key")
+
+ assert isinstance(result, AssetStateResponse)
+ assert result.value == "2026-04-30T00:00:00Z"
+
+ def test_get_returns_error_response_on_404(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ return httpx.Response(
+ status_code=404,
+ json={"detail": {"reason": "not_found", "message": "Asset
state key 'watermark' not found"}},
+ )
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.get(key="watermark", name="test_asset")
+ assert isinstance(result, ErrorResponse)
+ assert result.error == ErrorType.ASSET_STATE_NOT_FOUND
+
+ def test_set_by_name_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "PUT"
+ assert request.url.path == "/state/asset/by-name/value"
+ assert request.url.params["name"] == "test_asset"
+ assert request.url.params["key"] == "watermark"
+ assert b'"value":"2026-04-30T00:00:00Z"' in request.content
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.set(key="watermark",
value="2026-04-30T00:00:00Z", name="test_asset")
+ assert result == OKResponse(ok=True)
+
+ def test_set_by_uri_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "PUT"
+ assert request.url.path == "/state/asset/by-uri/value"
+ assert request.url.params["uri"] == "s3://bucket/key"
+ assert request.url.params["key"] == "watermark"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.set(key="watermark",
value="2026-04-30T00:00:00Z", uri="s3://bucket/key")
+ assert result == OKResponse(ok=True)
+
+ def test_delete_by_name_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == "/state/asset/by-name/value"
+ assert request.url.params["name"] == "test_asset"
+ assert request.url.params["key"] == "watermark"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.delete(key="watermark", name="test_asset")
+ assert result == OKResponse(ok=True)
+
+ def test_delete_by_uri_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == "/state/asset/by-uri/value"
+ assert request.url.params["uri"] == "s3://bucket/key"
+ assert request.url.params["key"] == "watermark"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.delete(key="watermark",
uri="s3://bucket/key")
+ assert result == OKResponse(ok=True)
+
+ def test_clear_by_name_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == "/state/asset/by-name/clear"
+ assert request.url.params["name"] == "test_asset"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.clear(name="test_asset")
+ assert result == OKResponse(ok=True)
+
+ def test_clear_by_uri_success(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.method == "DELETE"
+ assert request.url.path == "/state/asset/by-uri/clear"
+ assert request.url.params["uri"] == "s3://bucket/key"
+ return httpx.Response(status_code=204)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.asset_state.clear(uri="s3://bucket/key")
+ assert result == OKResponse(ok=True)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 1b0be13ab3a..ff0e6025c63 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock, patch
+from uuid import UUID
import pytest
@@ -30,33 +31,55 @@ from airflow.sdk.definitions.asset import (
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
+ AssetNameRef,
AssetUniqueKey,
+ AssetUriRef,
)
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
-from airflow.sdk.exceptions import AirflowNotFoundException, ErrorType
+from airflow.sdk.exceptions import AirflowNotFoundException,
AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import (
AssetEventDagRunReferenceResult,
AssetEventResult,
AssetEventSourceTaskInstance,
AssetEventsResult,
AssetResult,
+ AssetsByAliasResult,
+ AssetStateResult,
+ ClearAssetStateByName,
+ ClearAssetStateByUri,
+ ClearTaskState,
ConnectionResult,
DagRunResult,
+ DeleteAssetStateByName,
+ DeleteAssetStateByUri,
+ DeleteTaskState,
ErrorResponse,
GetAssetByName,
GetAssetByUri,
GetAssetEventByAsset,
+ GetAssetsByAlias,
+ GetAssetStateByName,
+ GetAssetStateByUri,
GetDagRun,
+ GetTaskState,
GetXCom,
+ OKResponse,
+ SetAssetStateByName,
+ SetAssetStateByUri,
+ SetTaskState,
+ TaskStateResult,
VariableResult,
XComResult,
)
from airflow.sdk.execution_time.context import (
+ AssetStateAccessor,
+ AssetStateAccessors,
ConnectionAccessor,
InletEventsAccessors,
OutletEventAccessor,
OutletEventAccessors,
+ TaskStateAccessor,
TriggeringAssetEventsAccessor,
VariableAccessor,
_AssetRefResolutionMixin,
@@ -1032,3 +1055,265 @@ class TestSecretsBackend:
with pytest.raises(AirflowNotFoundException, match="isn't
defined"):
_get_connection("nonexistent_conn")
+
+
+class TestTaskStateAccessor:
+ TI_ID = UUID("01900000-0000-0000-0000-000000000001")
+
+ def test_get_returns_value(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value =
TaskStateResult(value="app_001")
+
+ result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id")
+
+ assert result == "app_001"
+
mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID,
key="job_id"))
+
+ def test_get_returns_none_on_404(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"}
+ )
+
+ result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key")
+
+ assert result is None
+
+ def test_get_raises_on_error(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
+ )
+
+ with pytest.raises(AirflowRuntimeError):
+ TaskStateAccessor(ti_id=self.TI_ID).get("some_key")
+
+ def test_set_operation(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001")
+ )
+
+ def test_delete_operation(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID).delete("job_id")
+
+
mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID,
key="job_id"))
+
+ def test_clear_default_sends_all_map_indices_false(self,
mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID).clear()
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)
+ )
+
+ def test_clear_all_map_indices_sends_flag_true(self,
mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True)
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ ClearTaskState(ti_id=self.TI_ID, all_map_indices=True)
+ )
+
+
+class TestAssetStateAccessor:
+ ASSET_NAME = "debug_watcher_asset"
+ ASSET_URI = "s3://bucket/key"
+
+ def test_get_returns_value(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value =
AssetStateResult(value="2026-04-30T00:00:00Z")
+
+ result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark")
+
+ assert result == "2026-04-30T00:00:00Z"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_get_returns_none_on_404(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key":
"missing_key"}
+ )
+
+ result = AssetStateAccessor(name=self.ASSET_NAME).get("missing_key")
+
+ assert result is None
+
+ def test_get_raises_on_error(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
+ )
+
+ with pytest.raises(AirflowRuntimeError):
+ AssetStateAccessor(name=self.ASSET_NAME).get("some_key")
+
+ def test_set_operation(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).set("watermark",
"2026-04-30T00:00:00Z")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetAssetStateByName(name=self.ASSET_NAME, key="watermark",
value="2026-04-30T00:00:00Z")
+ )
+
+ def test_delete_operation(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).delete("watermark")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_clear_operation(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(name=self.ASSET_NAME).clear()
+
+
mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.ASSET_NAME))
+
+ def test_get_by_uri(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value =
AssetStateResult(value="2026-04-30T00:00:00Z")
+
+ result = AssetStateAccessor(uri=self.ASSET_URI).get("watermark")
+
+ assert result == "2026-04-30T00:00:00Z"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByUri(uri=self.ASSET_URI, key="watermark")
+ )
+
+ def test_set_by_uri(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(uri=self.ASSET_URI).set("watermark",
"2026-04-30T00:00:00Z")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetAssetStateByUri(uri=self.ASSET_URI, key="watermark",
value="2026-04-30T00:00:00Z")
+ )
+
+ def test_delete_by_uri(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(uri=self.ASSET_URI).delete("watermark")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ DeleteAssetStateByUri(uri=self.ASSET_URI, key="watermark")
+ )
+
+ def test_clear_by_uri(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessor(uri=self.ASSET_URI).clear()
+
+
mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.ASSET_URI))
+
+
+class TestAssetStateAccessors:
+ ASSET_NAME = "my_asset"
+ ASSET_URI = "s3://bucket/key"
+
+ def test_subscript_by_asset_routes_by_name(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v1")
+
+ result = AssetStateAccessors([asset])[asset].get("watermark")
+
+ assert result == "v1"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_subscript_by_asset_name_ref(self, mock_supervisor_comms):
+ ref = AssetNameRef(name=self.ASSET_NAME)
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v2")
+
+ result = AssetStateAccessors([ref])[ref].get("watermark")
+
+ assert result == "v2"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_subscript_by_uri_ref(self, mock_supervisor_comms):
+ ref = AssetUriRef(uri=self.ASSET_URI)
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v3")
+
+ result = AssetStateAccessors([ref])[ref].get("watermark")
+
+ assert result == "v3"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByUri(uri=self.ASSET_URI, key="watermark")
+ )
+
+ def test_get_single_inlet_simplified(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v4")
+
+ result = AssetStateAccessors([asset]).get("watermark")
+
+ assert result == "v4"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_set_single_inlet_simplified(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessors([asset]).set("watermark", "2026-05-01")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ SetAssetStateByName(name=self.ASSET_NAME, key="watermark",
value="2026-05-01")
+ )
+
+ def test_delete_single_inlet_simplified(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessors([asset]).delete("watermark")
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_clear_single_inlet_simplified(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+ AssetStateAccessors([asset]).clear()
+
+
mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.ASSET_NAME))
+
+ def test_double_reference_raises(self):
+ a1 = Asset(name="asset_one", uri="s3://one")
+ a2 = Asset(name="asset_two", uri="s3://two")
+
+ with pytest.raises(ValueError, match="2 concrete inlets"):
+ AssetStateAccessors([a1, a2]).get("watermark")
+
+ def test_alias_inlet_resolves_to_concrete_assets(self,
mock_supervisor_comms):
+ alias = AssetAlias(name="my_alias")
+ mock_supervisor_comms.send.return_value = AssetsByAliasResult(
+ assets=[AssetResult(name="resolved_asset",
uri="s3://bucket/resolved", group="asset")]
+ )
+ mock_supervisor_comms.send.return_value = AssetsByAliasResult(
+ assets=[AssetResult(name="resolved_asset",
uri="s3://bucket/resolved", group="asset")]
+ )
+
+ accessors = AssetStateAccessors([alias])
+
+
mock_supervisor_comms.send.assert_called_once_with(GetAssetsByAlias(alias_name="my_alias"))
+ resolved = Asset(name="resolved_asset", uri="s3://bucket/resolved")
+ assert resolved.name in accessors._by_name
+
+ def test_alias_inlet_no_resolved_assets_contributes_nothing(self,
mock_supervisor_comms):
+ alias = AssetAlias(name="empty_alias")
+ mock_supervisor_comms.send.return_value =
AssetsByAliasResult(assets=[])
+
+ accessors = AssetStateAccessors([alias])
+
+ assert accessors._total == 0
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index f0d8e1a0b65..a8f97f81ac2 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -72,6 +72,11 @@ from airflow.sdk.execution_time import supervisor,
task_runner
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
AssetResult,
+ AssetsByAliasResult,
+ AssetStateResult,
+ ClearAssetStateByName,
+ ClearAssetStateByUri,
+ ClearTaskState,
CommsDecoder,
ConnectionResult,
CreateHITLDetailPayload,
@@ -79,6 +84,9 @@ from airflow.sdk.execution_time.comms import (
DagRunResult,
DagRunStateResult,
DeferTask,
+ DeleteAssetStateByName,
+ DeleteAssetStateByUri,
+ DeleteTaskState,
DeleteVariable,
DeleteXCom,
DRCount,
@@ -87,6 +95,9 @@ from airflow.sdk.execution_time.comms import (
GetAssetByUri,
GetAssetEventByAsset,
GetAssetEventByAssetAlias,
+ GetAssetsByAlias,
+ GetAssetStateByName,
+ GetAssetStateByUri,
GetConnection,
GetDag,
GetDagRun,
@@ -98,6 +109,7 @@ from airflow.sdk.execution_time.comms import (
GetPrevSuccessfulDagRun,
GetTaskBreadcrumbs,
GetTaskRescheduleStartDate,
+ GetTaskState,
GetTaskStates,
GetTICount,
GetVariable,
@@ -117,14 +129,18 @@ from airflow.sdk.execution_time.comms import (
ResendLoggingFD,
RetryTask,
SentFDs,
+ SetAssetStateByName,
+ SetAssetStateByUri,
SetRenderedFields,
SetRenderedMapIndex,
+ SetTaskState,
SetXCom,
SkipDownstreamTasks,
SucceedTask,
TaskBreadcrumbsResult,
TaskRescheduleStartDate,
TaskState,
+ TaskStateResult,
TaskStatesResult,
TICount,
ToSupervisor,
@@ -1807,6 +1823,29 @@ REQUEST_TEST_CASES = [
),
test_id="get_asset_by_uri",
),
+ RequestTestCase(
+ message=GetAssetsByAlias(alias_name="my_alias"),
+ expected_body={
+ "assets": [
+ {
+ "name": "asset_a",
+ "uri": "s3://bucket/a",
+ "group": "asset",
+ "extra": None,
+ "type": "AssetResult",
+ }
+ ],
+ "type": "AssetsByAliasResult",
+ },
+ client_mock=ClientMock(
+ method_path="assets.get_by_alias",
+ kwargs={"alias_name": "my_alias"},
+ response=AssetsByAliasResult(
+ assets=[AssetResult(name="asset_a", uri="s3://bucket/a",
group="asset", extra=None)]
+ ),
+ ),
+ test_id="get_assets_by_alias",
+ ),
RequestTestCase(
message=GetAssetEventByAsset(uri="s3://bucket/obj", name="test"),
expected_body={
@@ -2656,6 +2695,148 @@ REQUEST_TEST_CASES = [
),
test_id="get_dag",
),
+ RequestTestCase(
+ message=GetTaskState(ti_id=TI_ID, key="job_id"),
+ test_id="get_task_state",
+ client_mock=ClientMock(
+ method_path="task_state.get",
+ args=(TI_ID, "job_id"),
+ response=TaskStateResult(value="spark_app_001"),
+ ),
+ expected_body={"value": "spark_app_001", "type": "TaskStateResult"},
+ ),
+ RequestTestCase(
+ message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"),
+ test_id="set_task_state",
+ client_mock=ClientMock(
+ method_path="task_state.set",
+ args=(TI_ID, "job_id", "spark_app_001"),
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=DeleteTaskState(ti_id=TI_ID, key="job_id"),
+ test_id="delete_task_state",
+ client_mock=ClientMock(
+ method_path="task_state.delete",
+ args=(TI_ID, "job_id"),
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=ClearTaskState(ti_id=TI_ID),
+ test_id="clear_task_state",
+ client_mock=ClientMock(
+ method_path="task_state.clear",
+ args=(TI_ID,),
+ kwargs={"all_map_indices": False},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=ClearTaskState(ti_id=TI_ID, all_map_indices=True),
+ test_id="clear_task_state_all_map_indices",
+ client_mock=ClientMock(
+ method_path="task_state.clear",
+ args=(TI_ID,),
+ kwargs={"all_map_indices": True},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=GetAssetStateByName(name="debug_watcher_asset",
key="watermark"),
+ test_id="get_asset_state_by_name",
+ client_mock=ClientMock(
+ method_path="asset_state.get",
+ args=("watermark",),
+ kwargs={"name": "debug_watcher_asset"},
+ response=AssetStateResult(value="2026-04-30T00:00:00Z"),
+ ),
+ expected_body={"value": "2026-04-30T00:00:00Z", "type":
"AssetStateResult"},
+ ),
+ RequestTestCase(
+ message=GetAssetStateByUri(uri="s3://bucket/key", key="watermark"),
+ test_id="get_asset_state_by_uri",
+ client_mock=ClientMock(
+ method_path="asset_state.get",
+ args=("watermark",),
+ kwargs={"uri": "s3://bucket/key"},
+ response=AssetStateResult(value="2026-04-30T00:00:00Z"),
+ ),
+ expected_body={"value": "2026-04-30T00:00:00Z", "type":
"AssetStateResult"},
+ ),
+ RequestTestCase(
+ message=SetAssetStateByName(
+ name="debug_watcher_asset", key="watermark",
value="2026-04-30T00:00:00Z"
+ ),
+ test_id="set_asset_state_by_name",
+ client_mock=ClientMock(
+ method_path="asset_state.set",
+ args=("watermark", "2026-04-30T00:00:00Z"),
+ kwargs={"name": "debug_watcher_asset"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=SetAssetStateByUri(uri="s3://bucket/key", key="watermark",
value="2026-04-30T00:00:00Z"),
+ test_id="set_asset_state_by_uri",
+ client_mock=ClientMock(
+ method_path="asset_state.set",
+ args=("watermark", "2026-04-30T00:00:00Z"),
+ kwargs={"uri": "s3://bucket/key"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=DeleteAssetStateByName(name="debug_watcher_asset",
key="watermark"),
+ test_id="delete_asset_state_by_name",
+ client_mock=ClientMock(
+ method_path="asset_state.delete",
+ args=("watermark",),
+ kwargs={"name": "debug_watcher_asset"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=DeleteAssetStateByUri(uri="s3://bucket/key", key="watermark"),
+ test_id="delete_asset_state_by_uri",
+ client_mock=ClientMock(
+ method_path="asset_state.delete",
+ args=("watermark",),
+ kwargs={"uri": "s3://bucket/key"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=ClearAssetStateByName(name="debug_watcher_asset"),
+ test_id="clear_asset_state_by_name",
+ client_mock=ClientMock(
+ method_path="asset_state.clear",
+ args=(),
+ kwargs={"name": "debug_watcher_asset"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=ClearAssetStateByUri(uri="s3://bucket/key"),
+ test_id="clear_asset_state_by_uri",
+ client_mock=ClientMock(
+ method_path="asset_state.clear",
+ args=(),
+ kwargs={"uri": "s3://bucket/key"},
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 630aff9094e..723ca42d93a 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -67,7 +67,7 @@ from airflow.sdk.api.datamodels._generated import (
)
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.types import NOTSET,
SET_DURING_EXECUTION, is_arg_set
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey,
Dataset, Model
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey,
AssetUriRef, Dataset, Model
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.exceptions import (
AirflowException,
@@ -86,31 +86,46 @@ from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import (
AssetEventResult,
AssetEventsResult,
+ AssetResult,
+ AssetsByAliasResult,
BundleInfo,
+ ClearAssetStateByName,
+ ClearTaskState,
ConnectionResult,
DagResult,
DagRunStateResult,
DeferTask,
+ DeleteAssetStateByName,
+ DeleteTaskState,
DRCount,
ErrorResponse,
+ GetAssetByUri,
+ GetAssetsByAlias,
+ GetAssetStateByName,
+ GetAssetStateByUri,
GetConnection,
GetDag,
GetDagRunState,
GetDRCount,
GetPreviousDagRun,
GetPreviousTI,
+ GetTaskState,
GetTaskStates,
GetTICount,
GetVariable,
GetXCom,
GetXComSequenceSlice,
+ InactiveAssetsResult,
MaskSecret,
OKResponse,
PreviousDagRunResult,
PreviousTIResult,
PrevSuccessfulDagRunResult,
RescheduleTask,
+ SetAssetStateByName,
+ SetAssetStateByUri,
SetRenderedFields,
+ SetTaskState,
SetXCom,
SkipDownstreamTasks,
StartupDetails,
@@ -120,6 +135,7 @@ from airflow.sdk.execution_time.comms import (
TaskStatesResult,
TICount,
TriggerDagRun,
+ ValidateInletsAndOutlets,
VariableResult,
XComResult,
XComSequenceSliceResult,
@@ -129,6 +145,7 @@ from airflow.sdk.execution_time.context import (
InletEventsAccessors,
MacrosAccessor,
OutletEventAccessors,
+ TaskStateAccessor,
TriggeringAssetEventsAccessor,
VariableAccessor,
)
@@ -1764,6 +1781,7 @@ class TestRuntimeTaskInstance:
"run_id": "test_run",
"task": task,
"task_instance": runtime_ti,
+ "task_state": TaskStateAccessor(ti_id=ti_id),
"ti": runtime_ti,
}
@@ -1809,6 +1827,7 @@ class TestRuntimeTaskInstance:
"run_id": "test_run",
"task": task,
"task_instance": runtime_ti,
+ "task_state": TaskStateAccessor(ti_id=runtime_ti.id),
"ti": runtime_ti,
"dag_run": dr,
"data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0),
@@ -4868,3 +4887,220 @@ def test_dag_add_result(create_runtime_ti,
mock_supervisor_comms):
dag_result=True,
)
)
+
+
+class TestTaskInstanceStateOperations:
+ """Tests to verify that tasks can perform state operations (task / asset)
via the supervisor."""
+
+ def test_task_can_set_and_get_state(self, create_runtime_ti,
mock_supervisor_comms):
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ ts = context["task_state"]
+ ts.set("job_id", "spark_app_001")
+ return ts.get("job_id")
+
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(ti_id=runtime_ti.id, key="job_id",
value="spark_app_001")
+ )
+
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id,
key="job_id"))
+
+ def test_task_can_delete_state(self, create_runtime_ti,
mock_supervisor_comms):
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ context["task_state"].delete("job_id")
+
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+
mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=runtime_ti.id,
key="job_id"))
+
+ @pytest.mark.parametrize(
+ ("call_kwargs", "expected_flag"),
+ [
+ pytest.param({}, False, id="default"),
+ pytest.param({"all_map_indices": True}, True, id="fleet-wipe"),
+ ],
+ )
+ def test_task_can_clear_state(self, call_kwargs, expected_flag,
create_runtime_ti, mock_supervisor_comms):
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ context["task_state"].clear(**call_kwargs)
+
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+ mock_supervisor_comms.send.assert_any_call(
+ ClearTaskState(ti_id=runtime_ti.id, all_map_indices=expected_flag)
+ )
+
+ @staticmethod
+ def _watcher_side_effect(msg=None, *args, **kwargs):
+ actual = msg or (args[0] if args else None)
+ if isinstance(actual, ValidateInletsAndOutlets):
+ return InactiveAssetsResult(inactive_assets=[])
+ if isinstance(actual, GetAssetByUri):
+ # GetAssetByUri has no .name field. Mirroring AssetModel behaviour:
+ # when only uri is provided, name defaults to uri.
+ return AssetResult(name=actual.uri, uri=actual.uri, group="asset")
+ return OKResponse(ok=True)
+
+ def test_asset_state_get_and_set(self, create_runtime_ti,
mock_supervisor_comms):
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"].set("watermark", "2026-04-30")
+ context["asset_state"].get("watermark")
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="my_asset", key="watermark",
value="2026-04-30")
+ )
+
mock_supervisor_comms.send.assert_any_call(GetAssetStateByName(name="my_asset",
key="watermark"))
+
+ def test_asset_state_delete(self, create_runtime_ti,
mock_supervisor_comms):
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"].delete("watermark")
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+
mock_supervisor_comms.send.assert_any_call(DeleteAssetStateByName(name="my_asset",
key="watermark"))
+
+ def test_asset_state_clear(self, create_runtime_ti, mock_supervisor_comms):
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"].clear()
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+
mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name="my_asset"))
+
+ def test_asset_state_uri_ref_inlet(self, create_runtime_ti,
mock_supervisor_comms):
+ watched = AssetUriRef(uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"].set("watermark", "2026-04-30")
+ context["asset_state"].get("watermark")
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByUri(uri="s3://bucket/data", key="watermark",
value="2026-04-30")
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ GetAssetStateByUri(uri="s3://bucket/data", key="watermark")
+ )
+
+ def test_asset_state_alias_as_inlet(self, create_runtime_ti,
mock_supervisor_comms):
+ alias = AssetAlias(name="my_alias")
+ resolved = Asset(name="resolved_asset", uri="s3://bucket/resolved")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"][resolved].set("watermark", "2026-05-01")
+
+ def side_effect(msg):
+ if isinstance(msg, GetAssetsByAlias):
+ return AssetsByAliasResult(
+ assets=[AssetResult(name=resolved.name, uri=resolved.uri,
group="asset")]
+ )
+ return TestTaskInstanceStateOperations._watcher_side_effect(msg)
+
+ task = WatcherOperator(task_id="t", inlets=[alias])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect = side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="resolved_asset", key="watermark",
value="2026-05-01")
+ )
+
+ def test_asset_state_alias_inlet_no_resolved_assets(self,
create_runtime_ti, mock_supervisor_comms):
+ alias = AssetAlias(name="empty_alias")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ # asset_state is in context but it is empty because alias
resolved to nothing
+ assert "asset_state" in context
+
+ def side_effect(msg):
+ if isinstance(msg, GetAssetsByAlias):
+ return AssetsByAliasResult(assets=[])
+ return TestTaskInstanceStateOperations._watcher_side_effect(msg)
+
+ task = WatcherOperator(task_id="t", inlets=[alias])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect = side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ def test_asset_state_keyed_access_single_inlet(self, create_runtime_ti,
mock_supervisor_comms):
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ # accessing via asset name key
+ context["asset_state"][watched].set("watermark", "2026-05-01")
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="my_asset", key="watermark",
value="2026-05-01")
+ )
+
+ def test_asset_state_multi_inlet(self, create_runtime_ti,
mock_supervisor_comms):
+ asset_a = Asset(name="asset_a", uri="s3://bucket/a")
+ asset_b = Asset(name="asset_b", uri="s3://bucket/b")
+
+ class MultiInletOperator(BaseOperator):
+ def execute(self, context):
+ context["asset_state"][asset_a].set("watermark_a",
"2026-05-01")
+ context["asset_state"][asset_b].set("watermark_b",
"2026-05-02")
+
+ task = MultiInletOperator(task_id="t", inlets=[asset_a, asset_b])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect =
TestTaskInstanceStateOperations._watcher_side_effect
+
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="asset_a", key="watermark_a",
value="2026-05-01")
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ SetAssetStateByName(name="asset_b", key="watermark_b",
value="2026-05-02")
+ )