This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 43c17c6730f fix: Always provide a relevant TI context in Dag callback 
(#61274)
43c17c6730f is described below

commit 43c17c6730f37fcc7c37302343616973434ae4c1
Author: Asquator <[email protected]>
AuthorDate: Sun Mar 8 19:46:48 2026 +0200

    fix: Always provide a relevant TI context in Dag callback (#61274)
    
    * Skeleton solution
    
    * Handle more cases
    
    * Refactored None handling
    
    * Fixed task name
    
    * Materialized tasks for double iteration
    
    * Use default argument in min
    
    * Use default argument in max
    
    * Use default argument in max
    
    * Changed end_date to start_date
    
    * Fix logic for on_success_callback
    
    * Handle the case where ti has no start_date
    
    * Logic for DAG testing
    
    * Cosmetics
    
    * Simplified logic
    
    * Pass last succeeded TI on success
    
    * Added template method
    
    * Cosmetics
    
    * More cosmetics
    
    * Lint
    
    * Added default keys for mypy check to succeed
    
    * Updated the docs
    
    * Fixed timezone aware datetime comparison
    
    * Clarified Dag run timeouts in the docs
    
    * Fix condition to look for failures across all relevant TIs
    
    * Mypy & ruff
    
    * Timezone awareness for deadlocked tasks case
    
    * Select from all tis for success callback
    
    * Improved deadlock callback logic
    
    * Mypy & ruff
    
    * Made the test more explicit
    
    * Format
    
    * Updated backward docs version to 3.1.9 according to milestone
    
    * Revert "Updated backward docs version to 3.1.9 according to milestone"
    
    This reverts commit 4defea7c29c2c5b9c2a82c53b506247f52d3221e.
    
    * Type hint
    
    Thank you
    
    Co-authored-by: Kaxil Naik <[email protected]>
    
    * Reapply "Updated backward docs version to 3.1.9 according to milestone"
    
    This reverts commit a31476a21162378ca1f228bdc439d148a3449db4.
    
    * Removed redundant iter() call
    
    * Removed redundant logic in deadlock case
    
    * Updated backward docs version to 3.2.0 according to milestone
    
    Co-authored-by: Elad Kalif <[email protected]>
    
    * Added a newsfragment entry
    
    * Removed trailing whitespaces
    
    * Shortened nesfragment to just one line
    
    ---------
    
    Co-authored-by: TheoS <[email protected]>
    Co-authored-by: Kaxil Naik <[email protected]>
    Co-authored-by: Elad Kalif <[email protected]>
---
 .../logging-monitoring/callbacks.rst               |  60 ++++--
 airflow-core/newsfragments/61274.improvement.rst   |   1 +
 .../src/airflow/jobs/scheduler_job_runner.py       |  26 ++-
 airflow-core/src/airflow/models/dagrun.py          | 149 +++++++-------
 airflow-core/tests/unit/jobs/test_scheduler_job.py |   4 +-
 airflow-core/tests/unit/models/test_dag.py         |   8 +-
 airflow-core/tests/unit/models/test_dagrun.py      | 213 ++++++++++-----------
 7 files changed, 234 insertions(+), 227 deletions(-)

diff --git 
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
 
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
index 040b278f097..25838377e5c 100644
--- 
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
+++ 
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
@@ -53,23 +53,47 @@ Callback Types
 
 There are six types of events that can trigger a callback:
 
-=========================================== 
================================================================
-Name                                        Description
-=========================================== 
================================================================
-``on_success_callback``                     Invoked when the :ref:`Dag 
succeeds <dag-run:dag-run-status>` or :ref:`task succeeds 
<concepts:task-instances>`.
-                                            Available at the Dag or task level.
-``on_failure_callback``                     Invoked when the task :ref:`fails 
<concepts:task-instances>`.
-                                            Available at the Dag or task level.
-``on_retry_callback``                       Invoked when the task is :ref:`up 
for retry <concepts:task-instances>`.
-                                            Available only at the task level.
-``on_execute_callback``                     Invoked right before the task 
begins executing.
-                                            Available only at the task level.
-``on_skipped_callback``                     Invoked when the task is 
:ref:`running <concepts:task-instances>` and  AirflowSkipException raised.
-                                            Explicitly it is NOT called if a 
task is not started to be executed because of a preceding branching
-                                            decision in the Dag or a trigger 
rule which causes execution to skip so that the task execution
-                                            is never scheduled.
-                                            Available only at the task level.
-=========================================== 
================================================================
+=========================================== 
======================================================================= 
=================
+Name                                        Description                        
                                     Availability
+=========================================== 
======================================================================= 
=================
+``on_success_callback``                     Invoked when the :ref:`Dag 
succeeds <dag-run:dag-run-status>`           Dag or Task
+                                            or :ref:`task succeeds 
<concepts:task-instances>`.
+``on_failure_callback``                     Invoked when the :ref:`Dag fails 
<dag-run:dag-run-status>`              Dag or Task
+                                            or task :ref:`fails 
<concepts:task-instances>`.
+``on_retry_callback``                       Invoked when the task is :ref:`up 
for retry <concepts:task-instances>`. Task
+``on_execute_callback``                     Invoked right before the task 
begins executing.                         Task
+``on_skipped_callback``                     Invoked when the task is 
:ref:`running <concepts:task-instances>`       Task
+                                            and AirflowSkipException raised. 
Explicitly it is NOT called if a task
+                                            is not started to be executed 
because of a preceding branching
+                                            decision in the Dag or a trigger 
rule which causes execution
+                                            to skip so that the task execution 
is never scheduled.
+=========================================== 
======================================================================= 
=================
+
+
+Context Mapping
+---------------
+
+A context mapping that contains runtime information about a task instance is 
passed to every callback.
+Full list of variables available in ``context`` are in :doc:`docs 
<../../templates-ref>` and `code 
<https://github.com/apache/airflow/blob/main/task-sdk/src/airflow/sdk/definitions/context.py>`_.
+
+
+Dag Callbacks
+^^^^^^^^^^^^^
+
+As the context mapping describes execution of a task instance, contexts passed 
to Dag callbacks will also contain task instance variables,
+and the task selected depends on the state of a Dag:
+
+#. On regular failure, the latest failed task is selected.
+#. On Dag run timeout, the latest started but not finished task is passed.
+#. If tasks are deadlocked, a task that should have run next but couldn't is 
passed.
+#. On success, the latest succeeded task is passed.
+
+It's not recommended to rely on task instance variables in Dag callbacks 
except for human analysis, as they reflect only partial information about the 
Dag's state.
+For example, a timeout may be caused by a number of stalling tasks, but only 
one will eventually be selected for context.
+
+.. note::
+    Before Airflow 3.2.0, the rules above did not apply and the task instance 
passed to Dag callback was not related to Dag state, rather being selected as 
the latest task in the Dag
+    lexicographically.
 
 
 Examples
@@ -109,8 +133,6 @@ Before each task begins to execute, the 
``task_execute_callback`` function will
         task3 = EmptyOperator(task_id="task3")
         task1 >> task2 >> task3
 
-Full list of variables available in ``context`` in :doc:`docs 
<../../templates-ref>` and `code 
<https://github.com/apache/airflow/blob/main/task-sdk/src/airflow/sdk/definitions/context.py>`_.
-
 
 Using Notifiers
 ^^^^^^^^^^^^^^^
diff --git a/airflow-core/newsfragments/61274.improvement.rst 
b/airflow-core/newsfragments/61274.improvement.rst
new file mode 100644
index 00000000000..ab9f4f4c0f0
--- /dev/null
+++ b/airflow-core/newsfragments/61274.improvement.rst
@@ -0,0 +1 @@
+Improve Dag callback relevancy by passing a context-relevant task instance 
based on the Dag's final state (e.g., the last failed, timed out, or successful 
task) instead of an arbitrary lexicographical selection.
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 367447968c2..9f9aa78cd74 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -58,7 +58,6 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun as
 from airflow.assets.evaluation import AssetEvaluator
 from airflow.callbacks.callback_requests import (
     DagCallbackRequest,
-    DagRunContext,
     EmailRequest,
     TaskCallbackRequest,
 )
@@ -2437,7 +2436,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 select(TI)
                 .where(TI.dag_id == dag_run.dag_id)
                 .where(TI.run_id == dag_run.run_id)
-                .where(TI.state.in_(State.unfinished))
+                .where(TI.state.in_(State.unfinished) | (TI.state.is_(None)))
+            ).all()
+            last_unfinished_ti = max(
+                unfinished_task_instances,
+                key=lambda ti: ti.start_date or 
timezone.make_aware(datetime.min),
+                default=None,
             )
             for task_instance in unfinished_task_instances:
                 task_instance.state = TaskInstanceState.SKIPPED
