This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 161030e0675c1b10b531b3dbffcf390900c75c3e
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Jul 10 22:08:57 2025 +0530

    Run Task failure callbacks on DAG Processor when task is externally killed 
(#53058) (#53143)
    
    Until https://github.com/apache/airflow/issues/44354 is implemented, tasks 
killed externally or when supervisor process dies unexpectedly, users have no 
way of knowing this happened.
    
    This has been a blocker for Airflow 3.0 adoption for some:
    
    - https://github.com/apache/airflow/issues/44354
    - https://apache-airflow.slack.com/archives/C07813CNKA8/p1751057525231389
    
    https://github.com/apache/airflow/issues/44354 is more involved and we 
might not get to it for Airflow 3.1 -- so this is a good fix until then similar 
to how we run Dag Run callback.
    
    (cherry-picked from a5211f2efd5ccc565cbc16baee6144dba09918bc)
---
 .../execution_api/datamodels/taskinstance.py       |   4 +-
 .../src/airflow/callbacks/callback_requests.py     |   2 +
 .../src/airflow/dag_processing/processor.py        |  72 ++++-
 .../src/airflow/jobs/scheduler_job_runner.py       |  18 +-
 .../tests/unit/callbacks/test_callback_requests.py |  29 +-
 .../tests/unit/dag_processing/test_processor.py    | 328 ++++++++++++++++++---
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 141 ++++-----
 .../src/airflow/sdk/api/datamodels/_generated.py   |   2 +-
 8 files changed, 472 insertions(+), 124 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index c43c931f3e2..2d7968bbc62 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -302,7 +302,7 @@ class TIRunContext(BaseModel):
     dag_run: DagRun
     """DAG run information for the task instance."""
 
-    task_reschedule_count: Annotated[int, Field(default=0)]
+    task_reschedule_count: int = 0
     """How many times the task has been rescheduled."""
 
     max_tries: int
@@ -328,7 +328,7 @@ class TIRunContext(BaseModel):
     xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
     """List of Xcom keys that need to be cleared and purged on by the 
worker."""
 
-    should_retry: bool
+    should_retry: bool = False
     """If the ti encounters an error, whether it should enter retry or failed 
state."""
 
 
diff --git a/airflow-core/src/airflow/callbacks/callback_requests.py 
b/airflow-core/src/airflow/callbacks/callback_requests.py
index 8cf8c770357..3220497a209 100644
--- a/airflow-core/src/airflow/callbacks/callback_requests.py
+++ b/airflow-core/src/airflow/callbacks/callback_requests.py
@@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest):
     """Simplified Task Instance representation"""
     task_callback_type: TaskInstanceState | None = None
     """Whether on success, on failure, on retry"""
+    context_from_server: ti_datamodel.TIRunContext | None = None
+    """Task execution context from the Server"""
     type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest"
 
     @property
diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index a54c56ddbe2..36022c61dfd 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -16,12 +16,14 @@
 # under the License.
 from __future__ import annotations
 
+import contextlib
 import importlib
 import os
 import sys
 import traceback
+from collections.abc import Callable, Sequence
 from pathlib import Path
-from typing import TYPE_CHECKING, Annotated, BinaryIO, Callable, ClassVar, 
Literal, Union
+from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal, Union
 
 import attrs
 from pydantic import BaseModel, Field, TypeAdapter
