This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4e431ec87509a19eb6db330434197161e97b860a Author: Jun <[email protected]> AuthorDate: Thu May 27 23:34:03 2021 +0800 Fix apply defaults for task decorator (#16085) (cherry picked from commit 9d06ee8019ecbc07d041ccede15d0e322aa797a3) --- airflow/decorators/base.py | 14 ++++++++++++++ airflow/models/baseoperator.py | 6 ++++++ tests/decorators/test_python.py | 16 ++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 47fb0d2..3307f05 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -162,6 +162,20 @@ class DecoratedOperator(BaseOperator): ) return return_value + def _hook_apply_defaults(self, *args, **kwargs): + if 'python_callable' not in kwargs: + return args, kwargs + + python_callable = kwargs['python_callable'] + default_args = kwargs.get('default_args') or {} + op_kwargs = kwargs.get('op_kwargs') or {} + f_sig = signature(python_callable) + for arg in f_sig.parameters: + if arg not in op_kwargs and arg in default_args: + op_kwargs[arg] = default_args[arg] + kwargs['op_kwargs'] = op_kwargs + return args, kwargs + T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f6fec77..e243b5e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -176,6 +176,12 @@ class BaseOperatorMeta(abc.ABCMeta): if dag_params: kwargs['params'] = dag_params + if default_args: + kwargs['default_args'] = default_args + + if hasattr(self, '_hook_apply_defaults'): + args, kwargs = self._hook_apply_defaults(*args, **kwargs) # pylint: disable=protected-access + result = func(self, *args, **kwargs) # Here we set upstream task defined by XComArgs passed to template fields of the operator diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index a829863..59849fc 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -411,6 +411,22 @@ class TestAirflowTaskDecorator(TestPythonBase): ret = do_run() assert ret.operator.owner == 'airflow' # pylint: disable=maybe-no-member + @task_decorator + def test_apply_default_raise(unknow): + return unknow + + with pytest.raises(TypeError): + with self.dag: + test_apply_default_raise() # pylint: disable=no-value-for-parameter + + @task_decorator + def test_apply_default(owner): + return owner + + with self.dag: + ret = test_apply_default() # pylint: disable=no-value-for-parameter + assert 'owner' in ret.operator.op_kwargs + def test_xcom_arg(self): """Tests that returned key in XComArg is returned correctly"""