@@ -2465,18 +2469,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 self.log.error("DagRun %s was deleted unexpectedly", 
dag_run.id)
                 return None
             dag_run = dag_run_reloaded
-            callback_to_execute = DagCallbackRequest(
-                filepath=dag_model.relative_fileloc or "",
-                dag_id=dag.dag_id,
-                run_id=dag_run.run_id,
-                bundle_name=dag_model.bundle_name,
-                bundle_version=dag_run.bundle_version,
-                context_from_server=DagRunContext(
-                    dag_run=dag_run,
-                    last_ti=dag_run.get_last_ti(dag=dag, session=session),
-                ),
-                is_failure_callback=True,
-                msg="timed_out",
+            callback_to_execute = dag_run.produce_dag_callback(
+                dag=dag,
+                success=False,
+                relevant_ti=last_unfinished_ti,
+                reason="timed_out",
+                execute=False,
             )
 
             dag_run.notify_dagrun_state_changed(msg="timed_out")
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 0de1a784fa4..4bc47dddea7 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -28,7 +28,6 @@ from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, 
cast, overload
 from uuid import UUID
 
 import structlog
-from natsort import natsorted
 from sqlalchemy import (
     JSON,
     Enum,
@@ -1220,21 +1219,18 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="task_failure")
 
-            if execute_callbacks and dag.has_on_failure_callback:
-                self.handle_dag_callback(dag=cast("SDKDAG", dag), 
success=False, reason="task_failure")
-            elif dag.has_on_failure_callback:
-                callback = DagCallbackRequest(
-                    filepath=self.dag_model.relative_fileloc,
-                    dag_id=self.dag_id,
-                    run_id=self.run_id,
-                    bundle_name=self.dag_model.bundle_name,
-                    bundle_version=self.bundle_version,
-                    context_from_server=DagRunContext(
-                        dag_run=self,
-                        last_ti=self.get_last_ti(dag=dag, session=session),
-                    ),
-                    is_failure_callback=True,
-                    msg="task_failure",
+            if dag.has_on_failure_callback:
+                ti_causing_failure = max(
+                    (ti for ti in tis if ti.state == TaskInstanceState.FAILED),
+                    key=lambda ti: ti.end_date or 
timezone.make_aware(datetime.min),
+                    default=None,
+                )
+                callback = self.produce_dag_callback(
+                    dag=dag,
+                    success=False,
+                    relevant_ti=ti_causing_failure,
+                    reason="task_failure",
+                    execute=execute_callbacks,
                 )
 
             # Check if the max_consecutive_failed_dag_runs has been provided 
and not 0
@@ -1253,21 +1249,18 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.SUCCESS)
             self.notify_dagrun_state_changed(msg="success")
 
-            if execute_callbacks and dag.has_on_success_callback:
-                self.handle_dag_callback(dag=cast("SDKDAG", dag), 
success=True, reason="success")
-            elif dag.has_on_success_callback:
-                callback = DagCallbackRequest(
-                    filepath=self.dag_model.relative_fileloc,
-                    dag_id=self.dag_id,
-                    run_id=self.run_id,
-                    bundle_name=self.dag_model.bundle_name,
-                    bundle_version=self.bundle_version,
-                    context_from_server=DagRunContext(
-                        dag_run=self,
-                        last_ti=self.get_last_ti(dag=dag, session=session),
-                    ),
-                    is_failure_callback=False,
-                    msg="success",
+            if dag.has_on_success_callback:
+                last_succeeded_ti: TI | None = max(
+                    (ti for ti in tis if ti.state == 
TaskInstanceState.SUCCESS),
+                    key=lambda ti: ti.end_date or 
timezone.make_aware(datetime.min),
+                    default=None,
+                )
+                callback = self.produce_dag_callback(
+                    dag=dag,
+                    success=True,
+                    relevant_ti=last_succeeded_ti,
+                    reason="success",
+                    execute=execute_callbacks,
                 )
 
             if dag.deadline:
@@ -1288,25 +1281,23 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
 
-            if execute_callbacks and dag.has_on_failure_callback:
-                self.handle_dag_callback(
-                    dag=cast("SDKDAG", dag),
+            if dag.has_on_failure_callback:
+                finished_task_ids = {ti.task_id for ti in finished_tis}
+                blocking_ti = next(
+                    (
+                        ti
+                        for ti in unfinished.tis
+                        if ti.task
+                        and not 
(ti.task.get_direct_relative_ids(upstream=True).isdisjoint(finished_task_ids))
+                    ),
+                    None,
+                )
+                callback = self.produce_dag_callback(
+                    dag=dag,
                     success=False,
+                    relevant_ti=blocking_ti,
                     reason="all_tasks_deadlocked",
-                )
-            elif dag.has_on_failure_callback:
-                callback = DagCallbackRequest(
-                    filepath=self.dag_model.relative_fileloc,
-                    dag_id=self.dag_id,
-                    run_id=self.run_id,
-                    bundle_name=self.dag_model.bundle_name,
-                    bundle_version=self.bundle_version,
-                    context_from_server=DagRunContext(
-                        dag_run=self,
-                        last_ti=self.get_last_ti(dag=dag, session=session),
-                    ),
-                    is_failure_callback=True,
-                    msg="all_tasks_deadlocked",
+                    execute=execute_callbacks,
                 )
 
         # finally, if the leaves aren't done, the dag is still running
@@ -1417,27 +1408,40 @@ class DagRun(Base, LoggingMixin):
         # we can't get all the state changes on SchedulerJob,
         # or LocalTaskJob, so we don't want to "falsely advertise" we notify 
about that
 
-    @provide_session
-    def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION) 
-> TI | None:
-        """Get Last TI from the dagrun to build and pass Execution context 
object from server to then run callbacks."""
-        tis = self.get_task_instances(session=session)
-        # tis from a dagrun may not be a part of dag.partial_subset,
-        # since dag.partial_subset is a subset of the dag.
-        # This ensures that we will only use the accessible TI
-        # context for the callback.
-        if dag.partial:
-            tis = [ti for ti in tis if not ti.state == State.NONE]
-        # filter out removed tasks
-        tis = natsorted(
-            (ti for ti in tis if ti.state != TaskInstanceState.REMOVED),
-            key=lambda ti: ti.task_id,
+    def produce_dag_callback(
+        self,
+        dag: SerializedDAG,
+        success: bool = True,
+        relevant_ti: TI | None = None,
+        reason: str = "success",
+        execute: bool = False,
+    ) -> DagCallbackRequest | None:
+        """Create a callback request for the DAG, or execute the callbacks 
directly if instructed, and return None."""
+        if not execute:
+            return DagCallbackRequest(
+                filepath=self.dag_model.relative_fileloc,
+                dag_id=self.dag_id,
+                run_id=self.run_id,
+                bundle_name=self.dag_model.bundle_name,
+                bundle_version=self.bundle_version,
+                context_from_server=DagRunContext(
+                    dag_run=self,
+                    last_ti=relevant_ti,
+                ),
+                is_failure_callback=(not success),
+                msg=reason,
+            )
+        self.execute_dag_callbacks(
+            dag=cast("SDKDAG", dag),
+            success=success,
+            relevant_ti=relevant_ti,
+            reason=reason,
         )
-        if not tis:
-            return None
-        ti = tis[-1]  # get last TaskInstance of DagRun
-        return ti
+        return None
 
-    def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: 
str = "success"):
+    def execute_dag_callbacks(
+        self, dag: SDKDAG, success: bool = True, relevant_ti: TI | None = 
None, reason: str = "success"
+    ):
         """Only needed for `dag.test` where `execute_callbacks=True` is passed 
to `update_state`."""
         from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
             DagRun as DRDataModel,
@@ -1446,10 +1450,9 @@ class DagRun(Base, LoggingMixin):
         )
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        last_ti = self.get_last_ti(cast("SerializedDAG", dag))
-        if last_ti:
-            last_ti_model = TIDataModel.model_validate(last_ti, 
from_attributes=True)
-            task = dag.get_task(last_ti.task_id)
+        if relevant_ti:
+            last_ti_model = TIDataModel.model_validate(relevant_ti, 
from_attributes=True)
+            task = dag.get_task(relevant_ti.task_id)
 
             dag_run_data = DRDataModel(
                 dag_id=self.dag_id,
@@ -1472,12 +1475,12 @@ class DagRun(Base, LoggingMixin):
                 task=task,
                 _ti_context_from_server=TIRunContext(
                     dag_run=dag_run_data,
-                    max_tries=last_ti.max_tries,
+                    max_tries=relevant_ti.max_tries,
                     variables=[],
                     connections=[],
                     xcom_keys_to_clear=[],
                 ),
-                max_tries=last_ti.max_tries,
+                max_tries=relevant_ti.max_tries,
             )
             context = runtime_ti.get_template_context()
         else:
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index ae6388f1b8a..57536416cab 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -3524,7 +3524,7 @@ class TestSchedulerJob:
             bundle_version=orm_dag.bundle_version,
             context_from_server=DagRunContext(
                 dag_run=dr,
-                last_ti=dr.get_last_ti(dag, session),
+                last_ti=dr.get_task_instance("dummy", session),
             ),
             msg="timed_out",
         )
@@ -3720,7 +3720,7 @@ class TestSchedulerJob:
             bundle_version=None,
             context_from_server=DagRunContext(
                 dag_run=dr,
-                last_ti=dr.get_last_ti(dag, session),
+                last_ti=dr.get_task_instance("empty", session),
             ),
         )
 
diff --git a/airflow-core/tests/unit/models/test_dag.py 
b/airflow-core/tests/unit/models/test_dag.py
index e801fb57936..046c85ea799 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -922,8 +922,8 @@ class TestDag:
             )
 
             # should not raise any exception
-        dag_run.handle_dag_callback(dag=dag, success=False)
-        dag_run.handle_dag_callback(dag=dag, success=True)
+        dag_run.execute_dag_callbacks(dag=dag, success=False)
+        dag_run.execute_dag_callbacks(dag=dag, success=True)
 
         mock_stats.incr.assert_called_with(
             "dag.callback_exceptions",
@@ -963,8 +963,8 @@ class TestDag:
             assert dag_run.get_task_instance(task_removed.task_id).state == 
TaskInstanceState.REMOVED
 
             # should not raise any exception
-            dag_run.handle_dag_callback(dag=dag, success=False)
-            dag_run.handle_dag_callback(dag=dag, success=True)
+            dag_run.execute_dag_callbacks(dag=dag, success=False)
+            dag_run.execute_dag_callbacks(dag=dag, success=True)
 
     @time_machine.travel(timezone.datetime(2025, 11, 11))
     @pytest.mark.parametrize(("catchup", "expected_next_dagrun"), [(True, 
DEFAULT_DATE), (False, None)])
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 3e62554901c..f3de13422fa 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -23,7 +23,7 @@ from collections.abc import Mapping
 from functools import reduce
 from typing import TYPE_CHECKING
 from unittest import mock
-from unittest.mock import call
+from unittest.mock import ANY, call
 
 import pendulum
 import pytest
@@ -434,9 +434,16 @@ class TestDagRun:
         }
 
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
-        with mock.patch.object(dag_run, "handle_dag_callback") as 
handle_dag_callback:
+        with mock.patch.object(dag_run, "execute_dag_callbacks") as 
execute_dag_callbacks:
             _, callback = dag_run.update_state()