@@ -44,9 +46,11 @@ from airflow.sdk.execution_time.comms import (
     VariableResult,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
+from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
 from airflow.stats import Stats
 from airflow.utils.file import iter_airflow_imports
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger
@@ -200,10 +204,7 @@ def _execute_callbacks(
     for request in callback_requests:
         log.debug("Processing Callback Request", request=request.to_json())
         if isinstance(request, TaskCallbackRequest):
-            raise NotImplementedError(
-                "Haven't coded Task callback yet - 
https://github.com/apache/airflow/issues/44354!";
-            )
-            # _execute_task_callbacks(dagbag, request)
+            _execute_task_callbacks(dagbag, request, log)
         if isinstance(request, DagCallbackRequest):
             _execute_dag_callbacks(dagbag, request, log)
 
@@ -237,6 +238,67 @@ def _execute_dag_callbacks(dagbag: DagBag, request: 
DagCallbackRequest, log: Fil
             Stats.incr("dag.callback_exceptions", tags={"dag_id": 
request.dag_id})
 
 
+def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: 
FilteringBoundLogger) -> None:
+    if not request.is_failure_callback:
+        log.warning(
+            "Task callback requested but is not a failure callback",
+            dag_id=request.ti.dag_id,
+            task_id=request.ti.task_id,
+            run_id=request.ti.run_id,
+        )
+        return
+
+    dag = dagbag.dags[request.ti.dag_id]
+    task = dag.get_task(request.ti.task_id)
+
+    if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY:
+        callbacks = task.on_retry_callback
+    else:
+        callbacks = task.on_failure_callback
+
+    if not callbacks:
+        log.warning(
+            "Callback requested but no callback found",
+            dag_id=request.ti.dag_id,
+            task_id=request.ti.task_id,
+            run_id=request.ti.run_id,
+            ti_id=request.ti.id,
+        )
+        return
+
+    callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks]
+    ctx_from_server = request.context_from_server
+
+    if ctx_from_server is not None:
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **request.ti.model_dump(exclude_unset=True),
+            task=task,
+            _ti_context_from_server=ctx_from_server,
+            max_tries=ctx_from_server.max_tries,
+        )
+    else:
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **request.ti.model_dump(exclude_unset=True),
+            task=task,
+        )
+    context = runtime_ti.get_template_context()
+
+    def get_callback_representation(callback):
+        with contextlib.suppress(AttributeError):
+            return callback.__name__
+        with contextlib.suppress(AttributeError):
+            return callback.__class__.__name__
+        return callback
+
+    for idx, callback in enumerate(callbacks):
+        callback_repr = get_callback_representation(callback)
+        log.info("Executing Task callback at index %d: %s", idx, callback_repr)
+        try:
+            callback(context)
+        except Exception:
+            log.exception("Error in callback at index %d: %s", idx, 
callback_repr)
+
+
 def in_process_api_server() -> InProcessExecutionAPI:
     from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
 
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 4185cbf6d44..0fa598be412 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -38,6 +38,7 @@ from sqlalchemy.orm import joinedload, lazyload, load_only, 
make_transient, sele
 from sqlalchemy.sql import expression
 
 from airflow import settings
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIRunContext
 from airflow.callbacks.callback_requests import DagCallbackRequest, 
TaskCallbackRequest
 from airflow.configuration import conf
 from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
@@ -945,10 +946,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                         bundle_version=ti.dag_version.bundle_version,
                         ti=ti,
                         msg=msg,
+                        context_from_server=TIRunContext(
+                            dag_run=ti.dag_run,
+                            max_tries=ti.max_tries,
+                            variables=[],
+                            connections=[],
+                            xcom_keys_to_clear=[],
+                        ),
                     )
                     executor.send_callback(request)
-                else:
-                    ti.handle_failure(error=msg, session=session)
+                ti.handle_failure(error=msg, session=session)
 
         return len(event_buffer)
 
@@ -2296,6 +2303,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 bundle_version=ti.dag_run.bundle_version,
                 ti=ti,
                 msg=str(task_instance_heartbeat_timeout_message_details),
