ashb commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771724522
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
+
+ @function.validator
+ def _validate_function(self, _, f):
+ if 'self' in self.function_arg_names:
+ raise TypeError(f'@{self.decorator_name} does not support methods')
+
+ @multiple_outputs.default
+ def _infer_multiple_outputs(self):
+ sig = inspect.signature(self.function).return_annotation
+ ttype = getattr(sig, "__origin__", None)
+
+ return sig is not inspect.Signature.empty and ttype in (dict, Dict)
+
+ def __attrs_post_init__(self):
+ self.kwargs.setdefault('task_id', self.function.__name__)
+
+ def __call__(self, *args, **kwargs) -> XComArg:
+ op = self.operator_class(
+ python_callable=self.function,
+ op_args=args,
+ op_kwargs=kwargs,
+ multiple_outputs=self.multiple_outputs,
+ **self.kwargs,
+ )
+ if self.function.__doc__:
+ op.doc_md = self.function.__doc__
+ return XComArg(op)
+
+ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any],
valid_names: Set[str] = set()):
+ unknown_args = kwargs.copy()
+ for name in itertools.chain(self.function_arg_names, valid_names):
+ unknown_args.pop(name, None)
+
+ if not unknown_args:
+ # If we have no args left ot check, we are valid
+ return
+
+ if len(unknown_args) == 1:
+ raise TypeError(f'{funcname} got unexpected keyword argument
{unknown_args.popitem()[0]!r}')
+ else:
+ names = ", ".join(repr(n) for n in unknown_args)
+ raise TypeError(f'{funcname} got unexpected keyword arguments
{names}')
+
+ def map(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> XComArg:
+
+ dag = dag or DagContext.get_current_dag()
+ task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+ task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
+
+ self._validate_arg_names("map", kwargs)
+
+ operator = MappedOperator(
+ operator_class=self.operator_class,
+ task_id=task_id,
+ dag=dag,
+ task_group=task_group,
+ partial_kwargs=self.kwargs,
+ # Set them to empty to bypass the validation, as for decorated
stuff we validate ourselves
+ mapped_kwargs={},
+ )
+ operator.mapped_kwargs.update(kwargs)
+
+ return XComArg(operator=operator)
+
+ def partial(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> "OperatorWrapper[T, OperatorSubclass]":
+ self._validate_arg_names("partial", kwargs, {'task_id'})
+ partial_kwargs = self.kwargs.copy()
+ partial_kwargs.update(kwargs)
Review comment:
I'm unliekly to be able to make this change, so either someone else can
make this, or we can merge it with this and fix it later.
(I think that having a `.map()` function that returns an error would be
clearer than having no map method, similar to how I have `.partial()` on a Task
object throw an error still.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]