This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-1-test by this push:
     new ee9278ea8e7 [v3-1-test] Fix Outlet Event Extra Data is Empty in Task 
Instance Success Listener (#54568) (#57031)
ee9278ea8e7 is described below

commit ee9278ea8e714e0809593aa00056ec886c261b9f
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Oct 22 16:51:47 2025 +0800

    [v3-1-test] Fix Outlet Event Extra Data is Empty in Task Instance Success 
Listener (#54568) (#57031)
    
    Co-authored-by: Kevin Yang <[email protected]>
---
 .../check_template_context_variable_in_sync.py     |  18 ++-
 .../src/airflow/sdk/execution_time/task_runner.py  |  13 +-
 .../task_sdk/execution_time/test_task_runner.py    | 140 ++++++++++++++++++++-
 3 files changed, 161 insertions(+), 10 deletions(-)

diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py 
b/scripts/ci/prek/check_template_context_variable_in_sync.py
index 0b74e4beedb..1c55fbd1920 100755
--- a/scripts/ci/prek/check_template_context_variable_in_sync.py
+++ b/scripts/ci/prek/check_template_context_variable_in_sync.py
@@ -83,17 +83,25 @@ def _iter_template_context_keys_from_original_return() -> 
typing.Iterator[str]:
             yield key.value
 
     # Extract keys from the main `context` dictionary assignment
-    context_assignment = next(
+    context_assignment: ast.AnnAssign = next(
         stmt
         for stmt in fn_get_template_context.body
         if isinstance(stmt, ast.AnnAssign)
-        and isinstance(stmt.target, ast.Name)
-        and stmt.target.id == "context"
+        and isinstance(stmt.target, ast.Attribute)
+        and isinstance(stmt.target.value, ast.Name)
+        and stmt.target.value.id == "self"
+        and stmt.target.attr == "_context"
     )
 
-    if not isinstance(context_assignment.value, ast.Dict):
+    if not isinstance(context_assignment.value, ast.BoolOp):
+        raise TypeError("Expected a BoolOp like 'self._context or {...}'.")
+
+    context_assignment_op = context_assignment.value
+    _, context_assignment_value = context_assignment_op.values
+
+    if not isinstance(context_assignment_value, ast.Dict):
         raise ValueError("'context' is not assigned a dictionary literal")
-    yield from extract_keys_from_dict(context_assignment.value)
+    yield from extract_keys_from_dict(context_assignment_value)
 
     # Handle keys added conditionally in `if from_server`
     for stmt in fn_get_template_context.body:
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 409982d1a6b..71385980271 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -131,6 +131,9 @@ class RuntimeTaskInstance(TaskInstance):
 
     task: BaseOperator
     bundle_instance: BaseDagBundle
+    _context: Context | None = None
+    """The Task Instance context."""
+
     _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] 
= None
     """The Task Instance context from the API server, if any."""
 
@@ -173,7 +176,9 @@ class RuntimeTaskInstance(TaskInstance):
 
         validated_params = process_params(self.task.dag, self.task, 
dag_run_conf, suppress_exception=False)
 
-        context: Context = {
+        # Cache the context object, which ensures that all calls to 
get_template_context
+        # are operating on the same context object.
+        self._context: Context = self._context or {
             # From the Task Execution interface
             "dag": self.task.dag,
             "inlets": self.task.inlets,
@@ -213,7 +218,7 @@ class RuntimeTaskInstance(TaskInstance):
                     lambda: 
coerce_datetime(get_previous_dagrun_success(self.id).end_date)
                 ),
             }
-            context.update(context_from_server)
+            self._context.update(context_from_server)
 
             if logical_date := coerce_datetime(dag_run.logical_date):
                 if TYPE_CHECKING:
@@ -224,7 +229,7 @@ class RuntimeTaskInstance(TaskInstance):
                 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
                 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
                 # logical_date and data_interval either coexist or be None 
together
-                context.update(
+                self._context.update(
                     {
                         # keys that depend on logical_date
                         "logical_date": logical_date,
@@ -251,7 +256,7 @@ class RuntimeTaskInstance(TaskInstance):
                 # existence. Should this be a private attribute on RuntimeTI 
instead perhaps?
                 setattr(self, "_upstream_map_indexes", 
from_server.upstream_map_indexes)
 
-        return context
+        return self._context
 
     def render_templates(
         self, context: Context | None = None, jinja_env: jinja2.Environment | 
None = None
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 930d80590ea..a2a19266ae5 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -65,7 +65,7 @@ from airflow.sdk.api.datamodels._generated import (
 )
 from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions._internal.types import NOTSET, 
SET_DURING_EXECUTION, ArgNotSet
-from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, 
Dataset, Model
 from airflow.sdk.definitions.param import DagParam
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import (
@@ -2482,6 +2482,32 @@ class TestTaskRunnerCallsListeners:
         def before_stopping(self, component):
             self.component = component
 
+    class CustomOutletEventsListener:
+        def __init__(self):
+            self.outlet_events = []
+            self.error = None
+
+        def _add_outlet_events(self, context):
+            outlets = context["outlets"]
+            for outlet in outlets:
+                self.outlet_events.append(context["outlet_events"][outlet])
+
+        @hookimpl
+        def on_task_instance_running(self, previous_state, task_instance):
+            context = task_instance.get_template_context()
+            self._add_outlet_events(context)
+
+        @hookimpl
+        def on_task_instance_success(self, previous_state, task_instance):
+            context = task_instance.get_template_context()
+            self._add_outlet_events(context)
+
+        @hookimpl
+        def on_task_instance_failed(self, previous_state, task_instance, 
error):
+            context = task_instance.get_template_context()
+            self._add_outlet_events(context)
+            self.error = error
+
     @pytest.fixture(autouse=True)
     def clean_listener_manager(self):
         lm = get_listener_manager()
@@ -2601,6 +2627,118 @@ class TestTaskRunnerCallsListeners:
         assert listener.state == [TaskInstanceState.RUNNING, 
TaskInstanceState.FAILED]
         assert listener.error == error
 
+    def test_listener_access_outlet_event_on_running_and_success(self, 
mocked_parse, mock_supervisor_comms):
+        """Test listener can access outlet events through invoking 
get_template_context() while task running and success"""
+        listener = self.CustomOutletEventsListener()
+        get_listener_manager().add_listener(listener)
+
+        test_asset = Asset("test-asset")
+        test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
+        test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
+
+        class Producer(BaseOperator):
+            def execute(self, context):
+                outlet_events = context["outlet_events"]
+                outlet_events[test_asset].extra = test_extra
+
+        task = Producer(
+            
task_id="test_listener_access_outlet_event_on_running_and_success", 
outlets=[test_asset]
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id=task.task_id,
+            dag_id=dag.dag_id,
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid7(),
+        )
+
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **ti.model_dump(exclude_unset=True), task=task, 
start_date=timezone.utcnow()
+        )
+
+        log = mock.MagicMock()
+        context = runtime_ti.get_template_context()
+
+        with mock.patch(
+            
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+        ) as validate_mock:
+            state, _, _ = run(runtime_ti, context, log)
+
+        validate_mock.assert_called_once()
+
+        outlet_event_accessor = listener.outlet_events.pop()
+        assert outlet_event_accessor.key == test_key
+        assert outlet_event_accessor.extra == test_extra
+
+        finalize(runtime_ti, state, context, log)
+
+        outlet_event_accessor = listener.outlet_events.pop()
+        assert outlet_event_accessor.key == test_key
+        assert outlet_event_accessor.extra == test_extra
+
+    @pytest.mark.parametrize(
+        "exception",
+        [
+            ValueError("oops"),
+            SystemExit("oops"),
+            AirflowException("oops"),
+        ],
+        ids=["ValueError", "SystemExit", "AirflowException"],
+    )
+    def test_listener_access_outlet_event_on_failed(self, mocked_parse, 
mock_supervisor_comms, exception):
+        """Test listener can access outlet events through invoking 
get_template_context() while task failed"""
+        listener = self.CustomOutletEventsListener()
+        get_listener_manager().add_listener(listener)
+
+        test_asset = Asset("test-asset")
+        test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
+        test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
+
+        class Producer(BaseOperator):
+            def execute(self, context):
+                outlet_events = context["outlet_events"]
+                outlet_events[test_asset].extra = test_extra
+                raise exception
+
+        task = Producer(task_id="test_listener_access_outlet_event_on_failed", 
outlets=[test_asset])
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id=task.task_id,
+            dag_id=dag.dag_id,
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid7(),
+        )
+
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **ti.model_dump(exclude_unset=True), task=task, 
start_date=timezone.utcnow()
+        )
+
+        log = mock.MagicMock()
+        context = runtime_ti.get_template_context()
+
+        with mock.patch(
+            
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+        ) as validate_mock:
+            state, _, error = run(runtime_ti, context, log)
+
+        validate_mock.assert_called_once()
+
+        outlet_event_accessor = listener.outlet_events.pop()
+        assert outlet_event_accessor.key == test_key
+        assert outlet_event_accessor.extra == test_extra
+
+        finalize(runtime_ti, state, context, log, error)
+
+        outlet_event_accessor = listener.outlet_events.pop()
+        assert outlet_event_accessor.key == test_key
+        assert outlet_event_accessor.extra == test_extra
+
+        assert listener.error == error
+
 
 @pytest.mark.usefixtures("mock_supervisor_comms")
 class TestTaskRunnerCallsCallbacks:

Reply via email to