+                context_from_server=TIRunContext(
+                    dag_run=ti.dag_run,
+                    max_tries=ti.max_tries,
+                    variables=[],
+                    connections=[],
+                    xcom_keys_to_clear=[],
+                ),
             )
             session.add(
                 Log(
diff --git a/airflow-core/tests/unit/callbacks/test_callback_requests.py 
b/airflow-core/tests/unit/callbacks/test_callback_requests.py
index 37a7a3023d8..d27b7ee343c 100644
--- a/airflow-core/tests/unit/callbacks/test_callback_requests.py
+++ b/airflow-core/tests/unit/callbacks/test_callback_requests.py
@@ -28,7 +28,7 @@ from airflow.models.dag import DAG
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers.standard.operators.bash import BashOperator
 from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 
 pytestmark = pytest.mark.db_test
 
@@ -85,3 +85,30 @@ class TestCallbackRequest:
         json_str = input.to_json()
         result = TaskCallbackRequest.from_json(json_str)
         assert input == result
+
+    @pytest.mark.parametrize(
+        "task_callback_type,expected_is_failure",
+        [
+            (None, True),
+            (TaskInstanceState.FAILED, True),
+            (TaskInstanceState.UP_FOR_RETRY, True),
+            (TaskInstanceState.UPSTREAM_FAILED, True),
+            (TaskInstanceState.SUCCESS, False),
+            (TaskInstanceState.RUNNING, False),
+        ],
+    )
+    def test_is_failure_callback_property(
+        self, task_callback_type, expected_is_failure, create_task_instance
+    ):
+        """Test is_failure_callback property with different task callback 
types"""
+        ti = create_task_instance()
+
+        request = TaskCallbackRequest(
+            filepath="filepath",
+            ti=ti,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=task_callback_type,
+        )
+
+        assert request.is_failure_callback == expected_is_failure
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index e8ab15a0222..71b601dfdd8 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -21,8 +21,10 @@ import inspect
 import pathlib
 import sys
 import textwrap
+import uuid
+from collections.abc import Callable
 from socket import socketpair
-from typing import TYPE_CHECKING, Callable
+from typing import TYPE_CHECKING
 from unittest.mock import MagicMock, patch
 
 import pytest
@@ -30,24 +32,27 @@ import structlog
 from pydantic import TypeAdapter
 
 from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    TaskInstance as TIDataModel,
+    TIRunContext,
+)
 from airflow.callbacks.callback_requests import CallbackRequest, 
DagCallbackRequest, TaskCallbackRequest
-from airflow.configuration import conf
 from airflow.dag_processing.processor import (
     DagFileParseRequest,
     DagFileParsingResult,
     DagFileProcessorProcess,
+    _execute_task_callbacks,
     _parse_file,
     _pre_import_airflow_modules,
 )
-from airflow.models import DagBag, TaskInstance
+from airflow.models import DagBag, DagRun
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.serialized_dag import SerializedDagModel
+from airflow.sdk import DAG
 from airflow.sdk.api.client import Client
 from airflow.sdk.execution_time import comms
 from airflow.utils import timezone
 from airflow.utils.session import create_session
-from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.types import DagRunTriggeredByType, DagRunType
+from airflow.utils.state import TaskInstanceState
 
 from tests_common.test_utils.config import conf_vars, env_vars
 
@@ -93,42 +98,6 @@ class TestDagFileProcessor:
             log=structlog.get_logger(),
         )
 
-    @pytest.mark.xfail(reason="TODO: AIP-72")
-    @pytest.mark.parametrize(
-        ["has_serialized_dag"],
-        [pytest.param(True, id="dag_in_db"), pytest.param(False, 
id="no_dag_found")],
-    )
-    @patch.object(TaskInstance, "handle_failure")
-    def test_execute_on_failure_callbacks_without_dag(self, 
mock_ti_handle_failure, has_serialized_dag):
-        dagbag = DagBag(dag_folder="/dev/null", include_examples=True, 
read_dags_from_db=False)
-        with create_session() as session:
-            session.query(TaskInstance).delete()
-            dag = dagbag.get_dag("example_branch_operator")
-            assert dag is not None
-            dag.sync_to_db()
-            dagrun = dag.create_dagrun(
-                state=DagRunState.RUNNING,
-                logical_date=DEFAULT_DATE,
-                run_type=DagRunType.SCHEDULED,
-                data_interval=dag.infer_automated_data_interval(DEFAULT_DATE),
-                run_after=DEFAULT_DATE,
-                triggered_by=DagRunTriggeredByType.TEST,
-                session=session,
-            )
-            task = dag.get_task(task_id="run_this_first")
-            ti = TaskInstance(task, run_id=dagrun.run_id, 
state=TaskInstanceState.QUEUED)
-            session.add(ti)
-
-            if has_serialized_dag:
-                assert SerializedDagModel.write_dag(dag, 
bundle_name="testing", session=session) is True
-                session.flush()
-
-        requests = [TaskCallbackRequest(full_filepath="A", ti=ti, 
msg="Message")]
-        self._process_file(dag.fileloc, requests)
-        mock_ti_handle_failure.assert_called_once_with(
-            error="Message", test_mode=conf.getboolean("core", 
"unit_test_mode"), session=session
-        )
-
     def test_dagbag_import_errors_captured(self, spy_agency: SpyAgency):
         @spy_agency.spy_for(DagBag.collect_dags, owner=DagBag)
         def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
@@ -554,10 +523,7 @@ def test_parse_file_with_dag_callbacks(spy_agency):
     assert called is True
 
 