-        assert handle_dag_callback.mock_calls == [mock.call(dag=dag, 
success=True, reason="success")]
+        assert execute_dag_callbacks.mock_calls == [
+            mock.call(dag=dag, success=True, relevant_ti=ANY, reason="success")
+        ]
+        # Make sure the correct TI is passed on success
+        call_args = execute_dag_callbacks.call_args
+        ti_passed = call_args.kwargs["relevant_ti"]
+        assert ti_passed.task_id == "test_state_succeeded2"
+
         assert dag_run.state == DagRunState.SUCCESS
         # Callbacks are not added until handle_callback = False is passed to 
dag_run.update_state()
         assert callback is None
@@ -461,13 +468,62 @@ class TestDagRun:
         dag_task1.set_downstream(dag_task2)
 
         dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
-        with mock.patch.object(dag_run, "handle_dag_callback") as 
handle_dag_callback:
+        with mock.patch.object(dag_run, "execute_dag_callbacks") as 
execute_dag_callbacks:
             _, callback = dag_run.update_state()
-        assert handle_dag_callback.mock_calls == [mock.call(dag=dag, 
success=False, reason="task_failure")]
+        assert execute_dag_callbacks.mock_calls == [
+            mock.call(dag=dag, success=False, relevant_ti=ANY, 
reason="task_failure")
+        ]
+        # Make sure the correct TI is passed on failure
+        call_args = execute_dag_callbacks.call_args
+        ti_passed = call_args.kwargs["relevant_ti"]
+        assert ti_passed.task_id == "test_state_failed2"
+
         assert dag_run.state == DagRunState.FAILED
         # Callbacks are not added until handle_callback = False is passed to 
