jedcunningham commented on code in PR #53058:
URL: https://github.com/apache/airflow/pull/53058#discussion_r2193907922


##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task retrying",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_with_context_from_server(self, spy_agency):
+        """Test _execute_task_callbacks with context_from_server creates full 
context"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        # Create a mock DagRun

Review Comment:
   ```suggestion
   ```



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry)

Review Comment:
   ```suggestion
               BaseOperator(task_id="test_task", on_retry_callback=on_retry)
   ```



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -927,10 +928,13 @@ def process_executor_events(
                         bundle_version=ti.dag_version.bundle_version,
                         ti=ti,
                         msg=msg,
+                        context_from_server=TIRunContext(
+                            dag_run=ti.dag_run,
+                            max_tries=ti.max_tries,
+                        ),
                     )
                     executor.send_callback(request)
-                else:
-                    ti.handle_failure(error=msg, session=session)
+                ti.handle_failure(error=msg, session=session)

Review Comment:
   Was this intentional? Don't think so?



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task retrying",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_with_context_from_server(self, spy_agency):
+        """Test _execute_task_callbacks with context_from_server creates full 
context"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        # Create a mock DagRun
+        dag_run = DagRun(
+            dag_id="test_dag",
+            run_id="test_run",
+            logical_date=timezone.utcnow(),
+            start_date=timezone.utcnow(),
+            run_type="manual",
+        )
+        dag_run.run_after = timezone.utcnow()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        context_from_server = TIRunContext(
+            dag_run=dag_run,
+            max_tries=3,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+            context_from_server=context_from_server,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        # When context_from_server is provided, we get a full 
RuntimeTaskInstance context
+        assert "dag_run" in context_received
+        assert "logical_date" in context_received
+
+    def test_execute_task_callbacks_not_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks when request is not a failure 
callback"""
+        called = False
+
+        def on_failure(context):
+            nonlocal called
+            called = True
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.SUCCESS,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task succeeded",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.SUCCESS,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        # Should not call the callback since it's not a failure callback
+        assert called is False
+
+    def test_execute_task_callbacks_multiple_callbacks(self, spy_agency):
+        """Test _execute_task_callbacks with multiple callbacks"""
+        call_count = 0
+
+        def on_failure_1(context):
+            nonlocal call_count
+            call_count += 1
+
+        def on_failure_2(context):
+            nonlocal call_count
+            call_count += 1
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=[on_failure_1, on_failure_2])

Review Comment:
   ```suggestion
               BaseOperator(task_id="test_task", 
on_failure_callback=[on_failure_1, on_failure_2])
   ```



##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -6621,6 +6625,75 @@ def 
test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
         for i in range(100):
             assert f"it's duplicate {i}" in dag_warning.message
 
+    def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling 
heartbeat timeouts."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task")
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        mock_executor = MagicMock()
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Create a task instance that appears to be running but hasn't 
heartbeat
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        ti.queued_by_job_id = scheduler_job.id
+        # Set last_heartbeat_at to a time that would trigger timeout
+        ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600)  # 
10 minutes ago
+        session.merge(ti)
+        session.commit()
+
+        # Run the heartbeat timeout check
+        self.job_runner._find_and_purge_task_instances_without_heartbeats()
+
+        # Verify TaskCallbackRequest was created with context_from_server
+        mock_executor.send_callback.assert_called_once()
+        callback_request = mock_executor.send_callback.call_args[0][0]
+
+        assert isinstance(callback_request, TaskCallbackRequest)
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
+    def test_scheduler_passes_context_from_server_on_task_failure(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling task 
failures."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: 
print("failure"))

Review Comment:
   ```suggestion
               EmptyOperator(task_id="test_task", on_failure_callback=lambda: 
print("failure"))
   ```



##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -6621,6 +6625,75 @@ def 
test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
         for i in range(100):
             assert f"it's duplicate {i}" in dag_warning.message
 
+    def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling 
heartbeat timeouts."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task")
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        mock_executor = MagicMock()
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Create a task instance that appears to be running but hasn't 
heartbeat
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        ti.queued_by_job_id = scheduler_job.id
+        # Set last_heartbeat_at to a time that would trigger timeout
+        ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600)  # 
10 minutes ago
+        session.merge(ti)
+        session.commit()
+
+        # Run the heartbeat timeout check
+        self.job_runner._find_and_purge_task_instances_without_heartbeats()
+
+        # Verify TaskCallbackRequest was created with context_from_server
+        mock_executor.send_callback.assert_called_once()
+        callback_request = mock_executor.send_callback.call_args[0][0]
+
+        assert isinstance(callback_request, TaskCallbackRequest)
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
+    def test_scheduler_passes_context_from_server_on_task_failure(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling task 
failures."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: 
print("failure"))
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        # Create a task instance that's running
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        # ti.queued_by_job_id = 90000

Review Comment:
   ```suggestion
   ```



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task retrying",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_with_context_from_server(self, spy_agency):
+        """Test _execute_task_callbacks with context_from_server creates full 
context"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)

Review Comment:
   ```suggestion
               BaseOperator(task_id="test_task", on_failure_callback=on_failure)
   ```



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)

Review Comment:
   ```suggestion
               BaseOperator(task_id="test_task", on_failure_callback=on_failure)
   ```
   
   Shouldn't need this.



##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -6621,6 +6625,75 @@ def 
test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
         for i in range(100):
             assert f"it's duplicate {i}" in dag_warning.message
 
+    def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling 
heartbeat timeouts."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task")

Review Comment:
   ```suggestion
               EmptyOperator(task_id="test_task")
   ```



