This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e624b37fd0 Skip PythonVirtualenvOperator task when it returns a 
provided exit code (#30690)
e624b37fd0 is described below

commit e624b37fd04b7fe9c9c2c34e494c14f90a6aa5c1
Author: Hussein Awala <[email protected]>
AuthorDate: Tue Apr 18 01:40:29 2023 +0200

    Skip PythonVirtualenvOperator task when it returns a provided exit code 
(#30690)
    
    * Add a new argument to rais skip exception when the python callable exit 
with the same value
    
    * add unit tests for skip_exit_code
---
 airflow/operators/python.py    | 22 +++++++++++++++++++---
 tests/operators/test_python.py | 34 ++++++++++++++++++++++++++++++++--
 2 files changed, 51 insertions(+), 5 deletions(-)

diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 8a3fa58123..6f565f3492 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -33,7 +33,12 @@ from typing import Any, Callable, Collection, Iterable, 
Mapping, Sequence
 
 import dill
 
-from airflow.exceptions import AirflowConfigException, AirflowException, 
RemovedInAirflow3Warning
+from airflow.exceptions import (
+    AirflowConfigException,
+    AirflowException,
+    AirflowSkipException,
+    RemovedInAirflow3Warning,
+)
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.skipmixin import SkipMixin
 from airflow.models.taskinstance import _CURRENT_CONTEXT
@@ -466,6 +471,9 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
     :param expect_airflow: expect Airflow to be installed in the target 
environment. If true, the operator
         will raise warning if Airflow is not installed, and it will attempt to 
load Airflow
         macros when starting.
+    :param skip_exit_code: If python_callable exits with this exit code, leave 
the task
+        in ``skipped`` state (default: None). If set to ``None``, any non-zero
+        exit code will be treated as a failure.
     """
 
     template_fields: Sequence[str] = tuple({"requirements"} | 
set(PythonOperator.template_fields))
@@ -486,6 +494,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         templates_dict: dict | None = None,
         templates_exts: list[str] | None = None,
         expect_airflow: bool = True,
+        skip_exit_code: int | None = None,
         **kwargs,
     ):
         if (
@@ -509,6 +518,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         self.python_version = python_version
         self.system_site_packages = system_site_packages
         self.pip_install_options = pip_install_options
+        self.skip_exit_code = skip_exit_code
         super().__init__(
             python_callable=python_callable,
             use_dill=use_dill,
@@ -544,8 +554,14 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
                 pip_install_options=self.pip_install_options,
             )
             python_path = tmp_path / "bin" / "python"
-
-            return self._execute_python_callable_in_subprocess(python_path, 
tmp_path)
+            try:
+                result = 
self._execute_python_callable_in_subprocess(python_path, tmp_path)
+            except subprocess.CalledProcessError as e:
+                if self.skip_exit_code and e.returncode == self.skip_exit_code:
+                    raise AirflowSkipException(f"Process exited with code 
{self.skip_exit_code}. Skipping.")
+                else:
+                    raise
+            return result
 
     def _iter_serializable_context_keys(self):
         yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 64d824f8f8..067e07d6c3 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -47,7 +47,7 @@ from airflow.utils import timezone
 from airflow.utils.context import AirflowContextDeprecationWarning, Context
 from airflow.utils.python_virtualenv import prepare_virtualenv
 from airflow.utils.session import create_session
-from airflow.utils.state import DagRunState, State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import NOTSET, DagRunType
 from tests.test_utils import AIRFLOW_MAIN_FOLDER
@@ -131,10 +131,12 @@ class BasePythonTest:
         task.run(start_date=self.default_date, end_date=self.default_date)
         return task
 
-    def run_as_task(self, fn, **kwargs):
+    def run_as_task(self, fn, return_ti=False, **kwargs):
         """Create TaskInstance and run it."""
         ti = self.create_ti(fn, **kwargs)
         ti.run()
+        if return_ti:
+            return ti
         return ti.task
 
     def render_templates(self, fn, **kwargs):
@@ -932,6 +934,34 @@ class TestPythonVirtualenvOperator(BasePythonTest):
         }
         assert set(context) == declared_keys
 
+    @pytest.mark.parametrize(
+        "extra_kwargs, actual_exit_code, expected_state",
+        [
+            (None, 99, TaskInstanceState.FAILED),
+            ({"skip_exit_code": 100}, 100, TaskInstanceState.SKIPPED),
+            ({"skip_exit_code": 100}, 101, TaskInstanceState.FAILED),
+            ({"skip_exit_code": None}, 0, TaskInstanceState.SUCCESS),
+        ],
+    )
+    def test_skip_exit_code(self, extra_kwargs, actual_exit_code, 
expected_state):
+        def f(exit_code):
+            if exit_code != 0:
+                raise SystemExit(exit_code)
+
+        if expected_state == TaskInstanceState.FAILED:
+            with pytest.raises(CalledProcessError):
+                self.run_as_task(
+                    f, op_kwargs={"exit_code": actual_exit_code}, 
**(extra_kwargs if extra_kwargs else {})
+                )
+        else:
+            ti = self.run_as_task(
+                f,
+                return_ti=True,
+                op_kwargs={"exit_code": actual_exit_code},
+                **(extra_kwargs if extra_kwargs else {}),
+            )
+            assert ti.state == expected_state
+
 
 class TestCurrentContext:
     def test_current_context_no_context_raise(self):

Reply via email to