This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d04712b86c2 AIP-103: Allowing outlets to be added/accessed in 
`AssetStateAccessor` (#67619)
d04712b86c2 is described below

commit d04712b86c2612cccd8fcb3c95f31505db227b08
Author: Jake McGrath <[email protected]>
AuthorDate: Tue Jun 2 01:39:31 2026 -0400

    AIP-103: Allowing outlets to be added/accessed in `AssetStateAccessor` 
(#67619)
---
 task-sdk/src/airflow/sdk/execution_time/context.py | 19 ++++---
 .../src/airflow/sdk/execution_time/task_runner.py  | 12 +++--
 .../tests/task_sdk/execution_time/test_context.py  | 58 +++++++++++++++++++++-
 3 files changed, 78 insertions(+), 11 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 159995b840c..209731e6142 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -742,16 +742,16 @@ class AssetStateAccessor:
 
 class AssetStateAccessors:
     """
-    Mapping of asset state accessors for all concrete inlets of a task.
+    Mapping of asset state accessors for all concrete inlets and outlets of a 
task.
 
     Available as ``context['asset_state']``. Subscript by asset to get a per 
asset
     accessor as: ``context['asset_state'][MY_ASSET].get('watermark')``.
 
-    For tasks with exactly one concrete inlet, the accessor methods (``get``, 
``set``,
-    ``delete``, ``clear``) can be called directly without subscripting.
+    For tasks with exactly one concrete inlet or outlet, the accessor methods 
(``get``,
+    ``set``, ``delete``, ``clear``) can be called directly without 
subscripting.
     """
 
-    def __init__(self, inlets: list) -> None:
+    def __init__(self, inlets: list, outlets: list | None = None) -> None:
         self._by_name: dict[str, AssetStateAccessor] = {}
         self._by_uri: dict[str, AssetStateAccessor] = {}
 
@@ -769,6 +769,13 @@ class AssetStateAccessors:
                     for asset in resp.assets:
                         self._by_name[asset.name] = 
AssetStateAccessor(name=asset.name)
 
+        for outlet in outlets or []:
+            # AssetAlias outlets are for dynamic event emission, not state 
access, so skip them
+            if isinstance(outlet, (Asset, AssetNameRef)) and outlet.name not 
in self._by_name:
+                self._by_name[outlet.name] = 
AssetStateAccessor(name=outlet.name)
+            elif isinstance(outlet, AssetUriRef) and outlet.uri not in 
self._by_uri:
+                self._by_uri[outlet.uri] = AssetStateAccessor(uri=outlet.uri)
+
         self._total = len(self._by_name) + len(self._by_uri)
 
     def __getitem__(self, key: Asset | AssetNameRef | AssetUriRef) -> 
AssetStateAccessor:
@@ -778,13 +785,13 @@ class AssetStateAccessors:
             if isinstance(key, AssetUriRef):
                 return self._by_uri[key.uri]
         except KeyError:
-            raise KeyError(f"{key!r} is not in this task's inlets")
+            raise KeyError(f"{key!r} is not in this task's inlets or outlets")
         raise TypeError(f"Expected Asset, AssetNameRef, or AssetUriRef; got 
{type(key).__name__}")
 
     def _single_accessor(self) -> AssetStateAccessor:
         if self._total != 1:
             raise ValueError(
-                f"Task has {self._total} concrete inlets — use 
context['asset_state'][MY_ASSET] to specify which"
+                f"Task has {self._total} concrete inlets and outlets — use 
context['asset_state'][MY_ASSET] to specify which"
             )
         if self._by_name:
             return next(iter(self._by_name.values()))
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 9f2599a1963..0ff9be4e8c9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -307,10 +307,14 @@ class RuntimeTaskInstance(TaskInstance):
                     ),
                 ),
             }
