This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 986409682ad007a28f134f9e4010c9abce6e94c0
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Apr 19 17:16:16 2023 +0200

    Add skip_on_exit_code also to ExternalPythonOperator (#30738)
    
    The change ##30690 and #30692 added skip_on_exit_code to the
    PythonVirtualenvOperator, but it skipped the - very closely related
    - ExternalPythonOperator.
    
    This change brings the same functionality to ExternalPythonOperator,
    moves it to the base class for both operators, it also adds
    separate Test class for ExternalPythonOperator, also introducing
    a common base class and moving the test methods that are common
    to both operators there.
    
    (cherry picked from commit 5ed466958a755085a351e0a593ec705e001723c5)
---
 airflow/operators/python.py    |  54 +++++----
 tests/operators/test_python.py | 260 +++++++++++++++++++++--------------------
 2 files changed, 167 insertions(+), 147 deletions(-)

diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index b3744c7c85..e9ce98fd9a 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -330,6 +330,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
         templates_dict: dict | None = None,
         templates_exts: list[str] | None = None,
         expect_airflow: bool = True,
+        skip_on_exit_code: int | Container[int] | None = None,
         **kwargs,
     ):
         if (
@@ -350,6 +351,13 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
         self.use_dill = use_dill
         self.pickling_library = dill if self.use_dill else pickle
         self.expect_airflow = expect_airflow
+        self.skip_on_exit_code = (
+            skip_on_exit_code
+            if isinstance(skip_on_exit_code, Container)
+            else [skip_on_exit_code]
+            if skip_on_exit_code
+            else []
+        )
 
     @abstractmethod
     def _iter_serializable_context_keys(self):
@@ -411,15 +419,22 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
             
render_template_as_native_obj=self.dag.render_template_as_native_obj,
         )
 
-        execute_in_subprocess(
-            cmd=[
-                os.fspath(python_path),
-                os.fspath(script_path),
-                os.fspath(input_path),
-                os.fspath(output_path),
-                os.fspath(string_args_path),
-            ]
-        )
+        try:
+            execute_in_subprocess(
+                cmd=[
+                    os.fspath(python_path),
+                    os.fspath(script_path),
+                    os.fspath(input_path),
+                    os.fspath(output_path),
+                    os.fspath(string_args_path),
+                ]
+            )
+        except subprocess.CalledProcessError as e:
+            if e.returncode in self.skip_on_exit_code:
+                raise AirflowSkipException(f"Process exited with code 
{e.returncode}. Skipping.")
+            else:
+                raise
+
         return self._read_result(output_path)
 
     def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, 
Any]:
@@ -519,13 +534,6 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         self.python_version = python_version
         self.system_site_packages = system_site_packages
         self.pip_install_options = pip_install_options
-        self.skip_on_exit_code = (
-            skip_on_exit_code
-            if isinstance(skip_on_exit_code, Container)
-            else [skip_on_exit_code]
-            if skip_on_exit_code
-            else []
-        )
         super().__init__(
             python_callable=python_callable,
             use_dill=use_dill,
@@ -535,6 +543,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
             templates_dict=templates_dict,
             templates_exts=templates_exts,
             expect_airflow=expect_airflow,
+            skip_on_exit_code=skip_on_exit_code,
             **kwargs,
         )
 
@@ -561,13 +570,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
                 pip_install_options=self.pip_install_options,
             )
             python_path = tmp_path / "bin" / "python"
