This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new ee9278ea8e7 [v3-1-test] Fix Outlet Event Extra Data is Empty in Task
Instance Success Listener (#54568) (#57031)
ee9278ea8e7 is described below
commit ee9278ea8e714e0809593aa00056ec886c261b9f
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Oct 22 16:51:47 2025 +0800
[v3-1-test] Fix Outlet Event Extra Data is Empty in Task Instance Success
Listener (#54568) (#57031)
Co-authored-by: Kevin Yang <[email protected]>
---
.../check_template_context_variable_in_sync.py | 18 ++-
.../src/airflow/sdk/execution_time/task_runner.py | 13 +-
.../task_sdk/execution_time/test_task_runner.py | 140 ++++++++++++++++++++-
3 files changed, 161 insertions(+), 10 deletions(-)
diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py
b/scripts/ci/prek/check_template_context_variable_in_sync.py
index 0b74e4beedb..1c55fbd1920 100755
--- a/scripts/ci/prek/check_template_context_variable_in_sync.py
+++ b/scripts/ci/prek/check_template_context_variable_in_sync.py
@@ -83,17 +83,25 @@ def _iter_template_context_keys_from_original_return() ->
typing.Iterator[str]:
yield key.value
# Extract keys from the main `context` dictionary assignment
- context_assignment = next(
+ context_assignment: ast.AnnAssign = next(
stmt
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
- and isinstance(stmt.target, ast.Name)
- and stmt.target.id == "context"
+ and isinstance(stmt.target, ast.Attribute)
+ and isinstance(stmt.target.value, ast.Name)
+ and stmt.target.value.id == "self"
+ and stmt.target.attr == "_context"
)
- if not isinstance(context_assignment.value, ast.Dict):
+ if not isinstance(context_assignment.value, ast.BoolOp):
+ raise TypeError("Expected a BoolOp like 'self._context or {...}'.")
+
+ context_assignment_op = context_assignment.value
+ _, context_assignment_value = context_assignment_op.values
+
+ if not isinstance(context_assignment_value, ast.Dict):
raise ValueError("'context' is not assigned a dictionary literal")
- yield from extract_keys_from_dict(context_assignment.value)
+ yield from extract_keys_from_dict(context_assignment_value)
# Handle keys added conditionally in `if from_server`
for stmt in fn_get_template_context.body:
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 409982d1a6b..71385980271 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -131,6 +131,9 @@ class RuntimeTaskInstance(TaskInstance):
task: BaseOperator
bundle_instance: BaseDagBundle
+ _context: Context | None = None
+ """The Task Instance context."""
+
_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)]
= None
"""The Task Instance context from the API server, if any."""
@@ -173,7 +176,9 @@ class RuntimeTaskInstance(TaskInstance):
validated_params = process_params(self.task.dag, self.task,
dag_run_conf, suppress_exception=False)
- context: Context = {
+ # Cache the context object, which ensures that all calls to
get_template_context
+ # are operating on the same context object.
+ self._context: Context = self._context or {
# From the Task Execution interface
"dag": self.task.dag,
"inlets": self.task.inlets,
@@ -213,7 +218,7 @@ class RuntimeTaskInstance(TaskInstance):
lambda:
coerce_datetime(get_previous_dagrun_success(self.id).end_date)
),
}
- context.update(context_from_server)
+ self._context.update(context_from_server)
if logical_date := coerce_datetime(dag_run.logical_date):
if TYPE_CHECKING:
@@ -224,7 +229,7 @@ class RuntimeTaskInstance(TaskInstance):
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
# logical_date and data_interval either coexist or be None
together
- context.update(
+ self._context.update(
{
# keys that depend on logical_date
"logical_date": logical_date,
@@ -251,7 +256,7 @@ class RuntimeTaskInstance(TaskInstance):
# existence. Should this be a private attribute on RuntimeTI
instead perhaps?
setattr(self, "_upstream_map_indexes",
from_server.upstream_map_indexes)
- return context
+ return self._context
def render_templates(
self, context: Context | None = None, jinja_env: jinja2.Environment |
None = None
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 930d80590ea..a2a19266ae5 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -65,7 +65,7 @@ from airflow.sdk.api.datamodels._generated import (
)
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.types import NOTSET,
SET_DURING_EXECUTION, ArgNotSet
-from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey,
Dataset, Model
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
@@ -2482,6 +2482,32 @@ class TestTaskRunnerCallsListeners:
def before_stopping(self, component):
self.component = component
+ class CustomOutletEventsListener:
+ def __init__(self):
+ self.outlet_events = []
+ self.error = None
+
+ def _add_outlet_events(self, context):
+ outlets = context["outlets"]
+ for outlet in outlets:
+ self.outlet_events.append(context["outlet_events"][outlet])
+
+ @hookimpl
+ def on_task_instance_running(self, previous_state, task_instance):
+ context = task_instance.get_template_context()
+ self._add_outlet_events(context)
+
+ @hookimpl
+ def on_task_instance_success(self, previous_state, task_instance):
+ context = task_instance.get_template_context()
+ self._add_outlet_events(context)
+
+ @hookimpl
+ def on_task_instance_failed(self, previous_state, task_instance,
error):
+ context = task_instance.get_template_context()
+ self._add_outlet_events(context)
+ self.error = error
+
@pytest.fixture(autouse=True)
def clean_listener_manager(self):
lm = get_listener_manager()
@@ -2601,6 +2627,118 @@ class TestTaskRunnerCallsListeners:
assert listener.state == [TaskInstanceState.RUNNING,
TaskInstanceState.FAILED]
assert listener.error == error
+ def test_listener_access_outlet_event_on_running_and_success(self,
mocked_parse, mock_supervisor_comms):
+ """Test listener can access outlet events through invoking
get_template_context() while task running and success"""
+ listener = self.CustomOutletEventsListener()
+ get_listener_manager().add_listener(listener)
+
+ test_asset = Asset("test-asset")
+ test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
+ test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
+
+ class Producer(BaseOperator):
+ def execute(self, context):
+ outlet_events = context["outlet_events"]
+ outlet_events[test_asset].extra = test_extra
+
+ task = Producer(
+
task_id="test_listener_access_outlet_event_on_running_and_success",
outlets=[test_asset]
+ )
+ dag = get_inline_dag(dag_id="test_dag", task=task)
+ ti = TaskInstance(
+ id=uuid7(),
+ task_id=task.task_id,
+ dag_id=dag.dag_id,
+ run_id="test_run",
+ try_number=1,
+ dag_version_id=uuid7(),
+ )
+
+ runtime_ti = RuntimeTaskInstance.model_construct(
+ **ti.model_dump(exclude_unset=True), task=task,
start_date=timezone.utcnow()
+ )
+
+ log = mock.MagicMock()
+ context = runtime_ti.get_template_context()
+
+ with mock.patch(
+
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+ ) as validate_mock:
+ state, _, _ = run(runtime_ti, context, log)
+
+ validate_mock.assert_called_once()
+
+ outlet_event_accessor = listener.outlet_events.pop()
+ assert outlet_event_accessor.key == test_key
+ assert outlet_event_accessor.extra == test_extra
+
+ finalize(runtime_ti, state, context, log)
+
+ outlet_event_accessor = listener.outlet_events.pop()
+ assert outlet_event_accessor.key == test_key
+ assert outlet_event_accessor.extra == test_extra
+
+ @pytest.mark.parametrize(
+ "exception",
+ [
+ ValueError("oops"),
+ SystemExit("oops"),
+ AirflowException("oops"),
+ ],
+ ids=["ValueError", "SystemExit", "AirflowException"],
+ )
+ def test_listener_access_outlet_event_on_failed(self, mocked_parse,
mock_supervisor_comms, exception):
+ """Test listener can access outlet events through invoking
get_template_context() while task failed"""
+ listener = self.CustomOutletEventsListener()
+ get_listener_manager().add_listener(listener)
+
+ test_asset = Asset("test-asset")
+ test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
+ test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
+
+ class Producer(BaseOperator):
+ def execute(self, context):
+ outlet_events = context["outlet_events"]
+ outlet_events[test_asset].extra = test_extra
+ raise exception
+
+ task = Producer(task_id="test_listener_access_outlet_event_on_failed",
outlets=[test_asset])
+ dag = get_inline_dag(dag_id="test_dag", task=task)
+ ti = TaskInstance(
+ id=uuid7(),
+ task_id=task.task_id,
+ dag_id=dag.dag_id,
+ run_id="test_run",
+ try_number=1,
+ dag_version_id=uuid7(),
+ )
+
+ runtime_ti = RuntimeTaskInstance.model_construct(
+ **ti.model_dump(exclude_unset=True), task=task,
start_date=timezone.utcnow()
+ )
+
+ log = mock.MagicMock()
+ context = runtime_ti.get_template_context()
+
+ with mock.patch(
+
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
+ ) as validate_mock:
+ state, _, error = run(runtime_ti, context, log)
+
+ validate_mock.assert_called_once()
+
+ outlet_event_accessor = listener.outlet_events.pop()
+ assert outlet_event_accessor.key == test_key
+ assert outlet_event_accessor.extra == test_extra
+
+ finalize(runtime_ti, state, context, log, error)
+
+ outlet_event_accessor = listener.outlet_events.pop()
+ assert outlet_event_accessor.key == test_key
+ assert outlet_event_accessor.extra == test_extra
+
+ assert listener.error == error
+
@pytest.mark.usefixtures("mock_supervisor_comms")
class TestTaskRunnerCallsCallbacks: