uranusjr commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771429030
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
+
+ @function.validator
+ def _validate_function(self, _, f):
+ if 'self' in self.function_arg_names:
+ raise TypeError(f'@{self.decorator_name} does not support methods')
+
+ @multiple_outputs.default
+ def _infer_multiple_outputs(self):
+ sig = inspect.signature(self.function).return_annotation
+ ttype = getattr(sig, "__origin__", None)
+
+ return sig is not inspect.Signature.empty and ttype in (dict, Dict)
Review comment:
Probably better as a cached property? Also the `ttype` line should be
`getattr(sig, "__origin__", sig)` because `def foo() -> dict` is still
returning multiple outputs.
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,92 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
Review comment:
The name doesn’t really make it clear what this class does, which is
more specific than wrapping an operator (although it does). Maybe something
like `DecoratedTaskWrapper`?
`task_decorator_factory` might not be a very descriptive name to begin with,
but if this is designed most jut for that function, neither
`TaskDecoratorFactory` or just `TaskDecorator` are bad names either.
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
Review comment:
Does this need to exist? From what I can tell it is only used once in
the validator (which already has access to the original `function`).
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
+
+ @function.validator
+ def _validate_function(self, _, f):
+ if 'self' in self.function_arg_names:
+ raise TypeError(f'@{self.decorator_name} does not support methods')
+
+ @multiple_outputs.default
+ def _infer_multiple_outputs(self):
+ sig = inspect.signature(self.function).return_annotation
+ ttype = getattr(sig, "__origin__", None)
+
+ return sig is not inspect.Signature.empty and ttype in (dict, Dict)
+
+ def __attrs_post_init__(self):
+ self.kwargs.setdefault('task_id', self.function.__name__)
+
+ def __call__(self, *args, **kwargs) -> XComArg:
+ op = self.operator_class(
+ python_callable=self.function,
+ op_args=args,
+ op_kwargs=kwargs,
+ multiple_outputs=self.multiple_outputs,
+ **self.kwargs,
+ )
+ if self.function.__doc__:
+ op.doc_md = self.function.__doc__
+ return XComArg(op)
+
+ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any],
valid_names: Set[str] = set()):
+ unknown_args = kwargs.copy()
+ for name in itertools.chain(self.function_arg_names, valid_names):
+ unknown_args.pop(name, None)
+
+ if not unknown_args:
+ # If we have no args left ot check, we are valid
+ return
+
+ if len(unknown_args) == 1:
+ raise TypeError(f'{funcname} got unexpected keyword argument
{unknown_args.popitem()[0]!r}')
+ else:
+ names = ", ".join(repr(n) for n in unknown_args)
+ raise TypeError(f'{funcname} got unexpected keyword arguments
{names}')
+
+ def map(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> XComArg:
+
+ dag = dag or DagContext.get_current_dag()
+ task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+ task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
+
+ self._validate_arg_names("map", kwargs)
+
+ operator = MappedOperator(
+ operator_class=self.operator_class,
+ task_id=task_id,
+ dag=dag,
+ task_group=task_group,
+ partial_kwargs=self.kwargs,
+ # Set them to empty to bypass the validation, as for decorated
stuff we validate ourselves
+ mapped_kwargs={},
+ )
+ operator.mapped_kwargs.update(kwargs)
+
+ return XComArg(operator=operator)
+
+ def partial(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> "OperatorWrapper[T, OperatorSubclass]":
+ self._validate_arg_names("partial", kwargs, {'task_id'})
+ partial_kwargs = self.kwargs.copy()
+ partial_kwargs.update(kwargs)
Review comment:
This makes me wonder, what should
`MyOperator.map(x=something).map(x=another)` do? If I understand this
correctly, this would currently discard `something` and just map to `another`.
We should likely add something in to prevent this from happening, perhaps in
`_validate_arg_names`?
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
+
+ @function.validator
+ def _validate_function(self, _, f):
+ if 'self' in self.function_arg_names:
+ raise TypeError(f'@{self.decorator_name} does not support methods')
+
+ @multiple_outputs.default
+ def _infer_multiple_outputs(self):
+ sig = inspect.signature(self.function).return_annotation
+ ttype = getattr(sig, "__origin__", None)
+
+ return sig is not inspect.Signature.empty and ttype in (dict, Dict)
+
+ def __attrs_post_init__(self):
+ self.kwargs.setdefault('task_id', self.function.__name__)
+
+ def __call__(self, *args, **kwargs) -> XComArg:
+ op = self.operator_class(
+ python_callable=self.function,
+ op_args=args,
+ op_kwargs=kwargs,
+ multiple_outputs=self.multiple_outputs,
+ **self.kwargs,
+ )
+ if self.function.__doc__:
+ op.doc_md = self.function.__doc__
+ return XComArg(op)
+
+ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any],
valid_names: Set[str] = set()):
+ unknown_args = kwargs.copy()
+ for name in itertools.chain(self.function_arg_names, valid_names):
+ unknown_args.pop(name, None)
+
+ if not unknown_args:
+ # If we have no args left ot check, we are valid
+ return
+
+ if len(unknown_args) == 1:
+ raise TypeError(f'{funcname} got unexpected keyword argument
{unknown_args.popitem()[0]!r}')
Review comment:
I’d do `next(iter(unknown_args))` 🙂
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]