This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 43c17c6730f fix: Always provide a relevant TI context in Dag callback
(#61274)
43c17c6730f is described below
commit 43c17c6730f37fcc7c37302343616973434ae4c1
Author: Asquator <[email protected]>
AuthorDate: Sun Mar 8 19:46:48 2026 +0200
fix: Always provide a relevant TI context in Dag callback (#61274)
* Skeleton solution
* Handle more cases
* Refactored None handling
* Fixed task name
* Materialized tasks for double iteration
* Use default argument in min
* Use default argument in max
* Use default argument in max
* Changed end_date to start_date
* Fix logic for on_success_callback
* Handle the case where ti has no start_date
* Logic for DAG testing
* Cosmetics
* Simplified logic
* Pass last succeeded TI on success
* Added template method
* Cosmetics
* More cosmetics
* Lint
* Added default keys for mypy check to succeed
* Updated the docs
* Fixed timezone aware datetime comparison
* Clarified Dag run timeouts in the docs
* Fix condition to look for failures across all relevant TIs
* Mypy & ruff
* Timezone awareness for deadlocked tasks case
* Select from all tis for success callback
* Improved deadlock callback logic
* Mypy & ruff
* Made the test more explicit
* Format
* Updated backward docs version to 3.1.9 according to milestone
* Revert "Updated backward docs version to 3.1.9 according to milestone"
This reverts commit 4defea7c29c2c5b9c2a82c53b506247f52d3221e.
* Type hint
Thank you
Co-authored-by: Kaxil Naik <[email protected]>
* Reapply "Updated backward docs version to 3.1.9 according to milestone"
This reverts commit a31476a21162378ca1f228bdc439d148a3449db4.
* Removed redundant iter() call
* Removed redundant logic in deadlock case
* Updated backward docs version to 3.2.0 according to milestone
Co-authored-by: Elad Kalif <[email protected]>
* Added a newsfragment entry
* Removed trailing whitespaces
* Shortened nesfragment to just one line
---------
Co-authored-by: TheoS <[email protected]>
Co-authored-by: Kaxil Naik <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
---
.../logging-monitoring/callbacks.rst | 60 ++++--
airflow-core/newsfragments/61274.improvement.rst | 1 +
.../src/airflow/jobs/scheduler_job_runner.py | 26 ++-
airflow-core/src/airflow/models/dagrun.py | 149 +++++++-------
airflow-core/tests/unit/jobs/test_scheduler_job.py | 4 +-
airflow-core/tests/unit/models/test_dag.py | 8 +-
airflow-core/tests/unit/models/test_dagrun.py | 213 ++++++++++-----------
7 files changed, 234 insertions(+), 227 deletions(-)
diff --git
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
index 040b278f097..25838377e5c 100644
---
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
+++
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
@@ -53,23 +53,47 @@ Callback Types
There are six types of events that can trigger a callback:
-===========================================
================================================================
-Name Description
-===========================================
================================================================
-``on_success_callback`` Invoked when the :ref:`Dag
succeeds <dag-run:dag-run-status>` or :ref:`task succeeds
<concepts:task-instances>`.
- Available at the Dag or task level.
-``on_failure_callback`` Invoked when the task :ref:`fails
<concepts:task-instances>`.
- Available at the Dag or task level.
-``on_retry_callback`` Invoked when the task is :ref:`up
for retry <concepts:task-instances>`.
- Available only at the task level.
-``on_execute_callback`` Invoked right before the task
begins executing.
- Available only at the task level.
-``on_skipped_callback`` Invoked when the task is
:ref:`running <concepts:task-instances>` and AirflowSkipException raised.
- Explicitly it is NOT called if a
task is not started to be executed because of a preceding branching
- decision in the Dag or a trigger
rule which causes execution to skip so that the task execution
- is never scheduled.
- Available only at the task level.
-===========================================
================================================================
+===========================================
=======================================================================
=================
+Name Description
Availability
+===========================================
=======================================================================
=================
+``on_success_callback`` Invoked when the :ref:`Dag
succeeds <dag-run:dag-run-status>` Dag or Task
+ or :ref:`task succeeds
<concepts:task-instances>`.
+``on_failure_callback`` Invoked when the :ref:`Dag fails
<dag-run:dag-run-status>` Dag or Task
+ or task :ref:`fails
<concepts:task-instances>`.
+``on_retry_callback`` Invoked when the task is :ref:`up
for retry <concepts:task-instances>`. Task
+``on_execute_callback`` Invoked right before the task
begins executing. Task
+``on_skipped_callback`` Invoked when the task is
:ref:`running <concepts:task-instances>` Task
+ and AirflowSkipException raised.
Explicitly it is NOT called if a task
+ is not started to be executed
because of a preceding branching
+ decision in the Dag or a trigger
rule which causes execution
+ to skip so that the task execution
is never scheduled.
+===========================================
=======================================================================
=================
+
+
+Context Mapping
+---------------
+
+A context mapping that contains runtime information about a task instance is
passed to every callback.
+Full list of variables available in ``context`` are in :doc:`docs
<../../templates-ref>` and `code
<https://github.com/apache/airflow/blob/main/task-sdk/src/airflow/sdk/definitions/context.py>`_.
+
+
+Dag Callbacks
+^^^^^^^^^^^^^
+
+As the context mapping describes execution of a task instance, contexts passed
to Dag callbacks will also contain task instance variables,
+and the task selected depends on the state of a Dag:
+
+#. On regular failure, the latest failed task is selected.
+#. On Dag run timeout, the latest started but not finished task is passed.
+#. If tasks are deadlocked, a task that should have run next but couldn't is
passed.
+#. On success, the latest succeeded task is passed.
+
+It's not recommended to rely on task instance variables in Dag callbacks
except for human analysis, as they reflect only partial information about the
Dag's state.
+For example, a timeout may be caused by a number of stalling tasks, but only
one will eventually be selected for context.
+
+.. note::
+ Before Airflow 3.2.0, the rules above did not apply and the task instance
passed to Dag callback was not related to Dag state, rather being selected as
the latest task in the Dag
+ lexicographically.
Examples
@@ -109,8 +133,6 @@ Before each task begins to execute, the
``task_execute_callback`` function will
task3 = EmptyOperator(task_id="task3")
task1 >> task2 >> task3
-Full list of variables available in ``context`` in :doc:`docs
<../../templates-ref>` and `code
<https://github.com/apache/airflow/blob/main/task-sdk/src/airflow/sdk/definitions/context.py>`_.
-
Using Notifiers
^^^^^^^^^^^^^^^
diff --git a/airflow-core/newsfragments/61274.improvement.rst
b/airflow-core/newsfragments/61274.improvement.rst
new file mode 100644
index 00000000000..ab9f4f4c0f0
--- /dev/null
+++ b/airflow-core/newsfragments/61274.improvement.rst
@@ -0,0 +1 @@
+Improve Dag callback relevancy by passing a context-relevant task instance
based on the Dag's final state (e.g., the last failed, timed out, or successful
task) instead of an arbitrary lexicographical selection.
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 367447968c2..9f9aa78cd74 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -58,7 +58,6 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun as
from airflow.assets.evaluation import AssetEvaluator
from airflow.callbacks.callback_requests import (
DagCallbackRequest,
- DagRunContext,
EmailRequest,
TaskCallbackRequest,
)
@@ -2437,7 +2436,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
select(TI)
.where(TI.dag_id == dag_run.dag_id)
.where(TI.run_id == dag_run.run_id)
- .where(TI.state.in_(State.unfinished))
+ .where(TI.state.in_(State.unfinished) | (TI.state.is_(None)))
+ ).all()
+ last_unfinished_ti = max(
+ unfinished_task_instances,
+ key=lambda ti: ti.start_date or
timezone.make_aware(datetime.min),
+ default=None,
)
for task_instance in unfinished_task_instances:
task_instance.state = TaskInstanceState.SKIPPED
@@ -2465,18 +2469,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self.log.error("DagRun %s was deleted unexpectedly",
dag_run.id)
return None
dag_run = dag_run_reloaded
- callback_to_execute = DagCallbackRequest(
- filepath=dag_model.relative_fileloc or "",
- dag_id=dag.dag_id,
- run_id=dag_run.run_id,
- bundle_name=dag_model.bundle_name,
- bundle_version=dag_run.bundle_version,
- context_from_server=DagRunContext(
- dag_run=dag_run,
- last_ti=dag_run.get_last_ti(dag=dag, session=session),
- ),
- is_failure_callback=True,
- msg="timed_out",
+ callback_to_execute = dag_run.produce_dag_callback(
+ dag=dag,
+ success=False,
+ relevant_ti=last_unfinished_ti,
+ reason="timed_out",
+ execute=False,
)
dag_run.notify_dagrun_state_changed(msg="timed_out")
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 0de1a784fa4..4bc47dddea7 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -28,7 +28,6 @@ from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar,
cast, overload
from uuid import UUID
import structlog
-from natsort import natsorted
from sqlalchemy import (
JSON,
Enum,
@@ -1220,21 +1219,18 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="task_failure")
- if execute_callbacks and dag.has_on_failure_callback:
- self.handle_dag_callback(dag=cast("SDKDAG", dag),
success=False, reason="task_failure")
- elif dag.has_on_failure_callback:
- callback = DagCallbackRequest(
- filepath=self.dag_model.relative_fileloc,
- dag_id=self.dag_id,
- run_id=self.run_id,
- bundle_name=self.dag_model.bundle_name,
- bundle_version=self.bundle_version,
- context_from_server=DagRunContext(
- dag_run=self,
- last_ti=self.get_last_ti(dag=dag, session=session),
- ),
- is_failure_callback=True,
- msg="task_failure",
+ if dag.has_on_failure_callback:
+ ti_causing_failure = max(
+ (ti for ti in tis if ti.state == TaskInstanceState.FAILED),
+ key=lambda ti: ti.end_date or
timezone.make_aware(datetime.min),
+ default=None,
+ )
+ callback = self.produce_dag_callback(
+ dag=dag,
+ success=False,
+ relevant_ti=ti_causing_failure,
+ reason="task_failure",
+ execute=execute_callbacks,
)
# Check if the max_consecutive_failed_dag_runs has been provided
and not 0
@@ -1253,21 +1249,18 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.SUCCESS)
self.notify_dagrun_state_changed(msg="success")
- if execute_callbacks and dag.has_on_success_callback:
- self.handle_dag_callback(dag=cast("SDKDAG", dag),
success=True, reason="success")
- elif dag.has_on_success_callback:
- callback = DagCallbackRequest(
- filepath=self.dag_model.relative_fileloc,
- dag_id=self.dag_id,
- run_id=self.run_id,
- bundle_name=self.dag_model.bundle_name,
- bundle_version=self.bundle_version,
- context_from_server=DagRunContext(
- dag_run=self,
- last_ti=self.get_last_ti(dag=dag, session=session),
- ),
- is_failure_callback=False,
- msg="success",
+ if dag.has_on_success_callback:
+ last_succeeded_ti: TI | None = max(
+ (ti for ti in tis if ti.state ==
TaskInstanceState.SUCCESS),
+ key=lambda ti: ti.end_date or
timezone.make_aware(datetime.min),
+ default=None,
+ )
+ callback = self.produce_dag_callback(
+ dag=dag,
+ success=True,
+ relevant_ti=last_succeeded_ti,
+ reason="success",
+ execute=execute_callbacks,
)
if dag.deadline:
@@ -1288,25 +1281,23 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
- if execute_callbacks and dag.has_on_failure_callback:
- self.handle_dag_callback(
- dag=cast("SDKDAG", dag),
+ if dag.has_on_failure_callback:
+ finished_task_ids = {ti.task_id for ti in finished_tis}
+ blocking_ti = next(
+ (
+ ti
+ for ti in unfinished.tis
+ if ti.task
+ and not
(ti.task.get_direct_relative_ids(upstream=True).isdisjoint(finished_task_ids))
+ ),
+ None,
+ )
+ callback = self.produce_dag_callback(
+ dag=dag,
success=False,
+ relevant_ti=blocking_ti,
reason="all_tasks_deadlocked",
- )
- elif dag.has_on_failure_callback:
- callback = DagCallbackRequest(
- filepath=self.dag_model.relative_fileloc,
- dag_id=self.dag_id,
- run_id=self.run_id,
- bundle_name=self.dag_model.bundle_name,
- bundle_version=self.bundle_version,
- context_from_server=DagRunContext(
- dag_run=self,
- last_ti=self.get_last_ti(dag=dag, session=session),
- ),
- is_failure_callback=True,
- msg="all_tasks_deadlocked",
+ execute=execute_callbacks,
)
# finally, if the leaves aren't done, the dag is still running
@@ -1417,27 +1408,40 @@ class DagRun(Base, LoggingMixin):
# we can't get all the state changes on SchedulerJob,
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
- @provide_session
- def get_last_ti(self, dag: SerializedDAG, session: Session = NEW_SESSION)
-> TI | None:
- """Get Last TI from the dagrun to build and pass Execution context
object from server to then run callbacks."""
- tis = self.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
- if dag.partial:
- tis = [ti for ti in tis if not ti.state == State.NONE]
- # filter out removed tasks
- tis = natsorted(
- (ti for ti in tis if ti.state != TaskInstanceState.REMOVED),
- key=lambda ti: ti.task_id,
+ def produce_dag_callback(
+ self,
+ dag: SerializedDAG,
+ success: bool = True,
+ relevant_ti: TI | None = None,
+ reason: str = "success",
+ execute: bool = False,
+ ) -> DagCallbackRequest | None:
+ """Create a callback request for the DAG, or execute the callbacks
directly if instructed, and return None."""
+ if not execute:
+ return DagCallbackRequest(
+ filepath=self.dag_model.relative_fileloc,
+ dag_id=self.dag_id,
+ run_id=self.run_id,
+ bundle_name=self.dag_model.bundle_name,
+ bundle_version=self.bundle_version,
+ context_from_server=DagRunContext(
+ dag_run=self,
+ last_ti=relevant_ti,
+ ),
+ is_failure_callback=(not success),
+ msg=reason,
+ )
+ self.execute_dag_callbacks(
+ dag=cast("SDKDAG", dag),
+ success=success,
+ relevant_ti=relevant_ti,
+ reason=reason,
)
- if not tis:
- return None
- ti = tis[-1] # get last TaskInstance of DagRun
- return ti
+ return None
- def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason:
str = "success"):
+ def execute_dag_callbacks(
+ self, dag: SDKDAG, success: bool = True, relevant_ti: TI | None =
None, reason: str = "success"
+ ):
"""Only needed for `dag.test` where `execute_callbacks=True` is passed
to `update_state`."""
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun as DRDataModel,
@@ -1446,10 +1450,9 @@ class DagRun(Base, LoggingMixin):
)
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- last_ti = self.get_last_ti(cast("SerializedDAG", dag))
- if last_ti:
- last_ti_model = TIDataModel.model_validate(last_ti,
from_attributes=True)
- task = dag.get_task(last_ti.task_id)
+ if relevant_ti:
+ last_ti_model = TIDataModel.model_validate(relevant_ti,
from_attributes=True)
+ task = dag.get_task(relevant_ti.task_id)
dag_run_data = DRDataModel(
dag_id=self.dag_id,
@@ -1472,12 +1475,12 @@ class DagRun(Base, LoggingMixin):
task=task,
_ti_context_from_server=TIRunContext(
dag_run=dag_run_data,
- max_tries=last_ti.max_tries,
+ max_tries=relevant_ti.max_tries,
variables=[],
connections=[],
xcom_keys_to_clear=[],
),
- max_tries=last_ti.max_tries,
+ max_tries=relevant_ti.max_tries,
)
context = runtime_ti.get_template_context()
else:
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index ae6388f1b8a..57536416cab 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -3524,7 +3524,7 @@ class TestSchedulerJob:
bundle_version=orm_dag.bundle_version,
context_from_server=DagRunContext(
dag_run=dr,
- last_ti=dr.get_last_ti(dag, session),
+ last_ti=dr.get_task_instance("dummy", session),
),
msg="timed_out",
)
@@ -3720,7 +3720,7 @@ class TestSchedulerJob:
bundle_version=None,
context_from_server=DagRunContext(
dag_run=dr,
- last_ti=dr.get_last_ti(dag, session),
+ last_ti=dr.get_task_instance("empty", session),
),
)
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index e801fb57936..046c85ea799 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -922,8 +922,8 @@ class TestDag:
)
# should not raise any exception
- dag_run.handle_dag_callback(dag=dag, success=False)
- dag_run.handle_dag_callback(dag=dag, success=True)
+ dag_run.execute_dag_callbacks(dag=dag, success=False)
+ dag_run.execute_dag_callbacks(dag=dag, success=True)
mock_stats.incr.assert_called_with(
"dag.callback_exceptions",
@@ -963,8 +963,8 @@ class TestDag:
assert dag_run.get_task_instance(task_removed.task_id).state ==
TaskInstanceState.REMOVED
# should not raise any exception
- dag_run.handle_dag_callback(dag=dag, success=False)
- dag_run.handle_dag_callback(dag=dag, success=True)
+ dag_run.execute_dag_callbacks(dag=dag, success=False)
+ dag_run.execute_dag_callbacks(dag=dag, success=True)
@time_machine.travel(timezone.datetime(2025, 11, 11))
@pytest.mark.parametrize(("catchup", "expected_next_dagrun"), [(True,
DEFAULT_DATE), (False, None)])
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 3e62554901c..f3de13422fa 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -23,7 +23,7 @@ from collections.abc import Mapping
from functools import reduce
from typing import TYPE_CHECKING
from unittest import mock
-from unittest.mock import call
+from unittest.mock import ANY, call
import pendulum
import pytest
@@ -434,9 +434,16 @@ class TestDagRun:
}
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
- with mock.patch.object(dag_run, "handle_dag_callback") as
handle_dag_callback:
+ with mock.patch.object(dag_run, "execute_dag_callbacks") as
execute_dag_callbacks:
_, callback = dag_run.update_state()
- assert handle_dag_callback.mock_calls == [mock.call(dag=dag,
success=True, reason="success")]
+ assert execute_dag_callbacks.mock_calls == [
+ mock.call(dag=dag, success=True, relevant_ti=ANY, reason="success")
+ ]
+ # Make sure the correct TI is passed on success
+ call_args = execute_dag_callbacks.call_args
+ ti_passed = call_args.kwargs["relevant_ti"]
+ assert ti_passed.task_id == "test_state_succeeded2"
+
assert dag_run.state == DagRunState.SUCCESS
# Callbacks are not added until handle_callback = False is passed to
dag_run.update_state()
assert callback is None
@@ -461,13 +468,62 @@ class TestDagRun:
dag_task1.set_downstream(dag_task2)
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
- with mock.patch.object(dag_run, "handle_dag_callback") as
handle_dag_callback:
+ with mock.patch.object(dag_run, "execute_dag_callbacks") as
execute_dag_callbacks:
_, callback = dag_run.update_state()
- assert handle_dag_callback.mock_calls == [mock.call(dag=dag,
success=False, reason="task_failure")]
+ assert execute_dag_callbacks.mock_calls == [
+ mock.call(dag=dag, success=False, relevant_ti=ANY,
reason="task_failure")
+ ]
+ # Make sure the correct TI is passed on failure
+ call_args = execute_dag_callbacks.call_args
+ ti_passed = call_args.kwargs["relevant_ti"]
+ assert ti_passed.task_id == "test_state_failed2"
+
assert dag_run.state == DagRunState.FAILED
# Callbacks are not added until handle_callback = False is passed to
dag_run.update_state()
assert callback is None
+ def test_dagrun_failure_callback_on_tasks_deadlocked(self, dag_maker,
session):
+ def on_failure_callable(context):
+ assert context["dag_run"].dag_id ==
"test_dagrun_failure_callback_on_tasks_deadlocked"
+
+ with dag_maker(
+ dag_id="test_dagrun_failure_callback_on_tasks_deadlocked",
+ schedule=datetime.timedelta(days=1),
+ start_date=datetime.datetime(2017, 1, 1),
+ on_failure_callback=on_failure_callable,
+ ):
+ up = EmptyOperator(task_id="upstream")
+ middle = EmptyOperator(task_id="wrong")
+ down = EmptyOperator(task_id="downstream")
+
+ middle.trigger_rule = TriggerRule.ONE_FAILED
+ middle.set_upstream(up)
+ middle.set_downstream(down)
+
+ dr = dag_maker.create_dagrun()
+
+ ti_up: TI = dr.get_task_instance(task_id=up.task_id, session=session)
+ ti_middle: TI = dr.get_task_instance(task_id=middle.task_id,
session=session)
+ ti_up.set_state(state=TaskInstanceState.SUCCESS, session=session)
+ ti_middle.set_state(state=None, session=session)
+ ti_middle.task.trigger_rule = "invalid"
+
+ serialized_dag = dr.get_dag()
+
+ with mock.patch.object(dr, "execute_dag_callbacks") as
execute_dag_callbacks:
+ _, callback = dr.update_state(execute_callbacks=True)
+ assert execute_dag_callbacks.mock_calls == [
+ mock.call(dag=serialized_dag, success=False,
relevant_ti=ti_middle, reason="all_tasks_deadlocked")
+ ]
+ # Make sure the correct TI is passed on deadlock
+ call_args = execute_dag_callbacks.call_args
+ ti_passed = call_args.kwargs["relevant_ti"]
+ assert ti_passed.task_id == "wrong"
+
+ assert dr.state == DagRunState.FAILED
+ # Callbacks is None as execute_callbacks=True
+ assert callback is None
+
def test_on_success_callback_when_task_skipped(self, session,
testing_dag_bundle):
mock_on_success = mock.MagicMock()
mock_on_success.__name__ = "mock_on_success"
@@ -682,7 +738,7 @@ class TestDagRun:
bundle_version=None,
context_from_server=DagRunContext(
dag_run=dag_run,
- last_ti=dag_run.get_last_ti(dag, session),
+
last_ti=dag_run.get_task_instance(task_id="test_state_succeeded2"),
),
msg="success",
)
@@ -732,7 +788,7 @@ class TestDagRun:
bundle_version=None,
context_from_server=DagRunContext(
dag_run=dag_run,
- last_ti=dag_run.get_last_ti(dag, session),
+
last_ti=dag_run.get_task_instance(task_id="test_state_failed2"),
),
)
@@ -1339,11 +1395,16 @@ class TestDagRun:
)
dag_run.dag = scheduler_dag
- with mock.patch.object(dag_run, "handle_dag_callback") as
handle_dag_callback:
+ with mock.patch.object(dag_run, "execute_dag_callbacks") as
execute_dag_callbacks:
_, callback = dag_run.update_state()
- assert handle_dag_callback.mock_calls == [
- mock.call(dag=scheduler_dag, success=True, reason="success")
+ assert execute_dag_callbacks.mock_calls == [
+ mock.call(dag=scheduler_dag, success=True, relevant_ti=ANY,
reason="success")
]
+ # Make sure the correct TI is passed on success
+ call_args = execute_dag_callbacks.call_args
+ ti_passed = call_args.kwargs["relevant_ti"]
+ assert ti_passed.task_id == "task_2"
+
assert dag_run.state == DagRunState.SUCCESS
# Callbacks are not added until handle_callback = False is passed to
dag_run.update_state()
assert callback is None
@@ -3107,103 +3168,11 @@ def test_teardown_and_fail_fast(dag_maker):
}
-class TestDagRunGetLastTi:
- def test_get_last_ti_with_multiple_tis(self, dag_maker, session):
- """Test get_last_ti returns the last TI (first created) when multiple
TIs exist"""
- with dag_maker("test_dag", session=session) as dag:
- BashOperator(task_id="task1", bash_command="echo 1")
- BashOperator(task_id="task2", bash_command="echo 2")
- BashOperator(task_id="task3", bash_command="echo 3")
-
- dr = dag_maker.create_dagrun()
-
- tis = dr.get_task_instances(session=session)
- assert len(tis) == 3
-
- # Mark some TIs with different states
- tis[0].state = TaskInstanceState.SUCCESS
- tis[1].state = TaskInstanceState.FAILED
- tis[2].state = TaskInstanceState.RUNNING
- session.commit()
-
- last_ti = dr.get_last_ti(dag, session=session)
-
- # Should return the last TI in the list (index -1)
- assert last_ti is not None
- assert last_ti == tis[-1]
- assert last_ti.task_id == "task3"
-
- def test_get_last_ti_filters_none_state_in_partial_dag(self, dag_maker,
session):
- """Test get_last_ti filters out NONE state TIs when dag is partial"""
- with dag_maker("test_dag", session=session) as dag:
- BashOperator(task_id="task1", bash_command="echo 1")
- BashOperator(task_id="task2", bash_command="echo 2")
-
- dr = dag_maker.create_dagrun()
-
- dag.partial = True
-
- # Create task instances with different states
- tis = dr.get_task_instances(session=session)
- tis[0].state = State.NONE # Should be filtered out in partial DAG
- tis[1].state = TaskInstanceState.RUNNING
- session.commit()
-
- last_ti = dr.get_last_ti(dag, session=session)
-
- assert last_ti is not None
- assert last_ti.state != State.NONE
- assert last_ti.task_id == "task2"
-
- def test_get_last_ti_filters_removed_tasks(self, dag_maker, session):
- """Test get_last_ti filters out REMOVED task instances"""
- with dag_maker("test_dag", session=session) as dag:
- BashOperator(task_id="task1", bash_command="echo 1")
- BashOperator(task_id="task2", bash_command="echo 2")
- BashOperator(task_id="task3", bash_command="echo 3")
-
- dr = dag_maker.create_dagrun()
-
- tis = dr.get_task_instances(session=session)
- assert len(tis) == 3
-
- ti_by_id = {ti.task_id: ti for ti in tis}
-
- # Mark some TIs as removed
- ti_by_id["task1"].state = TaskInstanceState.REMOVED
- ti_by_id["task2"].state = TaskInstanceState.REMOVED
- ti_by_id["task3"].state = TaskInstanceState.SUCCESS
- session.commit()
-
- last_ti = dr.get_last_ti(dag, session=session)
-
- # Should return the TI that is not REMOVED
- assert last_ti is not None
- assert last_ti.state != TaskInstanceState.REMOVED
- assert last_ti.task_id == "task3"
-
- def test_get_last_ti_with_single_ti(self, dag_maker, session):
- """Test get_last_ti works with single task instance"""
- with dag_maker("test_dag", session=session) as dag:
- BashOperator(task_id="single_task", bash_command="echo 1")
-
- dr = dag_maker.create_dagrun()
-
- tis = dr.get_task_instances(session=session)
- assert len(tis) == 1
-
- last_ti = dr.get_last_ti(dag, session=session)
-
- assert last_ti is not None
- assert last_ti == tis[0]
- assert last_ti.task_id == "single_task"
-
-
class TestDagRunHandleDagCallback:
- """Test the handle_dag_callback method (only uses in dag.test)."""
+ """Test the execute_dag_callbacks method (only uses in dag.test)."""
- def test_handle_dag_callback_success(self, dag_maker, session):
- """Test handle_dag_callback executes success callback with
RuntimeTaskInstance context"""
+ def test_execute_dag_callbacks_success(self, dag_maker, session):
+ """Test execute_dag_callbacks executes success callback with
RuntimeTaskInstance context"""
called = False
context_received = None
@@ -3220,7 +3189,9 @@ class TestDagRunHandleDagCallback:
dag.on_success_callback = on_success
dag.has_on_success_callback = True
- dr.handle_dag_callback(dag, success=True, reason="test_success")
+ dr.execute_dag_callbacks(
+ dag, success=True, relevant_ti=dr.get_task_instance("test_task"),
reason="test_success"
+ )
assert called is True
assert context_received is not None
@@ -3232,8 +3203,8 @@ class TestDagRunHandleDagCallback:
assert "ts" in context_received
assert "params" in context_received
- def test_handle_dag_callback_failure(self, dag_maker, session):
- """Test handle_dag_callback executes failure callback with
RuntimeTaskInstance context"""
+ def test_execute_dag_callbacks_failure(self, dag_maker, session):
+ """Test execute_dag_callbacks executes failure callback with
RuntimeTaskInstance context"""
called = False
context_received = None
@@ -3250,7 +3221,9 @@ class TestDagRunHandleDagCallback:
dag.on_failure_callback = on_failure
dag.has_on_failure_callback = True
- dr.handle_dag_callback(dag, success=False, reason="test_failure")
+ dr.execute_dag_callbacks(
+ dag, success=False, relevant_ti=dr.get_task_instance("test_task"),
reason="test_failure"
+ )
assert called is True
assert context_received is not None
@@ -3262,8 +3235,8 @@ class TestDagRunHandleDagCallback:
assert "ts" in context_received
assert "params" in context_received
- def test_handle_dag_callback_multiple_callbacks(self, dag_maker, session):
- """Test handle_dag_callback executes multiple callbacks"""
+ def test_execute_dag_callbacks_multiple_callbacks(self, dag_maker,
session):
+ """Test execute_dag_callbacks executes multiple callbacks"""
call_count = 0
def on_failure_1(context):
@@ -3282,12 +3255,17 @@ class TestDagRunHandleDagCallback:
dag.on_failure_callback = [on_failure_1, on_failure_2]
dag.has_on_failure_callback = True
- dr.handle_dag_callback(dag, success=False, reason="test_failure")
+ dr.execute_dag_callbacks(
+ dag,
+ success=False,
+ relevant_ti=dr.get_task_instance("test_task"),
+ reason="test_failure",
+ )
assert call_count == 2
- def test_handle_dag_callback_context_has_correct_ti_info(self, dag_maker,
session):
- """Test handle_dag_callback context contains correct task instance
information"""
+ def test_execute_dag_callbacks_context_has_correct_ti_info(self,
dag_maker, session):
+ """Test execute_dag_callbacks context contains correct task instance
information"""
context_received = None
def on_failure(context):
@@ -3302,7 +3280,12 @@ class TestDagRunHandleDagCallback:
dag.on_failure_callback = on_failure
dag.has_on_failure_callback = True
- dr.handle_dag_callback(dag, success=False, reason="test_failure")
+ dr.execute_dag_callbacks(
+ dag,
+ success=False,
+ relevant_ti=dr.get_task_instance("test_task"),
+ reason="test_failure",
+ )
assert context_received is not None
# Check that context contains correct task info