This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 53e956c98ef8b8a14f73077729b787a84f8335ab
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Mar 17 11:52:08 2023 +0100

    Revert fix for on_failure_callback when task receives a SIGTERM (#30165)
    
    * Revert fix for on_failure_callback when task receives a SIGTERM
    
    From the comment on exception handling when a task is killed externally we 
do not handle callback thus
    the above fix was made in error. Here's the comment on code:
    
     for case when task is marked as success/failed externally
     or dagrun timed out and task is marked as skipped
     current behavior doesn't hit the callbacks
    
    
https://github.com/apache/airflow/blob/b65dbaaf3f21ea5396da121bbfa7f895d0ab8516/airflow/models/taskinstance.py#L1468-L1470
    
    * Update tests/models/test_taskinstance.py
    
    Co-authored-by: Emil Ejbyfeldt <[email protected]>
    
    ---------
    
    Co-authored-by: Emil Ejbyfeldt <[email protected]>
    (cherry picked from commit 869c1e3581fa163bbaad11a2d5ddaf8cf433296d)
---
 airflow/exceptions.py             |  6 ------
 airflow/models/taskinstance.py    | 16 +++++-----------
 tests/models/test_taskinstance.py | 32 ++++++++++++++++----------------
 3 files changed, 21 insertions(+), 33 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index e6ef9bd4e1..4bf946fd8e 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -29,12 +29,6 @@ if TYPE_CHECKING:
     from airflow.models import DagRun
 
 
-class AirflowTermSignal(Exception):
-    """Raise when we receive a TERM signal"""
-
-    status_code = HTTPStatus.INTERNAL_SERVER_ERROR
-
-
 class AirflowException(Exception):
     """
     Base class for all Airflow's errors.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 03ab48d46a..7f4c91ddb8 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -75,7 +75,6 @@ from airflow.exceptions import (
     AirflowSensorTimeout,
     AirflowSkipException,
     AirflowTaskTimeout,
-    AirflowTermSignal,
     DagRunNotFound,
     RemovedInAirflow3Warning,
     TaskDeferralError,
@@ -1487,7 +1486,8 @@ class TaskInstance(Base, LoggingMixin):
                 os._exit(1)
                 return
             self.log.error("Received SIGTERM. Terminating subprocesses.")
-            raise AirflowTermSignal("Task received SIGTERM signal")
+            self.task.on_kill()
+            raise AirflowException("Task received SIGTERM signal")
 
         signal.signal(signal.SIGTERM, signal_handler)
 
@@ -1526,15 +1526,9 @@ class TaskInstance(Base, LoggingMixin):
 
             # Execute the task
             with set_current_context(context):
-                try:
-                    result = self._execute_task(context, task_orig)
-                    # Run post_execute callback
-                    self.task.post_execute(context=context, result=result)
-                except AirflowTermSignal:
-                    self.task.on_kill()
-                    if self.task.on_failure_callback:
-                        
self._run_finished_callback(self.task.on_failure_callback, context, 
"on_failure")
-                    raise AirflowException("Task received SIGTERM signal")
+                result = self._execute_task(context, task_orig)
+            # Run post_execute callback
+            self.task.post_execute(context=context, result=result)
 
         Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
         Stats.incr("ti_successes")
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index d54bfddc8d..96a8f51a2e 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -447,51 +447,51 @@ class TestTaskInstance:
         ti.run()
         assert State.SKIPPED == ti.state
 
-    def test_task_sigterm_works_with_retries(self, dag_maker):
+    def test_task_sigterm_calls_on_failure_callback(self, dag_maker, caplog):
         """
-        Test that ensures that tasks are retried when they receive sigterm
+        Test that ensures that tasks call on_failure_callback when they 
receive sigterm
         """
 
         def task_function(ti):
             os.kill(ti.pid, signal.SIGTERM)
 
-        with dag_maker("test_mark_failure_2"):
-            task = PythonOperator(
+        with dag_maker():
+            task_ = PythonOperator(
                 task_id="test_on_failure",
                 python_callable=task_function,
-                retries=1,
-                retry_delay=datetime.timedelta(seconds=2),
+                on_failure_callback=lambda context: 
context["ti"].log.info("on_failure_callback called"),
             )
 
         dr = dag_maker.create_dagrun()
         ti = dr.task_instances[0]
-        ti.task = task
+        ti.task = task_
         with pytest.raises(AirflowException):
             ti.run()
-        ti.refresh_from_db()
-        assert ti.state == State.UP_FOR_RETRY
+        assert "on_failure_callback called" in caplog.text
 
-    def test_task_sigterm_calls_on_failure_callack(self, dag_maker, caplog):
+    def test_task_sigterm_works_with_retries(self, dag_maker):
         """
-        Test that ensures that tasks call on_failure_callback when they 
receive sigterm
+        Test that ensures that tasks are retried when they receive sigterm
         """
 
         def task_function(ti):
             os.kill(ti.pid, signal.SIGTERM)
 
-        with dag_maker():
-            task_ = PythonOperator(
+        with dag_maker("test_mark_failure_2"):
+            task = PythonOperator(
                 task_id="test_on_failure",
                 python_callable=task_function,
-                on_failure_callback=lambda context: 
context["ti"].log.info("on_failure_callback called"),
+                retries=1,
+                retry_delay=datetime.timedelta(seconds=2),
             )
 
         dr = dag_maker.create_dagrun()
         ti = dr.task_instances[0]
-        ti.task = task_
+        ti.task = task
         with pytest.raises(AirflowException):
             ti.run()
-        assert "on_failure_callback called" in caplog.text
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
 
     @pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, 
State.SKIPPED])
     def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, 
dag_maker):

Reply via email to