dag_run.update_state()
         assert callback is None
 
+    def test_dagrun_failure_callback_on_tasks_deadlocked(self, dag_maker, 
session):
+        def on_failure_callable(context):
+            assert context["dag_run"].dag_id == 
"test_dagrun_failure_callback_on_tasks_deadlocked"
+
+        with dag_maker(
+            dag_id="test_dagrun_failure_callback_on_tasks_deadlocked",
+            schedule=datetime.timedelta(days=1),
+            start_date=datetime.datetime(2017, 1, 1),
+            on_failure_callback=on_failure_callable,
+        ):
+            up = EmptyOperator(task_id="upstream")
+            middle = EmptyOperator(task_id="wrong")
+            down = EmptyOperator(task_id="downstream")
+
+            middle.trigger_rule = TriggerRule.ONE_FAILED
+            middle.set_upstream(up)
+            middle.set_downstream(down)
+
+        dr = dag_maker.create_dagrun()
+
+        ti_up: TI = dr.get_task_instance(task_id=up.task_id, session=session)
+        ti_middle: TI = dr.get_task_instance(task_id=middle.task_id, 
session=session)
+        ti_up.set_state(state=TaskInstanceState.SUCCESS, session=session)
+        ti_middle.set_state(state=None, session=session)
+        ti_middle.task.trigger_rule = "invalid"
+
+        serialized_dag = dr.get_dag()
+
+        with mock.patch.object(dr, "execute_dag_callbacks") as 
execute_dag_callbacks:
+            _, callback = dr.update_state(execute_callbacks=True)
+        assert execute_dag_callbacks.mock_calls == [
+            mock.call(dag=serialized_dag, success=False, 
relevant_ti=ti_middle, reason="all_tasks_deadlocked")
+        ]
+        # Make sure the correct TI is passed on deadlock
+        call_args = execute_dag_callbacks.call_args
+        ti_passed = call_args.kwargs["relevant_ti"]
+        assert ti_passed.task_id == "wrong"
+
+        assert dr.state == DagRunState.FAILED
+        # Callbacks is None as execute_callbacks=True
+        assert callback is None
+
     def test_on_success_callback_when_task_skipped(self, session, 
testing_dag_bundle):
         mock_on_success = mock.MagicMock()
         mock_on_success.__name__ = "mock_on_success"
