This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d04712b86c2 AIP-103: Allowing outlets to be added/accessed in
`AssetStateAccessor` (#67619)
d04712b86c2 is described below
commit d04712b86c2612cccd8fcb3c95f31505db227b08
Author: Jake McGrath <[email protected]>
AuthorDate: Tue Jun 2 01:39:31 2026 -0400
AIP-103: Allowing outlets to be added/accessed in `AssetStateAccessor`
(#67619)
---
task-sdk/src/airflow/sdk/execution_time/context.py | 19 ++++---
.../src/airflow/sdk/execution_time/task_runner.py | 12 +++--
.../tests/task_sdk/execution_time/test_context.py | 58 +++++++++++++++++++++-
3 files changed, 78 insertions(+), 11 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 159995b840c..209731e6142 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -742,16 +742,16 @@ class AssetStateAccessor:
class AssetStateAccessors:
"""
- Mapping of asset state accessors for all concrete inlets of a task.
+ Mapping of asset state accessors for all concrete inlets and outlets of a
task.
Available as ``context['asset_state']``. Subscript by asset to get a per
asset
accessor as: ``context['asset_state'][MY_ASSET].get('watermark')``.
- For tasks with exactly one concrete inlet, the accessor methods (``get``,
``set``,
- ``delete``, ``clear``) can be called directly without subscripting.
+ For tasks with exactly one concrete inlet or outlet, the accessor methods
(``get``,
+ ``set``, ``delete``, ``clear``) can be called directly without
subscripting.
"""
- def __init__(self, inlets: list) -> None:
+ def __init__(self, inlets: list, outlets: list | None = None) -> None:
self._by_name: dict[str, AssetStateAccessor] = {}
self._by_uri: dict[str, AssetStateAccessor] = {}
@@ -769,6 +769,13 @@ class AssetStateAccessors:
for asset in resp.assets:
self._by_name[asset.name] =
AssetStateAccessor(name=asset.name)
+ for outlet in outlets or []:
+ # AssetAlias outlets are for dynamic event emission, not state
access, so skip them
+ if isinstance(outlet, (Asset, AssetNameRef)) and outlet.name not
in self._by_name:
+ self._by_name[outlet.name] =
AssetStateAccessor(name=outlet.name)
+ elif isinstance(outlet, AssetUriRef) and outlet.uri not in
self._by_uri:
+ self._by_uri[outlet.uri] = AssetStateAccessor(uri=outlet.uri)
+
self._total = len(self._by_name) + len(self._by_uri)
def __getitem__(self, key: Asset | AssetNameRef | AssetUriRef) ->
AssetStateAccessor:
@@ -778,13 +785,13 @@ class AssetStateAccessors:
if isinstance(key, AssetUriRef):
return self._by_uri[key.uri]
except KeyError:
- raise KeyError(f"{key!r} is not in this task's inlets")
+ raise KeyError(f"{key!r} is not in this task's inlets or outlets")
raise TypeError(f"Expected Asset, AssetNameRef, or AssetUriRef; got
{type(key).__name__}")
def _single_accessor(self) -> AssetStateAccessor:
if self._total != 1:
raise ValueError(
- f"Task has {self._total} concrete inlets — use
context['asset_state'][MY_ASSET] to specify which"
+ f"Task has {self._total} concrete inlets and outlets — use
context['asset_state'][MY_ASSET] to specify which"
)
if self._by_name:
return next(iter(self._by_name.values()))
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 9f2599a1963..0ff9be4e8c9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -307,10 +307,14 @@ class RuntimeTaskInstance(TaskInstance):
),
),
}
- if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef,
AssetAlias)) for i in self.task.inlets):
- self._cached_template_context["asset_state"] =
AssetStateAccessors(self.task.inlets)
- # AssetAlias inlets are resolved to their concrete assets at
context build time
- # via GetAssetsByAlias comms. If an alias maps to no active
assets, it doesnt contribute to asset_state.
+ _asset_types = (Asset, AssetNameRef, AssetUriRef, AssetAlias)
+ if any(isinstance(i, _asset_types) for i in self.task.inlets +
self.task.outlets):
+ self._cached_template_context["asset_state"] =
AssetStateAccessors(
+ self.task.inlets, self.task.outlets
+ )
+ # AssetAlias inlets are resolved to their concrete assets at
context build time via
+ # GetAssetsByAlias comms. If an alias maps to no active
assets, it doesn't contribute to
+ # asset_state. AssetAlias outlets are skipped downstream in
context.py
if TYPE_CHECKING:
assert self._cached_template_context is not None
if from_server:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 8b69d69b3bd..1f4f684adf7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -1471,7 +1471,7 @@ class TestAssetStateAccessors:
a1 = Asset(name="asset_one", uri="s3://one")
a2 = Asset(name="asset_two", uri="s3://two")
- with pytest.raises(ValueError, match="2 concrete inlets"):
+ with pytest.raises(ValueError, match="2 concrete inlets and outlets"):
AssetStateAccessors([a1, a2]).get("watermark")
def test_alias_inlet_resolves_to_concrete_assets(self,
mock_supervisor_comms):
@@ -1497,6 +1497,62 @@ class TestAssetStateAccessors:
assert accessors._total == 0
+ def test_outlet_only_asset_is_accessible(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v1")
+
+ result = AssetStateAccessors([], [asset])[asset].get("watermark")
+
+ assert result == "v1"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_outlet_only_name_ref_is_accessible(self, mock_supervisor_comms):
+ ref = AssetNameRef(name=self.ASSET_NAME)
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v2")
+
+ result = AssetStateAccessors([], [ref])[ref].get("watermark")
+
+ assert result == "v2"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+ )
+
+ def test_outlet_only_uri_ref_is_accessible(self, mock_supervisor_comms):
+ ref = AssetUriRef(uri=self.ASSET_URI)
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v2")
+
+ result = AssetStateAccessors([], [ref])[ref].get("watermark")
+
+ assert result == "v2"
+ mock_supervisor_comms.send.assert_called_once_with(
+ GetAssetStateByUri(uri=self.ASSET_URI, key="watermark")
+ )
+
+ def test_outlet_only_single_shorthand_works(self, mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+ mock_supervisor_comms.send.return_value = AssetStateResult(value="v3")
+
+ result = AssetStateAccessors([], [asset]).get("watermark")
+
+ assert result == "v3"
+
+ def test_asset_in_both_inlets_and_outlets_not_duplicated(self,
mock_supervisor_comms):
+ asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+
+ accessors = AssetStateAccessors([asset], [asset])
+
+ assert accessors._total == 1
+
+ def test_outlet_alias_is_ignored(self, mock_supervisor_comms):
+ alias = AssetAlias(name="my_alias")
+
+ accessors = AssetStateAccessors([], [alias])
+
+ assert accessors._total == 0
+ mock_supervisor_comms.send.assert_not_called()
+
class InMemoryStateBackend(BaseStateBackend):
"""Simple in-memory test backend."""