[email protected](reason="TODO: AIP-72: Task level callbacks not yet 
supported")
 def test_parse_file_with_task_callbacks(spy_agency):
-    from airflow import DAG
-
     called = False
 
     def on_failure(context):
@@ -572,15 +538,283 @@ def test_parse_file_with_task_callbacks(spy_agency):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            BaseOperator(task_id="test_task", on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            BaseOperator(task_id="test_task", on_retry_callback=on_retry)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task retrying",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_with_context_from_server(self, spy_agency):
+        """Test _execute_task_callbacks with context_from_server creates full 
context"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            BaseOperator(task_id="test_task", on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        dag_run = DagRun(
+            dag_id="test_dag",
+            run_id="test_run",
+            logical_date=timezone.utcnow(),
+            start_date=timezone.utcnow(),
+            run_type="manual",
+        )
+        dag_run.run_after = timezone.utcnow()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        context_from_server = TIRunContext(
+            dag_run=dag_run,
+            max_tries=3,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+            context_from_server=context_from_server,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        # When context_from_server is provided, we get a full 
RuntimeTaskInstance context
+        assert "dag_run" in context_received
+        assert "logical_date" in context_received
+
+    def test_execute_task_callbacks_not_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks when request is not a failure 
callback"""
+        called = False
+
+        def on_failure(context):
+            nonlocal called
+            called = True
+
+        with DAG(dag_id="test_dag") as dag:
+            BaseOperator(task_id="test_task", on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.SUCCESS,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task succeeded",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.SUCCESS,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        # Should not call the callback since it's not a failure callback
+        assert called is False
+
+    def test_execute_task_callbacks_multiple_callbacks(self, spy_agency):
+        """Test _execute_task_callbacks with multiple callbacks"""
+        call_count = 0
+
+        def on_failure_1(context):
+            nonlocal call_count
+            call_count += 1
+
+        def on_failure_2(context):
+            nonlocal call_count
+            call_count += 1
+
+        with DAG(dag_id="test_dag") as dag:
+            BaseOperator(task_id="test_task", 
on_failure_callback=[on_failure_1, on_failure_2])
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.FAILED,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert call_count == 2
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index bced9e15be4..30a3ba98007 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -408,9 +408,7 @@ class TestSchedulerJob:
 
         self.job_runner._process_executor_events(executor=executor, 
session=session)
         ti1.refresh_from_db()
-        # The state will remain in queued here and
-        # will be set to failed in dag parsing process
-        assert ti1.state == State.QUEUED
+        assert ti1.state == State.FAILED
         mock_task_callback.assert_called_once_with(
             filepath=dag.relative_fileloc,
             ti=mock.ANY,
@@ -420,10 +418,11 @@ class TestSchedulerJob:
             "<TaskInstance: 
test_process_executor_events_with_callback.dummy_task test [queued]> "
             "finished with state failed, but the task instance's state 
attribute is queued. "
             "Learn more: 
https://airflow.apache.org/docs/apache-airflow/stable/troubleshooting.html#task-state-changed-externally";,
+            context_from_server=mock.ANY,
         )
         
scheduler_job.executor.callback_sink.send.assert_called_once_with(task_callback)
         scheduler_job.executor.callback_sink.reset_mock()
-        mock_stats_incr.assert_called_once_with(
+        mock_stats_incr.assert_any_call(
             "scheduler.tasks.killed_externally",
             tags={
                 "dag_id": "test_process_executor_events_with_callback",
@@ -5880,6 +5879,11 @@ class TestSchedulerJob:
         assert callback_request.ti.run_id == ti.run_id
         assert callback_request.ti.map_index == ti.map_index
 
+        # Verify context_from_server is passed
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
ti.dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
     @pytest.mark.usefixtures("testing_dag_bundle")
     def test_task_instance_heartbeat_timeout_message(self, session, 
create_dagrun):
         """
@@ -5947,68 +5951,6 @@ class TestSchedulerJob:
             "External Executor Id": "abcdefg",
         }
 
-    @pytest.mark.usefixtures("testing_dag_bundle")
-    def 
test_find_task_instances_without_heartbeats_handle_failure_callbacks_are_correctly_passed_to_dag_processor(
-        self, create_dagrun, session
-    ):
-        """
-        Check that the same set of failure callbacks for task instances 
without heartbeats are passed to the dag
-        file processors until the next task instance heartbeat timeout 
detection logic is invoked.
-        """
-        with conf_vars({("core", "load_examples"): "False"}):
-            dagbag = DagBag(
-                dag_folder=os.path.join(settings.DAGS_FOLDER, 
"test_example_bash_operator.py"),
-                read_dags_from_db=False,
-            )
-            session.query(Job).delete()
-            dag = dagbag.get_dag("test_example_bash_operator")
-            DAG.bulk_write_to_db("testing", None, [dag])
-            SerializedDagModel.write_dag(dag=dag, bundle_name="testing")
-            data_interval = 
dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
-            dag_run = create_dagrun(
-                dag,
-                state=DagRunState.RUNNING,
-                logical_date=DEFAULT_DATE,
-                run_type=DagRunType.SCHEDULED,
-                data_interval=data_interval,
-            )
-            task = dag.get_task(task_id="run_this_last")
-            dag_version_id = DagVersion.get_latest_version(dag.dag_id).id
-            ti = TaskInstance(task, run_id=dag_run.run_id, 
state=State.RUNNING, dag_version_id=dag_version_id)
-            ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6)
-            ti.start_date = timezone.utcnow() - timedelta(minutes=10)
-
-            # TODO: If there was an actual Relationship between TI and Job
-            # we wouldn't need this extra commit
-            session.add(ti)
-            session.flush()
-
-        scheduler_job = Job(executor=self.null_exec)
-        self.job_runner = SchedulerJobRunner(job=scheduler_job)
-
-        self.job_runner._find_and_purge_task_instances_without_heartbeats()
-
-        scheduler_job.executor.callback_sink.send.assert_called_once()
-
-        expected_failure_callback_requests = [
-            TaskCallbackRequest(
-                filepath=dag.relative_fileloc,
-                ti=ti,
-                
msg=str(self.job_runner._generate_task_instance_heartbeat_timeout_message_details(ti)),
-                bundle_name="testing",
-                bundle_version=dag_run.bundle_version,
-            )
-        ]
-        callback_requests = 
scheduler_job.executor.callback_sink.send.call_args.args
-        assert len(callback_requests) == 1
-        assert {
-            task_instances_without_heartbeats.ti.id
-            for task_instances_without_heartbeats in 
expected_failure_callback_requests
-        } == {result.ti.id for result in callback_requests}
-        expected_failure_callback_requests[0].ti = None
-        callback_requests[0].ti = None
-        assert expected_failure_callback_requests[0] == callback_requests[0]
-
     @mock.patch.object(settings, "USE_JOB_SCHEDULE", False)
     def run_scheduler_until_dagrun_terminal(self):
         """
@@ -6519,6 +6461,73 @@ class TestSchedulerJob:
         for i in range(100):
             assert f"it's duplicate {i}" in dag_warning.message
 
+    def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling 
heartbeat timeouts."""
+        with dag_maker(dag_id="test_dag", session=session):
+            EmptyOperator(task_id="test_task")
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        mock_executor = MagicMock()
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Create a task instance that appears to be running but hasn't 
heartbeat
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        ti.queued_by_job_id = scheduler_job.id
+        # Set last_heartbeat_at to a time that would trigger timeout
+        ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600)  # 
10 minutes ago
+        session.merge(ti)
+        session.commit()
+
+        # Run the heartbeat timeout check
+        self.job_runner._find_and_purge_task_instances_without_heartbeats()
+
+        # Verify TaskCallbackRequest was created with context_from_server
+        mock_executor.send_callback.assert_called_once()
+        callback_request = mock_executor.send_callback.call_args[0][0]
+
+        assert isinstance(callback_request, TaskCallbackRequest)
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
+    def test_scheduler_passes_context_from_server_on_task_failure(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling task 
failures."""
+        with dag_maker(dag_id="test_dag", session=session):
+            EmptyOperator(task_id="test_task", on_failure_callback=lambda: 
print("failure"))
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        # Create a task instance that's running
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        session.merge(ti)
+        session.commit()
+
+        # Mock the executor to simulate a task failure
+        mock_executor = MagicMock(spec=BaseExecutor)
+        mock_executor.has_task = mock.MagicMock(return_value=False)
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Simulate executor reporting task as failed
+        executor_event = {ti.key: (TaskInstanceState.FAILED, None)}
+        mock_executor.get_event_buffer.return_value = executor_event
+
+        # Process the executor events
+        self.job_runner._process_executor_events(mock_executor, session)
+
+        # Verify TaskCallbackRequest was created with context_from_server
+        mock_executor.send_callback.assert_called_once()
+        callback_request = mock_executor.send_callback.call_args[0][0]
+
+        assert isinstance(callback_request, TaskCallbackRequest)
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
 
 @pytest.mark.need_serialized_dag
 def test_schedule_dag_run_with_upstream_skip(dag_maker, session):
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index ac1e51d5e55..1dabd8c9022 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -518,7 +518,7 @@ class TIRunContext(BaseModel):
     next_method: Annotated[str | None, Field(title="Next Method")] = None
     next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next 
Kwargs")] = None
     xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To 
Clear")] = None
-    should_retry: Annotated[bool, Field(title="Should Retry")]
+    should_retry: Annotated[bool | None, Field(title="Should Retry")] = False
 
 
 class TITerminalStatePayload(BaseModel):


Reply via email to