##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -6621,6 +6625,75 @@ def 
test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
         for i in range(100):
             assert f"it's duplicate {i}" in dag_warning.message
 
+    def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling 
heartbeat timeouts."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task")
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        mock_executor = MagicMock()
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Create a task instance that appears to be running but hasn't 
heartbeat
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        ti.queued_by_job_id = scheduler_job.id
+        # Set last_heartbeat_at to a time that would trigger timeout
+        ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600)  # 
10 minutes ago
+        session.merge(ti)
+        session.commit()
+
+        # Run the heartbeat timeout check
+        self.job_runner._find_and_purge_task_instances_without_heartbeats()
+
+        # Verify TaskCallbackRequest was created with context_from_server
+        mock_executor.send_callback.assert_called_once()
+        callback_request = mock_executor.send_callback.call_args[0][0]
+
+        assert isinstance(callback_request, TaskCallbackRequest)
+        assert callback_request.context_from_server is not None
+        assert callback_request.context_from_server.dag_run.logical_date == 
dag_run.logical_date
+        assert callback_request.context_from_server.max_tries == ti.max_tries
+
+    def test_scheduler_passes_context_from_server_on_task_failure(self, 
dag_maker, session):
+        """Test that scheduler passes context_from_server when handling task 
failures."""
+        with dag_maker(dag_id="test_dag", session=session):
+            _ = EmptyOperator(task_id="test_task", on_failure_callback=lambda: 
print("failure"))
+
+        dag_run = dag_maker.create_dagrun(run_id="test_run", 
state=DagRunState.RUNNING)
+
+        # Create a task instance that's running
+        ti = dag_run.get_task_instance(task_id="test_task")
+        ti.state = TaskInstanceState.RUNNING
+        # ti.queued_by_job_id = 90000
+        session.merge(ti)
+        session.commit()
+
+        # Mock the executor to simulate a task failure
+        mock_executor = MagicMock(spec=BaseExecutor)
+        mock_executor.has_task = mock.MagicMock(return_value=False)
+        scheduler_job = Job(executor=mock_executor)
+        self.job_runner = SchedulerJobRunner(scheduler_job)
+
+        # Simulate executor reporting task as failed
+        # from airflow.executors.base_executor import TaskInstanceStateType

Review Comment:
   ```suggestion
   ```



##########
airflow-core/tests/unit/dag_processing/test_processor.py:
##########
@@ -571,15 +536,284 @@ def fake_collect_dags(self, *args, **kwargs):
 
     spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
 
+    # Create a minimal TaskInstance for the request
+    ti_data = TIDataModel(
+        id=uuid.uuid4(),
+        dag_id="a",
+        task_id="b",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        dag_version_id=uuid.uuid4(),
+    )
+
     requests = [
         TaskCallbackRequest(
             filepath="A",
             msg="Message",
-            ti=None,
+            ti=ti_data,
             bundle_name="testing",
             bundle_version=None,
         )
     ]
-    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
+    _parse_file(
+        DagFileParseRequest(file="A", bundle_path="test", 
callback_requests=requests),
+        log=structlog.get_logger(),
+    )
 
     assert called is True
+
+
+class TestExecuteTaskCallbacks:
+    """Test the _execute_task_callbacks function"""
+
+    def test_execute_task_callbacks_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes failure callbacks"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_retry_callback(self, spy_agency):
+        """Test _execute_task_callbacks executes retry callbacks"""
+        called = False
+        context_received = None
+
+        def on_retry(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", on_retry_callback=on_retry)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+            state=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task retrying",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.UP_FOR_RETRY,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        assert context_received["dag"] == dag
+        assert "ti" in context_received
+
+    def test_execute_task_callbacks_with_context_from_server(self, spy_agency):
+        """Test _execute_task_callbacks with context_from_server creates full 
context"""
+        called = False
+        context_received = None
+
+        def on_failure(context):
+            nonlocal called, context_received
+            called = True
+            context_received = context
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        # Create a mock DagRun
+        dag_run = DagRun(
+            dag_id="test_dag",
+            run_id="test_run",
+            logical_date=timezone.utcnow(),
+            start_date=timezone.utcnow(),
+            run_type="manual",
+        )
+        dag_run.run_after = timezone.utcnow()
+
+        ti_data = TIDataModel(
+            id=uuid.uuid4(),
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            try_number=1,
+            dag_version_id=uuid.uuid4(),
+        )
+
+        context_from_server = TIRunContext(
+            dag_run=dag_run,
+            max_tries=3,
+        )
+
+        request = TaskCallbackRequest(
+            filepath="test.py",
+            msg="Task failed",
+            ti=ti_data,
+            bundle_name="testing",
+            bundle_version=None,
+            task_callback_type=TaskInstanceState.FAILED,
+            context_from_server=context_from_server,
+        )
+
+        log = structlog.get_logger()
+        _execute_task_callbacks(dagbag, request, log)
+
+        assert called is True
+        assert context_received is not None
+        # When context_from_server is provided, we get a full 
RuntimeTaskInstance context
+        assert "dag_run" in context_received
+        assert "logical_date" in context_received
+
+    def test_execute_task_callbacks_not_failure_callback(self, spy_agency):
+        """Test _execute_task_callbacks when request is not a failure 
callback"""
+        called = False
+
+        def on_failure(context):
+            nonlocal called
+            called = True
+
+        with DAG(dag_id="test_dag") as dag:
+            _ = BaseOperator(task_id="test_task", 
on_failure_callback=on_failure)

Review Comment:
   ```suggestion
               BaseOperator(task_id="test_task", on_failure_callback=on_failure)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to