@@ -682,7 +738,7 @@ class TestDagRun:
             bundle_version=None,
             context_from_server=DagRunContext(
                 dag_run=dag_run,
-                last_ti=dag_run.get_last_ti(dag, session),
+                
last_ti=dag_run.get_task_instance(task_id="test_state_succeeded2"),
             ),
             msg="success",
         )
@@ -732,7 +788,7 @@ class TestDagRun:
             bundle_version=None,
             context_from_server=DagRunContext(
                 dag_run=dag_run,
-                last_ti=dag_run.get_last_ti(dag, session),
+                
last_ti=dag_run.get_task_instance(task_id="test_state_failed2"),
             ),
         )
 
@@ -1339,11 +1395,16 @@ class TestDagRun:
         )
         dag_run.dag = scheduler_dag
 
-        with mock.patch.object(dag_run, "handle_dag_callback") as 
handle_dag_callback:
+        with mock.patch.object(dag_run, "execute_dag_callbacks") as 
execute_dag_callbacks:
             _, callback = dag_run.update_state()
-        assert handle_dag_callback.mock_calls == [
-            mock.call(dag=scheduler_dag, success=True, reason="success")
+        assert execute_dag_callbacks.mock_calls == [
+            mock.call(dag=scheduler_dag, success=True, relevant_ti=ANY, 
reason="success")
         ]
+        # Make sure the correct TI is passed on success
+        call_args = execute_dag_callbacks.call_args
+        ti_passed = call_args.kwargs["relevant_ti"]
+        assert ti_passed.task_id == "task_2"
+
         assert dag_run.state == DagRunState.SUCCESS
         # Callbacks are not added until handle_callback = False is passed to 
dag_run.update_state()
         assert callback is None
@@ -3107,103 +3168,11 @@ def test_teardown_and_fail_fast(dag_maker):
     }
 
 