-            try:
-                result = 
self._execute_python_callable_in_subprocess(python_path, tmp_path)
-            except subprocess.CalledProcessError as e:
-                if e.returncode in self.skip_on_exit_code:
-                    raise AirflowSkipException(f"Process exited with code 
{e.returncode}. Skipping.")
-                else:
-                    raise
+            result = self._execute_python_callable_in_subprocess(python_path, 
tmp_path)
             return result
 
     def _iter_serializable_context_keys(self):
@@ -624,6 +627,9 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
     :param expect_airflow: expect Airflow to be installed in the target 
environment. If true, the operator
         will raise warning if Airflow is not installed, and it will attempt to 
load Airflow
         macros when starting.
+    :param skip_on_exit_code: If python_callable exits with this exit code, 
leave the task
+        in ``skipped`` state (default: None). If set to ``None``, any non-zero
+        exit code will be treated as a failure.
     """
 
     template_fields: Sequence[str] = tuple({"python"} | 
set(PythonOperator.template_fields))
@@ -641,6 +647,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
         templates_exts: list[str] | None = None,
         expect_airflow: bool = True,
         expect_pendulum: bool = False,
+        skip_on_exit_code: int | Container[int] | None = None,
         **kwargs,
     ):
         if not python:
@@ -656,6 +663,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
             templates_dict=templates_dict,
             templates_exts=templates_exts,
             expect_airflow=expect_airflow,
+            skip_on_exit_code=skip_on_exit_code,
             **kwargs,
         )
 
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 32c3b30b64..5ff18746df 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -38,6 +38,7 @@ from airflow.models.taskinstance import clear_task_instances, 
set_current_contex
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import (
     BranchPythonOperator,
+    ExternalPythonOperator,
     PythonOperator,
     PythonVirtualenvOperator,
     ShortCircuitOperator,
@@ -607,7 +608,135 @@ class TestShortCircuitOperator(BasePythonTest):
 virtualenv_string_args: list[str] = []
 
 
-class TestPythonVirtualenvOperator(BasePythonTest):
+class BaseTestPythonVirtualenvOperator(BasePythonTest):
+    def test_template_fields(self):
+        assert 
set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields)
+
+    def test_fail(self):
+        def f():
+            raise Exception
+
+        with pytest.raises(CalledProcessError):
+            self.run_as_task(f)
+
+    def test_string_args(self):
+        def f():
+            global virtualenv_string_args
+            print(virtualenv_string_args)
+            if virtualenv_string_args[0] != virtualenv_string_args[2]:
+                raise Exception
+
+        self.run_as_task(f, string_args=[1, 2, 1])
+
+    def test_with_args(self):
+        def f(a, b, c=False, d=False):
+            if a == 0 and b == 1 and c and not d:
+                return True
+            else:
+                raise Exception
+
+        self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
+
+    def test_return_none(self):
+        def f():
+            return None
+
+        task = self.run_as_task(f)
+        assert task.execute_callable() is None
+
+    def test_return_false(self):
+        def f():
+            return False
+
+        task = self.run_as_task(f)
+        assert task.execute_callable() is False
+
+    def test_lambda(self):
+        with pytest.raises(AirflowException):
+            PythonVirtualenvOperator(python_callable=lambda x: 4, 
task_id=self.task_id)
+
+    def test_nonimported_as_arg(self):
+        def f(_):
+            return None
+
+        self.run_as_task(f, op_args=[datetime.utcnow()])
+
+    def test_context(self):
+        def f(templates_dict):
+            return templates_dict["ds"]
+
+        task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"})
+        assert task.templates_dict == {"ds": self.ds_templated}
+
+    def test_deepcopy(self):
+        """Test that PythonVirtualenvOperator are deep-copyable."""
+
+        def f():
+            return 1
+
+        task = PythonVirtualenvOperator(python_callable=f, task_id="task")
+        copy.deepcopy(task)
+
+    def test_virtualenv_serializable_context_fields(self, 
create_task_instance):
+        """Ensure all template context fields are listed in the operator.
+
+        This exists mainly so when a field is added to the context, we 
remember to
+        also add it to PythonVirtualenvOperator.
+        """
+        # These are intentionally NOT serialized into the virtual environment:
+        # * Variables pointing to the task instance itself.
+        # * Variables that are accessor instances.
+        intentionally_excluded_context_keys = [
+            "task_instance",
+            "ti",
+            "var",  # Accessor for Variable; var->json and var->value.
+            "conn",  # Accessor for Connection.
+        ]
+
+        ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, 
schedule=None)
+        context = ti.get_template_context()
+
+        declared_keys = {
+            *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
+            *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
+            *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
+            *intentionally_excluded_context_keys,
+        }
+        assert set(context) == declared_keys
+
+    @pytest.mark.parametrize(
+        "extra_kwargs, actual_exit_code, expected_state",
+        [
+            (None, 99, TaskInstanceState.FAILED),
+            ({"skip_on_exit_code": 100}, 100, TaskInstanceState.SKIPPED),
+            ({"skip_on_exit_code": [100]}, 100, TaskInstanceState.SKIPPED),
+            ({"skip_on_exit_code": (100, 101)}, 100, 
TaskInstanceState.SKIPPED),
+            ({"skip_on_exit_code": 100}, 101, TaskInstanceState.FAILED),
+            ({"skip_on_exit_code": [100, 102]}, 101, TaskInstanceState.FAILED),
+            ({"skip_on_exit_code": None}, 0, TaskInstanceState.SUCCESS),
+        ],
+    )
+    def test_on_skip_exit_code(self, extra_kwargs, actual_exit_code, 
expected_state):
+        def f(exit_code):
+            if exit_code != 0:
+                raise SystemExit(exit_code)
+
+        if expected_state == TaskInstanceState.FAILED:
+            with pytest.raises(CalledProcessError):
+                self.run_as_task(
+                    f, op_kwargs={"exit_code": actual_exit_code}, 
**(extra_kwargs if extra_kwargs else {})
+                )
+        else:
+            ti = self.run_as_task(
+                f,
+                return_ti=True,
+                op_kwargs={"exit_code": actual_exit_code},
+                **(extra_kwargs if extra_kwargs else {}),
+            )
+            assert ti.state == expected_state
+
+
+class TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
     opcls = PythonVirtualenvOperator
 
     @staticmethod
@@ -615,9 +744,6 @@ class TestPythonVirtualenvOperator(BasePythonTest):
         kwargs["python_version"] = python_version
         return kwargs
 
-    def test_template_fields(self):
-        assert 
set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields)
-
     def test_add_dill(self):
         def f():
             """Ensure dill is correctly installed."""
@@ -711,13 +837,6 @@ class TestPythonVirtualenvOperator(BasePythonTest):
             system_site_packages=False,
         )
 
-    def test_fail(self):
-        def f():
-            raise Exception
-
-        with pytest.raises(CalledProcessError):
-            self.run_as_task(f)
-
     def test_python_3(self):
         def f():
             import sys
@@ -737,55 +856,6 @@ class TestPythonVirtualenvOperator(BasePythonTest):
 
         self.run_as_task(f, system_site_packages=False, use_dill=False, 
op_args=[4])
 
-    def test_string_args(self):
-        def f():
-            global virtualenv_string_args
-            print(virtualenv_string_args)
-            if virtualenv_string_args[0] != virtualenv_string_args[2]:
-                raise Exception
-
-        self.run_as_task(f, string_args=[1, 2, 1])
-
-    def test_with_args(self):
-        def f(a, b, c=False, d=False):
-            if a == 0 and b == 1 and c and not d:
-                return True
-            else:
-                raise Exception
-
-        self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
-
-    def test_return_none(self):
-        def f():
-            return None
-
-        task = self.run_as_task(f)
-        assert task.execute_callable() is None
-
-    def test_return_false(self):
-        def f():
-            return False
-
-        task = self.run_as_task(f)
-        assert task.execute_callable() is False
-
-    def test_lambda(self):
-        with pytest.raises(AirflowException):
-            PythonVirtualenvOperator(python_callable=lambda x: 4, 
task_id=self.task_id)
-
-    def test_nonimported_as_arg(self):
-        def f(_):
-            return None
-
-        self.run_as_task(f, op_args=[datetime.utcnow()])
-
-    def test_context(self):
-        def f(templates_dict):
-            return templates_dict["ds"]
-
-        task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"})
-        assert task.templates_dict == {"ds": self.ds_templated}
-
     # This tests might take longer than default 60 seconds as it is 
serializing a lot of
     # context using dill (which is slow apparently).
     @pytest.mark.execution_timeout(120)
@@ -898,72 +968,14 @@ class TestPythonVirtualenvOperator(BasePythonTest):
 
         self.run_as_task(f, use_dill=True, system_site_packages=False, 
requirements=None)
 
-    def test_deepcopy(self):
-        """Test that PythonVirtualenvOperator are deep-copyable."""
-
-        def f():
-            return 1
-
-        task = PythonVirtualenvOperator(python_callable=f, task_id="task")
-        copy.deepcopy(task)
-
-    def test_virtualenv_serializable_context_fields(self, 
create_task_instance):
-        """Ensure all template context fields are listed in the operator.
-
-        This exists mainly so when a field is added to the context, we 
remember to
-        also add it to PythonVirtualenvOperator.
-        """
-        # These are intentionally NOT serialized into the virtual environment:
-        # * Variables pointing to the task instance itself.
-        # * Variables that are accessor instances.
-        intentionally_excluded_context_keys = [
-            "task_instance",
-            "ti",
-            "var",  # Accessor for Variable; var->json and var->value.
-            "conn",  # Accessor for Connection.
-        ]
-
-        ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, 
schedule=None)
-        context = ti.get_template_context()
 
-        declared_keys = {
-            *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
-            *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
-            *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
-            *intentionally_excluded_context_keys,
-        }
-        assert set(context) == declared_keys
+class TestExternalPythonOperator(BaseTestPythonVirtualenvOperator):
+    opcls = ExternalPythonOperator
 
-    @pytest.mark.parametrize(
-        "extra_kwargs, actual_exit_code, expected_state",
-        [
-            (None, 99, TaskInstanceState.FAILED),
-            ({"skip_on_exit_code": 100}, 100, TaskInstanceState.SKIPPED),
-            ({"skip_on_exit_code": [100]}, 100, TaskInstanceState.SKIPPED),
-            ({"skip_on_exit_code": (100, 101)}, 100, 
TaskInstanceState.SKIPPED),
-            ({"skip_on_exit_code": 100}, 101, TaskInstanceState.FAILED),
-            ({"skip_on_exit_code": [100, 102]}, 101, TaskInstanceState.FAILED),
-            ({"skip_on_exit_code": None}, 0, TaskInstanceState.SUCCESS),
-        ],
-    )
-    def test_on_skip_exit_code(self, extra_kwargs, actual_exit_code, 
expected_state):
-        def f(exit_code):
-            if exit_code != 0:
-                raise SystemExit(exit_code)
-
-        if expected_state == TaskInstanceState.FAILED:
-            with pytest.raises(CalledProcessError):
-                self.run_as_task(
-                    f, op_kwargs={"exit_code": actual_exit_code}, 
**(extra_kwargs if extra_kwargs else {})
-                )
-        else:
-            ti = self.run_as_task(
-                f,
-                return_ti=True,
-                op_kwargs={"exit_code": actual_exit_code},
-                **(extra_kwargs if extra_kwargs else {}),
-            )
-            assert ti.state == expected_state
+    @staticmethod
+    def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+        kwargs["python"] = sys.executable
+        return kwargs
 
 
 class TestCurrentContext:

Reply via email to