This is an automated email from the ASF dual-hosted git repository. feluelle pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new 3f2eee1 Fix PythonVirtualenvOperator not working with Airflow context (#9394) 3f2eee1 is described below commit 3f2eee15f9278d4c187ed5cde2772bc26dffdcbc Author: Felix Uellendall <felue...@users.noreply.github.com> AuthorDate: Thu Jul 30 10:32:52 2020 +0200 Fix PythonVirtualenvOperator not working with Airflow context (#9394) - automatically add dill requirement if use_dill=True - add howto docs - refactor Co-authored-by: Luis Magana <maganal...@users.noreply.github.com> --- airflow/example_dags/example_python_operator.py | 2 + airflow/operators/python.py | 213 ++++++++++++------------ airflow/utils/python_virtualenv.py | 22 +++ airflow/utils/python_virtualenv_script.jinja2 | 42 +++++ docs/howto/operator/python.rst | 31 ++++ tests/operators/test_python.py | 157 ++++++++++++++++- 6 files changed, 352 insertions(+), 115 deletions(-) diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 5476bd5..5b6d7b5 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -72,6 +72,7 @@ for i in range(5): # [END howto_operator_python_kwargs] +# [START howto_operator_python_venv] def callable_virtualenv(): """ Example function that will be performed in a virtual environment. @@ -101,3 +102,4 @@ virtualenv_task = PythonVirtualenvOperator( system_site_packages=False, dag=dag, ) +# [END howto_operator_python_venv] diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 92cb2bb..05f2f0b 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -26,7 +26,7 @@ from inspect import signature from itertools import islice from tempfile import TemporaryDirectory from textwrap import dedent -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast import dill @@ -38,7 +38,7 @@ from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.xcom_arg import XComArg from airflow.utils.decorators import apply_defaults from airflow.utils.process_utils import execute_in_subprocess -from airflow.utils.python_virtualenv import prepare_virtualenv +from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script class PythonOperator(BaseOperator): @@ -363,6 +363,10 @@ class PythonVirtualenvOperator(PythonOperator): Note that if your virtualenv runs in a different Python major version than Airflow, you cannot use return values, op_args, or op_kwargs. You can use string_args though. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:PythonVirtualenvOperator` + :param python_callable: A python function with no references to outside variables, defined with def, which will be run in a virtualenv :type python_callable: function @@ -370,7 +374,7 @@ class PythonVirtualenvOperator(PythonOperator): :type requirements: list[str] :param python_version: The Python version to run the virtualenv with. Note that both 2 and 2.7 are acceptable forms. - :type python_version: str + :type python_version: Optional[Union[str, int, float]] :param use_dill: Whether to use dill to serialize the args and result (pickle is default). This allow more complex types but requires you to include dill in your requirements. @@ -397,13 +401,48 @@ class PythonVirtualenvOperator(PythonOperator): :type templates_exts: list[str] """ + BASE_SERIALIZABLE_CONTEXT_KEYS = { + 'ds_nodash', + 'inlets', + 'next_ds', + 'next_ds_nodash', + 'outlets', + 'params', + 'prev_ds', + 'prev_ds_nodash', + 'run_id', + 'task_instance_key_str', + 'test_mode', + 'tomorrow_ds', + 'tomorrow_ds_nodash', + 'ts', + 'ts_nodash', + 'ts_nodash_with_tz', + 'yesterday_ds', + 'yesterday_ds_nodash' + } + PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { + 'execution_date', + 'next_execution_date', + 'prev_execution_date', + 'prev_execution_date_success', + 'prev_start_date_success' + } + AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = { + 'macros', + 'conf', + 'dag', + 'dag_run', + 'task' + } + @apply_defaults def __init__( # pylint: disable=too-many-arguments self, *, python_callable: Callable, requirements: Optional[Iterable[str]] = None, - python_version: Optional[str] = None, + python_version: Optional[Union[str, int, float]] = None, use_dill: bool = False, system_site_packages: bool = True, op_args: Optional[List] = None, @@ -413,6 +452,17 @@ class PythonVirtualenvOperator(PythonOperator): templates_exts: Optional[List[str]] = None, **kwargs ): + if ( + not isinstance(python_callable, types.FunctionType) or + isinstance(python_callable, types.LambdaType) and python_callable.__name__ == "<lambda>" + ): + raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg') + if ( + python_version and str(python_version)[0] != str(sys.version_info.major) and + (op_args or op_kwargs) + ): + raise AirflowException("Passing op_args or op_kwargs is not supported across different Python " + "major versions for PythonVirtualenvOperator. Please use string_args.") super().__init__( python_callable=python_callable, op_args=op_args, @@ -420,144 +470,93 @@ class PythonVirtualenvOperator(PythonOperator): templates_dict=templates_dict, templates_exts=templates_exts, **kwargs) - self.requirements = requirements or [] + self.requirements = list(requirements or []) self.string_args = string_args or [] self.python_version = python_version self.use_dill = use_dill self.system_site_packages = system_site_packages - # check that dill is present if needed - dill_in_requirements = map(lambda x: x.lower().startswith('dill'), - self.requirements) - if (not system_site_packages) and use_dill and not any(dill_in_requirements): - raise AirflowException('If using dill, dill must be in the environment ' + - 'either via system_site_packages or requirements') - # check that a function is passed, and that it is not a lambda - if (not isinstance(self.python_callable, - types.FunctionType) or (self.python_callable.__name__ == - (lambda x: 0).__name__)): - raise AirflowException('{} only supports functions for python_callable arg'.format( - self.__class__.__name__)) - # check that args are passed iff python major version matches - if (python_version is not None and - str(python_version)[0] != str(sys.version_info[0]) and - self._pass_op_args()): - raise AirflowException("Passing op_args or op_kwargs is not supported across " - "different Python major versions " - "for PythonVirtualenvOperator. " - "Please use string_args.") + if not self.system_site_packages and self.use_dill and 'dill' not in self.requirements: + self.requirements.append('dill') + self.pickling_library = dill if self.use_dill else pickle + + def execute(self, context: Dict): + serializable_context = {key: context[key] for key in self._get_serializable_context_keys()} + super().execute(context=serializable_context) def execute_callable(self): with TemporaryDirectory(prefix='venv') as tmp_dir: if self.templates_dict: self.op_kwargs['templates_dict'] = self.templates_dict - # generate filenames + input_filename = os.path.join(tmp_dir, 'script.in') output_filename = os.path.join(tmp_dir, 'script.out') string_args_filename = os.path.join(tmp_dir, 'string_args.txt') script_filename = os.path.join(tmp_dir, 'script.py') - # set up virtualenv - python_bin = 'python' + str(self.python_version) if self.python_version else None prepare_virtualenv( venv_directory=tmp_dir, - python_bin=python_bin, + python_bin=f'python{self.python_version}' if self.python_version else None, system_site_packages=self.system_site_packages, - requirements=self.requirements, + requirements=self.requirements ) self._write_args(input_filename) - self._write_script(script_filename) self._write_string_args(string_args_filename) + write_python_script( + jinja_context=dict( + op_args=self.op_args, + op_kwargs=self.op_kwargs, + pickling_library=self.pickling_library.__name__, + python_callable=self.python_callable.__name__, + python_callable_source=dedent(inspect.getsource(self.python_callable)) + ), + filename=script_filename + ) + + execute_in_subprocess(cmd=[ + f'{tmp_dir}/bin/python', + script_filename, + input_filename, + output_filename, + string_args_filename + ]) - # execute command in virtualenv - execute_in_subprocess( - self._generate_python_cmd(tmp_dir, - script_filename, - input_filename, - output_filename, - string_args_filename)) return self._read_result(output_filename) - def _pass_op_args(self): - # we should only pass op_args if any are given to us - return len(self.op_args) + len(self.op_kwargs) > 0 + def _write_args(self, filename): + if self.op_args or self.op_kwargs: + with open(filename, 'wb') as file: + self.pickling_library.dump({'args': self.op_args, 'kwargs': self.op_kwargs}, file) + + def _get_serializable_context_keys(self): + def _is_airflow_env(): + return self.system_site_packages or 'apache-airflow' in self.requirements + + def _is_pendulum_env(): + return 'pendulum' in self.requirements and 'lazy_object_proxy' in self.requirements + + serializable_context_keys = self.BASE_SERIALIZABLE_CONTEXT_KEYS.copy() + if _is_airflow_env(): + serializable_context_keys.update(self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS) + if _is_pendulum_env() or _is_airflow_env(): + serializable_context_keys.update(self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS) + return serializable_context_keys def _write_string_args(self, filename): - # writes string_args to a file, which are read line by line with open(filename, 'w') as file: file.write('\n'.join(map(str, self.string_args))) - def _write_args(self, input_filename): - # serialize args to file - if self._pass_op_args(): - with open(input_filename, 'wb') as file: - arg_dict = ({'args': self.op_args, 'kwargs': self.op_kwargs}) - if self.use_dill: - dill.dump(arg_dict, file) - else: - pickle.dump(arg_dict, file) - - def _read_result(self, output_filename): - if os.stat(output_filename).st_size == 0: + def _read_result(self, filename): + if os.stat(filename).st_size == 0: return None - with open(output_filename, 'rb') as file: + with open(filename, 'rb') as file: try: - if self.use_dill: - return dill.load(file) - else: - return pickle.load(file) + return self.pickling_library.load(file) except ValueError: - self.log.error("Error deserializing result. " - "Note that result deserialization " + self.log.error("Error deserializing result. Note that result deserialization " "is not supported across major Python versions.") raise - def _write_script(self, script_filename): - with open(script_filename, 'w') as file: - python_code = self._generate_python_code() - self.log.debug('Writing code to file\n %s', python_code) - file.write(python_code) - - @staticmethod - def _generate_python_cmd(tmp_dir, script_filename, - input_filename, output_filename, string_args_filename): - # direct path alleviates need to activate - return ['{}/bin/python'.format(tmp_dir), script_filename, - input_filename, output_filename, string_args_filename] - - def _generate_python_code(self): - if self.use_dill: - pickling_library = 'dill' - else: - pickling_library = 'pickle' - - # dont try to read pickle if we didnt pass anything - if self._pass_op_args(): - load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \ - .format(pickling_library) - else: - load_args_line = 'arg_dict = {"args": [], "kwargs": {}}' - - # no indents in original code so we can accept - # any type of indents in the original function - # we deserialize args, call function, serialize result if necessary - return dedent("""\ - import {pickling_library} - import sys - {load_args_code} - args = arg_dict["args"] - kwargs = arg_dict["kwargs"] - with open(sys.argv[3], 'r') as file: - virtualenv_string_args = list(map(lambda x: x.strip(), list(file))) - {python_callable_lines} - res = {python_callable_name}(*args, **kwargs) - with open(sys.argv[2], 'wb') as file: - res is not None and {pickling_library}.dump(res, file) - """).format(load_args_code=load_args_line, - python_callable_lines=dedent(inspect.getsource(self.python_callable)), - python_callable_name=self.python_callable.__name__, - pickling_library=pickling_library) - def get_current_context() -> Dict[str, Any]: """ diff --git a/airflow/utils/python_virtualenv.py b/airflow/utils/python_virtualenv.py index 09b6d37..ff94644 100644 --- a/airflow/utils/python_virtualenv.py +++ b/airflow/utils/python_virtualenv.py @@ -19,8 +19,11 @@ """ Utilities for creating a virtual environment """ +import os from typing import List, Optional +import jinja2 + from airflow.utils.process_utils import execute_in_subprocess @@ -69,3 +72,22 @@ def prepare_virtualenv( execute_in_subprocess(pip_cmd) return '{}/bin/python'.format(venv_directory) + + +def write_python_script(jinja_context: dict, filename: str): + """ + Renders the python script to a file to execute in the virtual environment. + + :param jinja_context: The jinja context variables to unpack and replace with its placeholders in the + template file. + :type jinja_context: dict + :param filename: The name of the file to dump the rendered script to. + :type filename: str + """ + template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) + template_env = jinja2.Environment( + loader=template_loader, + undefined=jinja2.StrictUndefined + ) + template = template_env.get_template('python_virtualenv_script.jinja2') + template.stream(**jinja_context).dump(filename) diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 new file mode 100644 index 0000000..f2dd875 --- /dev/null +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -0,0 +1,42 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} + +import {{ pickling_library }} +import sys + +# Read args +{% if op_args or op_kwargs %} +with open(sys.argv[1], "rb") as file: + arg_dict = {{ pickling_library }}.load(file) +{% else %} +arg_dict = {"args": [], "kwargs": {}} +{% endif %} + +# Read string args +with open(sys.argv[3], "r") as file: + virtualenv_string_args = list(map(lambda x: x.strip(), list(file))) + +# Script +{{ python_callable_source }} +res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) + +# Write output +with open(sys.argv[2], "wb") as file: + if res: + {{ pickling_library }}.dump(res, file) diff --git a/docs/howto/operator/python.rst b/docs/howto/operator/python.rst index 46e7d4c..8b6d726 100644 --- a/docs/howto/operator/python.rst +++ b/docs/howto/operator/python.rst @@ -50,3 +50,34 @@ argument. The ``templates_dict`` argument is templated, so each value in the dictionary is evaluated as a :ref:`Jinja template <jinja-templating>`. + + + +.. _howto/operator:PythonVirtualenvOperator: + +PythonVirtualenvOperator +======================== + +Use the :class:`~airflow.operators.python.PythonVirtualenvOperator` to execute +Python callables inside a new Python virtual environment. + +.. exampleinclude:: ../../../airflow/example_dags/example_python_operator.py + :language: python + :start-after: [START howto_operator_python_venv] + :end-before: [END howto_operator_python_venv] + +Passing in arguments +^^^^^^^^^^^^^^^^^^^^ + +You can use the ``op_args`` and ``op_kwargs`` arguments the same way you use it in the PythonOperator. +Unfortunately we currently do not support to serialize ``var`` and ``ti`` / ``task_instance`` due to incompatibilities +with the underlying library. For airflow context variables make sure that you either have access to Airflow through +setting ``system_site_packages`` to ``True`` or add ``apache-airflow`` to the ``requirements`` argument. +Otherwise you won't have access to the most context variables of Airflow in ``op_kwargs``. +If you want the context related to datetime objects like ``execution_date`` you can add ``pendulum`` and +``lazy_object_proxy``. + +Templating +^^^^^^^^^^ + +You can use jinja Templating the same way you use it in PythonOperator. diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 6bcf712..a8a04bb 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -68,13 +68,16 @@ def build_recording_function(calls_collection): Then using this custom function recording custom Call objects for further testing (replacing Mock.assert_called_with assertion method) """ + def recording_function(*args, **kwargs): calls_collection.append(Call(*args, **kwargs)) + return recording_function class TestPythonBase(unittest.TestCase): """Base test class for TestPythonOperator and TestPythonSensor classes""" + @classmethod def setUpClass(cls): super().setUpClass() @@ -326,9 +329,11 @@ class TestAirflowTaskDecorator(TestPythonBase): def test_fails_bad_signature(self): """Tests that @task will fail if signature is not binding.""" + @task_decorator def add_number(num: int) -> int: return num + 2 + with pytest.raises(TypeError): add_number(2, 3) # pylint: disable=too-many-function-args with pytest.raises(TypeError): @@ -345,12 +350,14 @@ class TestAirflowTaskDecorator(TestPythonBase): @task_decorator def add_number(self, num: int) -> int: return self.num + num + Test().add_number(2) def test_fail_multiple_outputs_key_type(self): @task_decorator(multiple_outputs=True) def add_number(num: int): return {2: num} + with self.dag: ret = add_number(2) self.dag.create_dagrun( @@ -450,6 +457,7 @@ class TestAirflowTaskDecorator(TestPythonBase): @task_decorator(task_id='some_name') def do_run(): return 4 + with self.dag: do_run() assert ['some_name'] == self.dag.task_ids @@ -460,6 +468,7 @@ class TestAirflowTaskDecorator(TestPythonBase): @task_decorator def do_run(): return 4 + with self.dag: do_run() assert ['do_run'] == self.dag.task_ids @@ -472,6 +481,7 @@ class TestAirflowTaskDecorator(TestPythonBase): def test_call_20(self): """Test calling decorated function 21 times in a DAG""" + @task_decorator def __do_run(): return 4 @@ -513,6 +523,7 @@ class TestAirflowTaskDecorator(TestPythonBase): def test_default_args(self): """Test that default_args are captured when calling the function correctly""" + @task_decorator def do_run(): return 4 @@ -1060,22 +1071,21 @@ class TestPythonVirtualenvOperator(unittest.TestCase): dag=self.dag, **kwargs) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + return task - def test_dill_warning(self): + def test_add_dill(self): def f(): pass - with self.assertRaises(AirflowException): - PythonVirtualenvOperator( - python_callable=f, - task_id='task', - dag=self.dag, - use_dill=True, - system_site_packages=False) + + task = self._run_as_operator(f, use_dill=True, system_site_packages=False) + assert 'dill' in task.requirements def test_no_requirements(self): """Tests that the python callable is invoked on task run.""" + def f(): pass + self._run_as_operator(f) def test_no_system_site_packages(self): @@ -1085,11 +1095,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase): except ImportError: return True raise Exception + self._run_as_operator(f, system_site_packages=False, requirements=['dill']) def test_system_site_packages(self): def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import + self._run_as_operator(f, requirements=['funcsigs'], system_site_packages=True) def test_with_requirements_pinned(self): @@ -1106,30 +1118,35 @@ class TestPythonVirtualenvOperator(unittest.TestCase): def test_unpinned_requirements(self): def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import + self._run_as_operator( f, requirements=['funcsigs', 'dill'], system_site_packages=False) def test_range_requirements(self): def f(): import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import + self._run_as_operator( f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False) def test_fail(self): def f(): raise Exception + with self.assertRaises(CalledProcessError): self._run_as_operator(f) def test_python_2(self): def f(): {}.iteritems() # pylint: disable=no-member + self._run_as_operator(f, python_version=2, requirements=['dill']) def test_python_2_7(self): def f(): {}.iteritems() # pylint: disable=no-member return True + self._run_as_operator(f, python_version='2.7', requirements=['dill']) def test_python_3(self): @@ -1141,6 +1158,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): except AttributeError: return raise Exception + self._run_as_operator(f, python_version=3, use_dill=False, requirements=['dill']) @staticmethod @@ -1165,6 +1183,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): def test_without_dill(self): def f(a): return a + self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4]) def test_string_args(self): @@ -1173,6 +1192,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): print(virtualenv_string_args) if virtualenv_string_args[0] != virtualenv_string_args[2]: raise Exception + self._run_as_operator( f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1]) @@ -1182,11 +1202,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase): return True else: raise Exception + self._run_as_operator(f, op_args=[0, 1], op_kwargs={'c': True}) def test_return_none(self): def f(): return None + self._run_as_operator(f) def test_lambda(self): @@ -1199,13 +1221,132 @@ class TestPythonVirtualenvOperator(unittest.TestCase): def test_nonimported_as_arg(self): def f(_): return None + self._run_as_operator(f, op_args=[datetime.utcnow()]) def test_context(self): def f(templates_dict): return templates_dict['ds'] + self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) + def test_airflow_context(self): + def f( + # basic + ds_nodash, + inlets, + next_ds, + next_ds_nodash, + outlets, + params, + prev_ds, + prev_ds_nodash, + run_id, + task_instance_key_str, + test_mode, + tomorrow_ds, + tomorrow_ds_nodash, + ts, + ts_nodash, + ts_nodash_with_tz, + yesterday_ds, + yesterday_ds_nodash, + # pendulum-specific + execution_date, + next_execution_date, + prev_execution_date, + prev_execution_date_success, + prev_start_date_success, + # airflow-specific + macros, + conf, + dag, + dag_run, + task, + # other + **context + ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals + pass + + self._run_as_operator( + f, + use_dill=True, + system_site_packages=True, + requirements=None + ) + + def test_pendulum_context(self): + def f( + # basic + ds_nodash, + inlets, + next_ds, + next_ds_nodash, + outlets, + params, + prev_ds, + prev_ds_nodash, + run_id, + task_instance_key_str, + test_mode, + tomorrow_ds, + tomorrow_ds_nodash, + ts, + ts_nodash, + ts_nodash_with_tz, + yesterday_ds, + yesterday_ds_nodash, + # pendulum-specific + execution_date, + next_execution_date, + prev_execution_date, + prev_execution_date_success, + prev_start_date_success, + # other + **context + ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals + pass + + self._run_as_operator( + f, + use_dill=True, + system_site_packages=False, + requirements=['pendulum', 'lazy_object_proxy'] + ) + + def test_base_context(self): + def f( + # basic + ds_nodash, + inlets, + next_ds, + next_ds_nodash, + outlets, + params, + prev_ds, + prev_ds_nodash, + run_id, + task_instance_key_str, + test_mode, + tomorrow_ds, + tomorrow_ds_nodash, + ts, + ts_nodash, + ts_nodash_with_tz, + yesterday_ds, + yesterday_ds_nodash, + # other + **context + ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals + pass + + self._run_as_operator( + f, + use_dill=True, + system_site_packages=False, + requirements=None + ) + DEFAULT_ARGS = { "owner": "test",