-class TestDagRunGetLastTi:
-    def test_get_last_ti_with_multiple_tis(self, dag_maker, session):
-        """Test get_last_ti returns the last TI (first created) when multiple 
TIs exist"""
-        with dag_maker("test_dag", session=session) as dag:
-            BashOperator(task_id="task1", bash_command="echo 1")
-            BashOperator(task_id="task2", bash_command="echo 2")
-            BashOperator(task_id="task3", bash_command="echo 3")
-
-        dr = dag_maker.create_dagrun()
-
-        tis = dr.get_task_instances(session=session)
-        assert len(tis) == 3
-
-        # Mark some TIs with different states
-        tis[0].state = TaskInstanceState.SUCCESS
-        tis[1].state = TaskInstanceState.FAILED
-        tis[2].state = TaskInstanceState.RUNNING
-        session.commit()
-
-        last_ti = dr.get_last_ti(dag, session=session)
-
-        # Should return the last TI in the list (index -1)
-        assert last_ti is not None
-        assert last_ti == tis[-1]
-        assert last_ti.task_id == "task3"
-
-    def test_get_last_ti_filters_none_state_in_partial_dag(self, dag_maker, 
session):
-        """Test get_last_ti filters out NONE state TIs when dag is partial"""
-        with dag_maker("test_dag", session=session) as dag:
-            BashOperator(task_id="task1", bash_command="echo 1")
-            BashOperator(task_id="task2", bash_command="echo 2")
-
-        dr = dag_maker.create_dagrun()
-
-        dag.partial = True
-
-        # Create task instances with different states
-        tis = dr.get_task_instances(session=session)
-        tis[0].state = State.NONE  # Should be filtered out in partial DAG
-        tis[1].state = TaskInstanceState.RUNNING
-        session.commit()
-
-        last_ti = dr.get_last_ti(dag, session=session)
-
-        assert last_ti is not None
-        assert last_ti.state != State.NONE
-        assert last_ti.task_id == "task2"
-
-    def test_get_last_ti_filters_removed_tasks(self, dag_maker, session):
-        """Test get_last_ti filters out REMOVED task instances"""
-        with dag_maker("test_dag", session=session) as dag:
-            BashOperator(task_id="task1", bash_command="echo 1")
-            BashOperator(task_id="task2", bash_command="echo 2")
-            BashOperator(task_id="task3", bash_command="echo 3")
-
-        dr = dag_maker.create_dagrun()
-
-        tis = dr.get_task_instances(session=session)
-        assert len(tis) == 3
-
-        ti_by_id = {ti.task_id: ti for ti in tis}
-
-        # Mark some TIs as removed
-        ti_by_id["task1"].state = TaskInstanceState.REMOVED
-        ti_by_id["task2"].state = TaskInstanceState.REMOVED
-        ti_by_id["task3"].state = TaskInstanceState.SUCCESS
-        session.commit()
-
-        last_ti = dr.get_last_ti(dag, session=session)
-
-        # Should return the TI that is not REMOVED
-        assert last_ti is not None
-        assert last_ti.state != TaskInstanceState.REMOVED
-        assert last_ti.task_id == "task3"
-
-    def test_get_last_ti_with_single_ti(self, dag_maker, session):
-        """Test get_last_ti works with single task instance"""
-        with dag_maker("test_dag", session=session) as dag:
-            BashOperator(task_id="single_task", bash_command="echo 1")
-
-        dr = dag_maker.create_dagrun()
-
-        tis = dr.get_task_instances(session=session)
-        assert len(tis) == 1
-
-        last_ti = dr.get_last_ti(dag, session=session)
-
-        assert last_ti is not None
-        assert last_ti == tis[0]
-        assert last_ti.task_id == "single_task"
-
-
 class TestDagRunHandleDagCallback:
