This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-6-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit e41c549b5c3f420767eb11b66be570bb894a7490 Author: Hussein Awala <[email protected]> AuthorDate: Tue Apr 18 01:40:29 2023 +0200 Skip PythonVirtualenvOperator task when it returns a provided exit code (#30690) * Add a new argument to rais skip exception when the python callable exit with the same value * add unit tests for skip_exit_code (cherry picked from commit e624b37fd04b7fe9c9c2c34e494c14f90a6aa5c1) --- airflow/operators/python.py | 22 +++++++++++++++++++--- tests/operators/test_python.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 8a3fa58123..6f565f3492 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -33,7 +33,12 @@ from typing import Any, Callable, Collection, Iterable, Mapping, Sequence import dill -from airflow.exceptions import AirflowConfigException, AirflowException, RemovedInAirflow3Warning +from airflow.exceptions import ( + AirflowConfigException, + AirflowException, + AirflowSkipException, + RemovedInAirflow3Warning, +) from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin from airflow.models.taskinstance import _CURRENT_CONTEXT @@ -466,6 +471,9 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator will raise warning if Airflow is not installed, and it will attempt to load Airflow macros when starting. + :param skip_exit_code: If python_callable exits with this exit code, leave the task + in ``skipped`` state (default: None). If set to ``None``, any non-zero + exit code will be treated as a failure. """ template_fields: Sequence[str] = tuple({"requirements"} | set(PythonOperator.template_fields)) @@ -486,6 +494,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): templates_dict: dict | None = None, templates_exts: list[str] | None = None, expect_airflow: bool = True, + skip_exit_code: int | None = None, **kwargs, ): if ( @@ -509,6 +518,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): self.python_version = python_version self.system_site_packages = system_site_packages self.pip_install_options = pip_install_options + self.skip_exit_code = skip_exit_code super().__init__( python_callable=python_callable, use_dill=use_dill, @@ -544,8 +554,14 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): pip_install_options=self.pip_install_options, ) python_path = tmp_path / "bin" / "python" - - return self._execute_python_callable_in_subprocess(python_path, tmp_path) + try: + result = self._execute_python_callable_in_subprocess(python_path, tmp_path) + except subprocess.CalledProcessError as e: + if self.skip_exit_code and e.returncode == self.skip_exit_code: + raise AirflowSkipException(f"Process exited with code {self.skip_exit_code}. Skipping.") + else: + raise + return result def _iter_serializable_context_keys(self): yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 64d824f8f8..067e07d6c3 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -47,7 +47,7 @@ from airflow.utils import timezone from airflow.utils.context import AirflowContextDeprecationWarning, Context from airflow.utils.python_virtualenv import prepare_virtualenv from airflow.utils.session import create_session -from airflow.utils.state import DagRunState, State +from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET, DagRunType from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -131,10 +131,12 @@ class BasePythonTest: task.run(start_date=self.default_date, end_date=self.default_date) return task - def run_as_task(self, fn, **kwargs): + def run_as_task(self, fn, return_ti=False, **kwargs): """Create TaskInstance and run it.""" ti = self.create_ti(fn, **kwargs) ti.run() + if return_ti: + return ti return ti.task def render_templates(self, fn, **kwargs): @@ -932,6 +934,34 @@ class TestPythonVirtualenvOperator(BasePythonTest): } assert set(context) == declared_keys + @pytest.mark.parametrize( + "extra_kwargs, actual_exit_code, expected_state", + [ + (None, 99, TaskInstanceState.FAILED), + ({"skip_exit_code": 100}, 100, TaskInstanceState.SKIPPED), + ({"skip_exit_code": 100}, 101, TaskInstanceState.FAILED), + ({"skip_exit_code": None}, 0, TaskInstanceState.SUCCESS), + ], + ) + def test_skip_exit_code(self, extra_kwargs, actual_exit_code, expected_state): + def f(exit_code): + if exit_code != 0: + raise SystemExit(exit_code) + + if expected_state == TaskInstanceState.FAILED: + with pytest.raises(CalledProcessError): + self.run_as_task( + f, op_kwargs={"exit_code": actual_exit_code}, **(extra_kwargs if extra_kwargs else {}) + ) + else: + ti = self.run_as_task( + f, + return_ti=True, + op_kwargs={"exit_code": actual_exit_code}, + **(extra_kwargs if extra_kwargs else {}), + ) + assert ti.state == expected_state + class TestCurrentContext: def test_current_context_no_context_raise(self):
