This is an automated email from the ASF dual-hosted git repository.

rom pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6889b6b49a Fix running on_success_callback if `AirflowSkipException` 
raised (#40936)
6889b6b49a is described below

commit 6889b6b49aa004fb7dd2f5f3ecbe1c42a7b104f9
Author: rom sharon <[email protected]>
AuthorDate: Tue Aug 20 15:24:27 2024 +0300

    Fix running on_success_callback if `AirflowSkipException` raised (#40936)
    
    * finish flow if airflow skip exception raise
    
    * run on success callback only if task is success
    
    * remove deprecated line
    
    * add regresion test
    
    * add test for dag run on success callback if task is skipped
    
    * validate task instance finish with status skipped
    
    * fix logic to depend on given ope
    rator for not break change
    
    * add execute_on_success_callback_when_skipped to MappedOperator
    
    * fix test
    
    * remove parameter for competability
    
    * add newsfragment
    
    * update newsfragments
    
    * fix newsfragments
    
    * fix warnning in test
---
 airflow/models/taskinstance.py    |  3 ++-
 newsfragments/40936.bugfix.rst    |  1 +
 tests/models/test_dagrun.py       | 25 +++++++++++++++++++++++++
 tests/models/test_taskinstance.py | 14 +++++++++-----
 4 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index cedc254239..640446f8e6 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -354,7 +354,8 @@ def _run_raw_task(
         # run on_success_callback before db committing
         # otherwise, the LocalTaskJob sees the state is changed to `success`,
         # but the task_runner is still running, LocalTaskJob then treats the 
state is set externally!
-        _run_finished_callback(callbacks=ti.task.on_success_callback, 
context=context)
+        if ti.state == TaskInstanceState.SUCCESS:
+            _run_finished_callback(callbacks=ti.task.on_success_callback, 
context=context)
 
         if not test_mode:
             _add_log(event=ti.state, task_instance=ti, session=session)
diff --git a/newsfragments/40936.bugfix.rst b/newsfragments/40936.bugfix.rst
new file mode 100644
index 0000000000..207aeb9752
--- /dev/null
+++ b/newsfragments/40936.bugfix.rst
@@ -0,0 +1 @@
+Fix: ``on_success_callback`` will no longer execute if a task is skipped. 
Previously, this callback was triggered even when the task was skipped, which 
could lead to unintended behavior or inconsistencies in downstream processes. 
This is a breaking change because workflows that rely on 
``on_success_callback`` running for skipped tasks will need to be updated. 
Consider updating your DAGs to handle cases where the callback is not invoked 
due to task skipping.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 0117103dbb..3c94c098a2 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -477,6 +477,31 @@ class TestDagRun:
         # Callbacks are not added until handle_callback = False is passed to 
dag_run.update_state()
         assert callback is None
 
+    def test_on_success_callback_when_task_skipped(self, session):
+        mock_on_success = mock.MagicMock()
+        mock_on_success.__name__ = "mock_on_success"
+
+        dag = DAG(
+            dag_id="test_dagrun_update_state_with_handle_callback_success",
+            start_date=datetime.datetime(2017, 1, 1),
+            on_success_callback=mock_on_success,
+            schedule=datetime.timedelta(days=1),
+        )
+
+        _ = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
+
+        initial_task_states = {
+            "test_state_succeeded1": TaskInstanceState.SKIPPED,
+        }
+
+        dag_run = self.create_dag_run(dag=dag, 
task_states=initial_task_states, session=session)
+        _, _ = dag_run.update_state(execute_callbacks=True)
+        task = dag_run.get_task_instances()[0]
+
+        assert task.state == TaskInstanceState.SKIPPED
+        assert DagRunState.SUCCESS == dag_run.state
+        mock_on_success.assert_called_once()
+
     def test_dagrun_update_state_with_handle_callback_success(self, session):
         def on_success_callable(context):
             assert context["dag_run"].dag_id == 
"test_dagrun_update_state_with_handle_callback_success"
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 86357bad8f..ed9599d08d 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -4042,22 +4042,26 @@ class TestTaskInstance:
         def raise_skip_exception():
             raise AirflowSkipException
 
-        callback_function = mock.MagicMock()
-        callback_function.__name__ = "callback_function"
+        on_skipped_callback_function = mock.MagicMock()
+        on_skipped_callback_function.__name__ = "on_skipped_callback_function"
+
+        on_success_callback_function = mock.MagicMock()
+        on_success_callback_function.__name__ = "on_success_callback_function"
 
         with dag_maker(dag_id="test_skipped_task", serialized=True):
             task = PythonOperator(
                 task_id="test_skipped_task",
                 python_callable=raise_skip_exception,
-                on_skipped_callback=callback_function,
+                on_skipped_callback=on_skipped_callback_function,
+                on_success_callback=on_success_callback_function,
             )
-
         dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
         ti = dr.task_instances[0]
         ti.task = task
         ti.run()
         assert State.SKIPPED == ti.state
-        assert callback_function.called
+        on_skipped_callback_function.assert_called_once()
+        on_success_callback_function.assert_not_called()
 
     def test_task_instance_history_is_created_when_ti_goes_for_retry(self, 
dag_maker, session):
         with dag_maker(serialized=True):

Reply via email to