This is an automated email from the ASF dual-hosted git repository.
rom pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6889b6b49a Fix running on_success_callback if `AirflowSkipException`
raised (#40936)
6889b6b49a is described below
commit 6889b6b49aa004fb7dd2f5f3ecbe1c42a7b104f9
Author: rom sharon <[email protected]>
AuthorDate: Tue Aug 20 15:24:27 2024 +0300
Fix running on_success_callback if `AirflowSkipException` raised (#40936)
* finish flow if airflow skip exception raise
* run on success callback only if task is success
* remove deprecated line
* add regresion test
* add test for dag run on success callback if task is skipped
* validate task instance finish with status skipped
* fix logic to depend on given ope
rator for not break change
* add execute_on_success_callback_when_skipped to MappedOperator
* fix test
* remove parameter for competability
* add newsfragment
* update newsfragments
* fix newsfragments
* fix warnning in test
---
airflow/models/taskinstance.py | 3 ++-
newsfragments/40936.bugfix.rst | 1 +
tests/models/test_dagrun.py | 25 +++++++++++++++++++++++++
tests/models/test_taskinstance.py | 14 +++++++++-----
4 files changed, 37 insertions(+), 6 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index cedc254239..640446f8e6 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -354,7 +354,8 @@ def _run_raw_task(
# run on_success_callback before db committing
# otherwise, the LocalTaskJob sees the state is changed to `success`,
# but the task_runner is still running, LocalTaskJob then treats the
state is set externally!
- _run_finished_callback(callbacks=ti.task.on_success_callback,
context=context)
+ if ti.state == TaskInstanceState.SUCCESS:
+ _run_finished_callback(callbacks=ti.task.on_success_callback,
context=context)
if not test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
diff --git a/newsfragments/40936.bugfix.rst b/newsfragments/40936.bugfix.rst
new file mode 100644
index 0000000000..207aeb9752
--- /dev/null
+++ b/newsfragments/40936.bugfix.rst
@@ -0,0 +1 @@
+Fix: ``on_success_callback`` will no longer execute if a task is skipped.
Previously, this callback was triggered even when the task was skipped, which
could lead to unintended behavior or inconsistencies in downstream processes.
This is a breaking change because workflows that rely on
``on_success_callback`` running for skipped tasks will need to be updated.
Consider updating your DAGs to handle cases where the callback is not invoked
due to task skipping.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 0117103dbb..3c94c098a2 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -477,6 +477,31 @@ class TestDagRun:
# Callbacks are not added until handle_callback = False is passed to
dag_run.update_state()
assert callback is None
+ def test_on_success_callback_when_task_skipped(self, session):
+ mock_on_success = mock.MagicMock()
+ mock_on_success.__name__ = "mock_on_success"
+
+ dag = DAG(
+ dag_id="test_dagrun_update_state_with_handle_callback_success",
+ start_date=datetime.datetime(2017, 1, 1),
+ on_success_callback=mock_on_success,
+ schedule=datetime.timedelta(days=1),
+ )
+
+ _ = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
+
+ initial_task_states = {
+ "test_state_succeeded1": TaskInstanceState.SKIPPED,
+ }
+
+ dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
+ _, _ = dag_run.update_state(execute_callbacks=True)
+ task = dag_run.get_task_instances()[0]
+
+ assert task.state == TaskInstanceState.SKIPPED
+ assert DagRunState.SUCCESS == dag_run.state
+ mock_on_success.assert_called_once()
+
def test_dagrun_update_state_with_handle_callback_success(self, session):
def on_success_callable(context):
assert context["dag_run"].dag_id ==
"test_dagrun_update_state_with_handle_callback_success"
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 86357bad8f..ed9599d08d 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -4042,22 +4042,26 @@ class TestTaskInstance:
def raise_skip_exception():
raise AirflowSkipException
- callback_function = mock.MagicMock()
- callback_function.__name__ = "callback_function"
+ on_skipped_callback_function = mock.MagicMock()
+ on_skipped_callback_function.__name__ = "on_skipped_callback_function"
+
+ on_success_callback_function = mock.MagicMock()
+ on_success_callback_function.__name__ = "on_success_callback_function"
with dag_maker(dag_id="test_skipped_task", serialized=True):
task = PythonOperator(
task_id="test_skipped_task",
python_callable=raise_skip_exception,
- on_skipped_callback=callback_function,
+ on_skipped_callback=on_skipped_callback_function,
+ on_success_callback=on_success_callback_function,
)
-
dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
ti = dr.task_instances[0]
ti.task = task
ti.run()
assert State.SKIPPED == ti.state
- assert callback_function.called
+ on_skipped_callback_function.assert_called_once()
+ on_success_callback_function.assert_not_called()
def test_task_instance_history_is_created_when_ti_goes_for_retry(self,
dag_maker, session):
with dag_maker(serialized=True):