-            if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef, 
AssetAlias)) for i in self.task.inlets):
-                self._cached_template_context["asset_state"] = 
AssetStateAccessors(self.task.inlets)
-                # AssetAlias inlets are resolved to their concrete assets at 
context build time
-                # via GetAssetsByAlias comms. If an alias maps to no active 
assets, it doesnt contribute to asset_state.
+            _asset_types = (Asset, AssetNameRef, AssetUriRef, AssetAlias)
+            if any(isinstance(i, _asset_types) for i in self.task.inlets + 
self.task.outlets):
+                self._cached_template_context["asset_state"] = 
AssetStateAccessors(
+                    self.task.inlets, self.task.outlets
+                )
+                # AssetAlias inlets are resolved to their concrete assets at 
context build time via
+                # GetAssetsByAlias comms. If an alias maps to no active 
assets, it doesn't contribute to
+                # asset_state. AssetAlias outlets are skipped downstream in 
context.py
         if TYPE_CHECKING:
             assert self._cached_template_context is not None
         if from_server:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 8b69d69b3bd..1f4f684adf7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -1471,7 +1471,7 @@ class TestAssetStateAccessors:
         a1 = Asset(name="asset_one", uri="s3://one")
         a2 = Asset(name="asset_two", uri="s3://two")
 
-        with pytest.raises(ValueError, match="2 concrete inlets"):
+        with pytest.raises(ValueError, match="2 concrete inlets and outlets"):
             AssetStateAccessors([a1, a2]).get("watermark")
 
     def test_alias_inlet_resolves_to_concrete_assets(self, 
mock_supervisor_comms):
@@ -1497,6 +1497,62 @@ class TestAssetStateAccessors:
 
         assert accessors._total == 0
 
+    def test_outlet_only_asset_is_accessible(self, mock_supervisor_comms):
+        asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+        mock_supervisor_comms.send.return_value = AssetStateResult(value="v1")
+
+        result = AssetStateAccessors([], [asset])[asset].get("watermark")
+
+        assert result == "v1"
+        mock_supervisor_comms.send.assert_called_once_with(
+            GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+        )
+
+    def test_outlet_only_name_ref_is_accessible(self, mock_supervisor_comms):
+        ref = AssetNameRef(name=self.ASSET_NAME)
+        mock_supervisor_comms.send.return_value = AssetStateResult(value="v2")
+
+        result = AssetStateAccessors([], [ref])[ref].get("watermark")
+
+        assert result == "v2"
+        mock_supervisor_comms.send.assert_called_once_with(
+            GetAssetStateByName(name=self.ASSET_NAME, key="watermark")
+        )
+
+    def test_outlet_only_uri_ref_is_accessible(self, mock_supervisor_comms):
+        ref = AssetUriRef(uri=self.ASSET_URI)
+        mock_supervisor_comms.send.return_value = AssetStateResult(value="v2")
+
+        result = AssetStateAccessors([], [ref])[ref].get("watermark")
+
+        assert result == "v2"
+        mock_supervisor_comms.send.assert_called_once_with(
+            GetAssetStateByUri(uri=self.ASSET_URI, key="watermark")
+        )
+
+    def test_outlet_only_single_shorthand_works(self, mock_supervisor_comms):
+        asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+        mock_supervisor_comms.send.return_value = AssetStateResult(value="v3")
+
+        result = AssetStateAccessors([], [asset]).get("watermark")
+
+        assert result == "v3"
+
+    def test_asset_in_both_inlets_and_outlets_not_duplicated(self, 
mock_supervisor_comms):
+        asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}")
+
+        accessors = AssetStateAccessors([asset], [asset])
+
+        assert accessors._total == 1
+
+    def test_outlet_alias_is_ignored(self, mock_supervisor_comms):
+        alias = AssetAlias(name="my_alias")
+
+        accessors = AssetStateAccessors([], [alias])
+
+        assert accessors._total == 0
+        mock_supervisor_comms.send.assert_not_called()
+
 
 class InMemoryStateBackend(BaseStateBackend):
     """Simple in-memory test backend."""

Reply via email to