-    """Test the handle_dag_callback method (only uses in dag.test)."""
+    """Test the execute_dag_callbacks method (only uses in dag.test)."""
 
-    def test_handle_dag_callback_success(self, dag_maker, session):
-        """Test handle_dag_callback executes success callback with 
RuntimeTaskInstance context"""
+    def test_execute_dag_callbacks_success(self, dag_maker, session):
+        """Test execute_dag_callbacks executes success callback with 
RuntimeTaskInstance context"""
         called = False
         context_received = None
 
@@ -3220,7 +3189,9 @@ class TestDagRunHandleDagCallback:
         dag.on_success_callback = on_success
         dag.has_on_success_callback = True
 
-        dr.handle_dag_callback(dag, success=True, reason="test_success")
+        dr.execute_dag_callbacks(
+            dag, success=True, relevant_ti=dr.get_task_instance("test_task"), 
reason="test_success"
+        )
 
         assert called is True
         assert context_received is not None
@@ -3232,8 +3203,8 @@ class TestDagRunHandleDagCallback:
         assert "ts" in context_received
         assert "params" in context_received
 
-    def test_handle_dag_callback_failure(self, dag_maker, session):
-        """Test handle_dag_callback executes failure callback with 
RuntimeTaskInstance context"""
+    def test_execute_dag_callbacks_failure(self, dag_maker, session):
+        """Test execute_dag_callbacks executes failure callback with 
RuntimeTaskInstance context"""
         called = False
         context_received = None
 
@@ -3250,7 +3221,9 @@ class TestDagRunHandleDagCallback:
         dag.on_failure_callback = on_failure
         dag.has_on_failure_callback = True
 
-        dr.handle_dag_callback(dag, success=False, reason="test_failure")
+        dr.execute_dag_callbacks(
+            dag, success=False, relevant_ti=dr.get_task_instance("test_task"), 
reason="test_failure"
+        )
 
         assert called is True
         assert context_received is not None
@@ -3262,8 +3235,8 @@ class TestDagRunHandleDagCallback:
         assert "ts" in context_received
         assert "params" in context_received
 
-    def test_handle_dag_callback_multiple_callbacks(self, dag_maker, session):
-        """Test handle_dag_callback executes multiple callbacks"""
+    def test_execute_dag_callbacks_multiple_callbacks(self, dag_maker, 
session):
+        """Test execute_dag_callbacks executes multiple callbacks"""
         call_count = 0
 
         def on_failure_1(context):
@@ -3282,12 +3255,17 @@ class TestDagRunHandleDagCallback:
         dag.on_failure_callback = [on_failure_1, on_failure_2]
         dag.has_on_failure_callback = True
 
-        dr.handle_dag_callback(dag, success=False, reason="test_failure")
+        dr.execute_dag_callbacks(
+            dag,
+            success=False,
+            relevant_ti=dr.get_task_instance("test_task"),
+            reason="test_failure",
+        )
 
         assert call_count == 2
 
-    def test_handle_dag_callback_context_has_correct_ti_info(self, dag_maker, 
session):
-        """Test handle_dag_callback context contains correct task instance 
information"""
+    def test_execute_dag_callbacks_context_has_correct_ti_info(self, 
dag_maker, session):
+        """Test execute_dag_callbacks context contains correct task instance 
information"""
         context_received = None
 
         def on_failure(context):
@@ -3302,7 +3280,12 @@ class TestDagRunHandleDagCallback:
         dag.on_failure_callback = on_failure
         dag.has_on_failure_callback = True
 
-        dr.handle_dag_callback(dag, success=False, reason="test_failure")
+        dr.execute_dag_callbacks(
+            dag,
+            success=False,
+            relevant_ti=dr.get_task_instance("test_task"),
+            reason="test_failure",
+        )
 
         assert context_received is not None
         # Check that context contains correct task info


Reply via email to