This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 964870ee704 add error as context["exception"] in 
InProcessTestSupervisor (#64568)
964870ee704 is described below

commit 964870ee704dc337c1290cdf6108c19a2e6ccf08
Author: Anatoli <[email protected]>
AuthorDate: Wed May 27 04:21:28 2026 +0300

    add error as context["exception"] in InProcessTestSupervisor (#64568)
    
    Co-authored-by: Copilot Autofix powered by AI 
<[email protected]>
    Co-authored-by: Amogh Desai <[email protected]>
---
 .../src/airflow/sdk/execution_time/supervisor.py   |  1 +
 .../task_sdk/execution_time/test_supervisor.py     | 53 +++++++++++++++++++++-
 2 files changed, 53 insertions(+), 1 deletion(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 45d486a688d..ac479dcebe6 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1952,6 +1952,7 @@ class InProcessTestSupervisor(ActivitySubprocess):
                 log = structlog.get_logger(logger_name="task")
 
                 state, msg, error = run(ti, context, log)
+                context["exception"] = error
                 finalize(ti, state, context, log, error)
 
                 # In the normal subprocess model, the task runner calls this 
before exiting.
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 811101dbf4c..11d22865c0d 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -53,7 +53,7 @@ from task_sdk import FAKE_BUNDLE, make_client
 from uuid6 import uuid7
 
 from airflow.executors.workloads import BundleInfo
-from airflow.sdk import BaseOperator, timezone
+from airflow.sdk import DAG, BaseOperator, timezone
 from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.client import ServerResponseError
 from airflow.sdk.api.datamodels._generated import (
@@ -175,6 +175,8 @@ from tests_common.test_utils.config import conf_vars
 if TYPE_CHECKING:
     import kgb
 
+    from airflow.sdk.definitions.context import Context
+
 log = logging.getLogger(__name__)
 TI_ID = uuid7()
 
@@ -3289,6 +3291,55 @@ class TestInProcessTestSupervisor:
         assert isinstance(response, VariableResult)
         assert response.value == "value"
 
+    def test_inprocess_failure_callback_receives_exception(
+        self,
+        monkeypatch,
+        make_ti_context,
+    ):
+        """Run a failing task via InProcessTestSupervisor and ensure the
+        `on_failure_callback` receives `context['exception']`.
+        """
+        collected: list[BaseException | None] = [None]
+
+        class _Failure(Exception):
+            pass
+
+        def failure_callback(context):
+            collected[0] = context.get("exception")
+
+        class FailingOperator(BaseOperator):
+            def execute(self, context: Context):
+                raise _Failure("boom")
+
+        task = FailingOperator(task_id="failing", 
on_failure_callback=failure_callback)
+
+        task.dag = DAG(dag_id="test_dag")
+
+        # Create a simple TaskInstance datamodel to pass to the supervisor
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id=task.task_id,
+            dag_id="test_dag",
+            run_id="r",
+            try_number=1,
+            dag_version_id=uuid7(),
+        )
+
+        # Patch the API client used by InProcessTestSupervisor to return a 
predictable TI context
+        fake_task_instances = mock.MagicMock(spec_set=["start", "finish"])
+        fake_task_instances.start.return_value = make_ti_context()
+        fake_client = mock.MagicMock(spec_set=["task_instances"])
+        fake_client.task_instances = fake_task_instances
+        monkeypatch.setattr(
+            InProcessTestSupervisor, "_api_client", staticmethod(lambda 
dag=None: fake_client)
+        )
+
+        result = InProcessTestSupervisor.start(what=ti, task=task)
+
+        # Ensure the task failed and the callback saw the exception
+        assert isinstance(result.error, _Failure)
+        assert isinstance(collected[0], _Failure)
+
 
 class TestInProcessClient:
     def test_no_retries(self):

Reply via email to