This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b7b8fd763f1eb8e16cf7e6c735af613c6b41c4a2
Author: Aleksey Kirilishin <[email protected]>
AuthorDate: Fri Jan 26 19:20:27 2024 +0300

    Handle SystemExit raised in the task. (#36986)
    
    * Handle SystemExit raised in the task.
    
    * Add handling of system exit in tasks:
    * Exiting with a zero or None code signifies success, and the task does not 
return any value.
    * Exiting with other codes signifies an error.
    
    (cherry picked from commit 574d90f2178ea746960b0cb71b30b59d6b4fb668)
---
 airflow/models/taskinstance.py    | 22 ++++++++++++++++++++--
 tests/models/test_taskinstance.py | 31 +++++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 2 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index cbbc2b726c..dfc1a1d5cf 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -408,6 +408,17 @@ def _execute_task(task_instance, context, task_orig):
             execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
     else:
         execute_callable = task_to_execute.execute
+
+    def _execute_callable(context, **execute_callable_kwargs):
+        try:
+            return execute_callable(context=context, **execute_callable_kwargs)
+        except SystemExit as e:
+            # Handle only successful cases here. Failure cases will be handled 
upper
+            # in the exception chain.
+            if e.code is not None and e.code != 0:
+                raise
+            return None
+
     # If a timeout is specified for the task, make it fail
     # if it goes beyond
     if task_to_execute.execution_timeout:
@@ -425,12 +436,12 @@ def _execute_task(task_instance, context, task_orig):
                 raise AirflowTaskTimeout()
             # Run task in timeout wrapper
             with timeout(timeout_seconds):
-                result = execute_callable(context=context, 
**execute_callable_kwargs)
+                result = _execute_callable(context=context, 
**execute_callable_kwargs)
         except AirflowTaskTimeout:
             task_to_execute.on_kill()
             raise
     else:
-        result = execute_callable(context=context, **execute_callable_kwargs)
+        result = _execute_callable(context=context, **execute_callable_kwargs)
     with create_session() as session:
         if task_to_execute.do_xcom_push:
             xcom_value = result
@@ -2402,6 +2413,13 @@ class TaskInstance(Base, LoggingMixin):
                 self.handle_failure(e, test_mode, context, session=session)
                 session.commit()
                 raise
+            except SystemExit as e:
+                # We have already handled SystemExit with success codes (0 and 
None) in the `_execute_task`.
+                # Therefore, here we must handle only error codes.
+                msg = f"Task failed due to SystemExit({e.code})"
+                self.handle_failure(msg, test_mode, context, session=session)
+                session.commit()
+                raise Exception(msg)
             finally:
                 
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", 
tags=self.stats_tags)
                 # Same metric with tagging
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 27ce80df1a..016bd75eaf 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2924,6 +2924,37 @@ class TestTaskInstance:
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
 
+    @pytest.mark.parametrize(
+        "code, expected_state",
+        [
+            (1, State.FAILED),
+            (-1, State.FAILED),
+            ("error", State.FAILED),
+            (0, State.SUCCESS),
+            (None, State.SUCCESS),
+        ],
+    )
+    def test_handle_system_exit(self, dag_maker, code, expected_state):
+        with dag_maker():
+
+            def f(*args, **kwargs):
+                exit(code)
+
+            task = PythonOperator(task_id="mytask", python_callable=f)
+
+        dr = dag_maker.create_dagrun()
+        ti = TI(task=task, run_id=dr.run_id)
+        ti.state = State.RUNNING
+        session = settings.Session()
+        session.merge(ti)
+        session.commit()
+        try:
+            ti._run_raw_task()
+        except Exception:
+            ...
+        ti.refresh_from_db()
+        assert ti.state == expected_state
+
     def test_get_current_context_works_in_template(self, dag_maker):
         def user_defined_macro():
             from airflow.operators.python import get_current_context

Reply via email to