This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-6-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 986409682ad007a28f134f9e4010c9abce6e94c0 Author: Jarek Potiuk <[email protected]> AuthorDate: Wed Apr 19 17:16:16 2023 +0200 Add skip_on_exit_code also to ExternalPythonOperator (#30738) The change ##30690 and #30692 added skip_on_exit_code to the PythonVirtualenvOperator, but it skipped the - very closely related - ExternalPythonOperator. This change brings the same functionality to ExternalPythonOperator, moves it to the base class for both operators, it also adds separate Test class for ExternalPythonOperator, also introducing a common base class and moving the test methods that are common to both operators there. (cherry picked from commit 5ed466958a755085a351e0a593ec705e001723c5) --- airflow/operators/python.py | 54 +++++---- tests/operators/test_python.py | 260 +++++++++++++++++++++-------------------- 2 files changed, 167 insertions(+), 147 deletions(-) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index b3744c7c85..e9ce98fd9a 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -330,6 +330,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): templates_dict: dict | None = None, templates_exts: list[str] | None = None, expect_airflow: bool = True, + skip_on_exit_code: int | Container[int] | None = None, **kwargs, ): if ( @@ -350,6 +351,13 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): self.use_dill = use_dill self.pickling_library = dill if self.use_dill else pickle self.expect_airflow = expect_airflow + self.skip_on_exit_code = ( + skip_on_exit_code + if isinstance(skip_on_exit_code, Container) + else [skip_on_exit_code] + if skip_on_exit_code + else [] + ) @abstractmethod def _iter_serializable_context_keys(self): @@ -411,15 +419,22 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): render_template_as_native_obj=self.dag.render_template_as_native_obj, ) - execute_in_subprocess( - cmd=[ - os.fspath(python_path), - os.fspath(script_path), - os.fspath(input_path), - os.fspath(output_path), - os.fspath(string_args_path), - ] - ) + try: + execute_in_subprocess( + cmd=[ + os.fspath(python_path), + os.fspath(script_path), + os.fspath(input_path), + os.fspath(output_path), + os.fspath(string_args_path), + ] + ) + except subprocess.CalledProcessError as e: + if e.returncode in self.skip_on_exit_code: + raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.") + else: + raise + return self._read_result(output_path) def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: @@ -519,13 +534,6 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): self.python_version = python_version self.system_site_packages = system_site_packages self.pip_install_options = pip_install_options - self.skip_on_exit_code = ( - skip_on_exit_code - if isinstance(skip_on_exit_code, Container) - else [skip_on_exit_code] - if skip_on_exit_code - else [] - ) super().__init__( python_callable=python_callable, use_dill=use_dill, @@ -535,6 +543,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): templates_dict=templates_dict, templates_exts=templates_exts, expect_airflow=expect_airflow, + skip_on_exit_code=skip_on_exit_code, **kwargs, ) @@ -561,13 +570,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): pip_install_options=self.pip_install_options, ) python_path = tmp_path / "bin" / "python" - try: - result = self._execute_python_callable_in_subprocess(python_path, tmp_path) - except subprocess.CalledProcessError as e: - if e.returncode in self.skip_on_exit_code: - raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.") - else: - raise + result = self._execute_python_callable_in_subprocess(python_path, tmp_path) return result def _iter_serializable_context_keys(self): @@ -624,6 +627,9 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator will raise warning if Airflow is not installed, and it will attempt to load Airflow macros when starting. + :param skip_on_exit_code: If python_callable exits with this exit code, leave the task + in ``skipped`` state (default: None). If set to ``None``, any non-zero + exit code will be treated as a failure. """ template_fields: Sequence[str] = tuple({"python"} | set(PythonOperator.template_fields)) @@ -641,6 +647,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): templates_exts: list[str] | None = None, expect_airflow: bool = True, expect_pendulum: bool = False, + skip_on_exit_code: int | Container[int] | None = None, **kwargs, ): if not python: @@ -656,6 +663,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): templates_dict=templates_dict, templates_exts=templates_exts, expect_airflow=expect_airflow, + skip_on_exit_code=skip_on_exit_code, **kwargs, ) diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 32c3b30b64..5ff18746df 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -38,6 +38,7 @@ from airflow.models.taskinstance import clear_task_instances, set_current_contex from airflow.operators.empty import EmptyOperator from airflow.operators.python import ( BranchPythonOperator, + ExternalPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator, @@ -607,7 +608,135 @@ class TestShortCircuitOperator(BasePythonTest): virtualenv_string_args: list[str] = [] -class TestPythonVirtualenvOperator(BasePythonTest): +class BaseTestPythonVirtualenvOperator(BasePythonTest): + def test_template_fields(self): + assert set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields) + + def test_fail(self): + def f(): + raise Exception + + with pytest.raises(CalledProcessError): + self.run_as_task(f) + + def test_string_args(self): + def f(): + global virtualenv_string_args + print(virtualenv_string_args) + if virtualenv_string_args[0] != virtualenv_string_args[2]: + raise Exception + + self.run_as_task(f, string_args=[1, 2, 1]) + + def test_with_args(self): + def f(a, b, c=False, d=False): + if a == 0 and b == 1 and c and not d: + return True + else: + raise Exception + + self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True}) + + def test_return_none(self): + def f(): + return None + + task = self.run_as_task(f) + assert task.execute_callable() is None + + def test_return_false(self): + def f(): + return False + + task = self.run_as_task(f) + assert task.execute_callable() is False + + def test_lambda(self): + with pytest.raises(AirflowException): + PythonVirtualenvOperator(python_callable=lambda x: 4, task_id=self.task_id) + + def test_nonimported_as_arg(self): + def f(_): + return None + + self.run_as_task(f, op_args=[datetime.utcnow()]) + + def test_context(self): + def f(templates_dict): + return templates_dict["ds"] + + task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"}) + assert task.templates_dict == {"ds": self.ds_templated} + + def test_deepcopy(self): + """Test that PythonVirtualenvOperator are deep-copyable.""" + + def f(): + return 1 + + task = PythonVirtualenvOperator(python_callable=f, task_id="task") + copy.deepcopy(task) + + def test_virtualenv_serializable_context_fields(self, create_task_instance): + """Ensure all template context fields are listed in the operator. + + This exists mainly so when a field is added to the context, we remember to + also add it to PythonVirtualenvOperator. + """ + # These are intentionally NOT serialized into the virtual environment: + # * Variables pointing to the task instance itself. + # * Variables that are accessor instances. + intentionally_excluded_context_keys = [ + "task_instance", + "ti", + "var", # Accessor for Variable; var->json and var->value. + "conn", # Accessor for Connection. + ] + + ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) + context = ti.get_template_context() + + declared_keys = { + *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS, + *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS, + *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS, + *intentionally_excluded_context_keys, + } + assert set(context) == declared_keys + + @pytest.mark.parametrize( + "extra_kwargs, actual_exit_code, expected_state", + [ + (None, 99, TaskInstanceState.FAILED), + ({"skip_on_exit_code": 100}, 100, TaskInstanceState.SKIPPED), + ({"skip_on_exit_code": [100]}, 100, TaskInstanceState.SKIPPED), + ({"skip_on_exit_code": (100, 101)}, 100, TaskInstanceState.SKIPPED), + ({"skip_on_exit_code": 100}, 101, TaskInstanceState.FAILED), + ({"skip_on_exit_code": [100, 102]}, 101, TaskInstanceState.FAILED), + ({"skip_on_exit_code": None}, 0, TaskInstanceState.SUCCESS), + ], + ) + def test_on_skip_exit_code(self, extra_kwargs, actual_exit_code, expected_state): + def f(exit_code): + if exit_code != 0: + raise SystemExit(exit_code) + + if expected_state == TaskInstanceState.FAILED: + with pytest.raises(CalledProcessError): + self.run_as_task( + f, op_kwargs={"exit_code": actual_exit_code}, **(extra_kwargs if extra_kwargs else {}) + ) + else: + ti = self.run_as_task( + f, + return_ti=True, + op_kwargs={"exit_code": actual_exit_code}, + **(extra_kwargs if extra_kwargs else {}), + ) + assert ti.state == expected_state + + +class TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator): opcls = PythonVirtualenvOperator @staticmethod @@ -615,9 +744,6 @@ class TestPythonVirtualenvOperator(BasePythonTest): kwargs["python_version"] = python_version return kwargs - def test_template_fields(self): - assert set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields) - def test_add_dill(self): def f(): """Ensure dill is correctly installed.""" @@ -711,13 +837,6 @@ class TestPythonVirtualenvOperator(BasePythonTest): system_site_packages=False, ) - def test_fail(self): - def f(): - raise Exception - - with pytest.raises(CalledProcessError): - self.run_as_task(f) - def test_python_3(self): def f(): import sys @@ -737,55 +856,6 @@ class TestPythonVirtualenvOperator(BasePythonTest): self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4]) - def test_string_args(self): - def f(): - global virtualenv_string_args - print(virtualenv_string_args) - if virtualenv_string_args[0] != virtualenv_string_args[2]: - raise Exception - - self.run_as_task(f, string_args=[1, 2, 1]) - - def test_with_args(self): - def f(a, b, c=False, d=False): - if a == 0 and b == 1 and c and not d: - return True - else: - raise Exception - - self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True}) - - def test_return_none(self): - def f(): - return None - - task = self.run_as_task(f) - assert task.execute_callable() is None - - def test_return_false(self): - def f(): - return False - - task = self.run_as_task(f) - assert task.execute_callable() is False - - def test_lambda(self): - with pytest.raises(AirflowException): - PythonVirtualenvOperator(python_callable=lambda x: 4, task_id=self.task_id) - - def test_nonimported_as_arg(self): - def f(_): - return None - - self.run_as_task(f, op_args=[datetime.utcnow()]) - - def test_context(self): - def f(templates_dict): - return templates_dict["ds"] - - task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"}) - assert task.templates_dict == {"ds": self.ds_templated} - # This tests might take longer than default 60 seconds as it is serializing a lot of # context using dill (which is slow apparently). @pytest.mark.execution_timeout(120) @@ -898,72 +968,14 @@ class TestPythonVirtualenvOperator(BasePythonTest): self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=None) - def test_deepcopy(self): - """Test that PythonVirtualenvOperator are deep-copyable.""" - - def f(): - return 1 - - task = PythonVirtualenvOperator(python_callable=f, task_id="task") - copy.deepcopy(task) - - def test_virtualenv_serializable_context_fields(self, create_task_instance): - """Ensure all template context fields are listed in the operator. - - This exists mainly so when a field is added to the context, we remember to - also add it to PythonVirtualenvOperator. - """ - # These are intentionally NOT serialized into the virtual environment: - # * Variables pointing to the task instance itself. - # * Variables that are accessor instances. - intentionally_excluded_context_keys = [ - "task_instance", - "ti", - "var", # Accessor for Variable; var->json and var->value. - "conn", # Accessor for Connection. - ] - - ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) - context = ti.get_template_context() - declared_keys = { - *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS, - *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS, - *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS, - *intentionally_excluded_context_keys, - } - assert set(context) == declared_keys +class TestExternalPythonOperator(BaseTestPythonVirtualenvOperator): + opcls = ExternalPythonOperator - @pytest.mark.parametrize( - "extra_kwargs, actual_exit_code, expected_state", - [ - (None, 99, TaskInstanceState.FAILED), - ({"skip_on_exit_code": 100}, 100, TaskInstanceState.SKIPPED), - ({"skip_on_exit_code": [100]}, 100, TaskInstanceState.SKIPPED), - ({"skip_on_exit_code": (100, 101)}, 100, TaskInstanceState.SKIPPED), - ({"skip_on_exit_code": 100}, 101, TaskInstanceState.FAILED), - ({"skip_on_exit_code": [100, 102]}, 101, TaskInstanceState.FAILED), - ({"skip_on_exit_code": None}, 0, TaskInstanceState.SUCCESS), - ], - ) - def test_on_skip_exit_code(self, extra_kwargs, actual_exit_code, expected_state): - def f(exit_code): - if exit_code != 0: - raise SystemExit(exit_code) - - if expected_state == TaskInstanceState.FAILED: - with pytest.raises(CalledProcessError): - self.run_as_task( - f, op_kwargs={"exit_code": actual_exit_code}, **(extra_kwargs if extra_kwargs else {}) - ) - else: - ti = self.run_as_task( - f, - return_ti=True, - op_kwargs={"exit_code": actual_exit_code}, - **(extra_kwargs if extra_kwargs else {}), - ) - assert ti.state == expected_state + @staticmethod + def default_kwargs(*, python_version=sys.version_info[0], **kwargs): + kwargs["python"] = sys.executable + return kwargs class TestCurrentContext:
