This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b7b8fd763f1eb8e16cf7e6c735af613c6b41c4a2 Author: Aleksey Kirilishin <[email protected]> AuthorDate: Fri Jan 26 19:20:27 2024 +0300 Handle SystemExit raised in the task. (#36986) * Handle SystemExit raised in the task. * Add handling of system exit in tasks: * Exiting with a zero or None code signifies success, and the task does not return any value. * Exiting with other codes signifies an error. (cherry picked from commit 574d90f2178ea746960b0cb71b30b59d6b4fb668) --- airflow/models/taskinstance.py | 22 ++++++++++++++++++++-- tests/models/test_taskinstance.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index cbbc2b726c..dfc1a1d5cf 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -408,6 +408,17 @@ def _execute_task(task_instance, context, task_orig): execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs else: execute_callable = task_to_execute.execute + + def _execute_callable(context, **execute_callable_kwargs): + try: + return execute_callable(context=context, **execute_callable_kwargs) + except SystemExit as e: + # Handle only successful cases here. Failure cases will be handled upper + # in the exception chain. + if e.code is not None and e.code != 0: + raise + return None + # If a timeout is specified for the task, make it fail # if it goes beyond if task_to_execute.execution_timeout: @@ -425,12 +436,12 @@ def _execute_task(task_instance, context, task_orig): raise AirflowTaskTimeout() # Run task in timeout wrapper with timeout(timeout_seconds): - result = execute_callable(context=context, **execute_callable_kwargs) + result = _execute_callable(context=context, **execute_callable_kwargs) except AirflowTaskTimeout: task_to_execute.on_kill() raise else: - result = execute_callable(context=context, **execute_callable_kwargs) + result = _execute_callable(context=context, **execute_callable_kwargs) with create_session() as session: if task_to_execute.do_xcom_push: xcom_value = result @@ -2402,6 +2413,13 @@ class TaskInstance(Base, LoggingMixin): self.handle_failure(e, test_mode, context, session=session) session.commit() raise + except SystemExit as e: + # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. + # Therefore, here we must handle only error codes. + msg = f"Task failed due to SystemExit({e.code})" + self.handle_failure(msg, test_mode, context, session=session) + session.commit() + raise Exception(msg) finally: Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) # Same metric with tagging diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 27ce80df1a..016bd75eaf 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2924,6 +2924,37 @@ class TestTaskInstance: ti.refresh_from_db() assert ti.state == State.SUCCESS + @pytest.mark.parametrize( + "code, expected_state", + [ + (1, State.FAILED), + (-1, State.FAILED), + ("error", State.FAILED), + (0, State.SUCCESS), + (None, State.SUCCESS), + ], + ) + def test_handle_system_exit(self, dag_maker, code, expected_state): + with dag_maker(): + + def f(*args, **kwargs): + exit(code) + + task = PythonOperator(task_id="mytask", python_callable=f) + + dr = dag_maker.create_dagrun() + ti = TI(task=task, run_id=dr.run_id) + ti.state = State.RUNNING + session = settings.Session() + session.merge(ti) + session.commit() + try: + ti._run_raw_task() + except Exception: + ... + ti.refresh_from_db() + assert ti.state == expected_state + def test_get_current_context_works_in_template(self, dag_maker): def user_defined_macro(): from airflow.operators.python import get_current_context
