potiuk commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771564179
##########
File path: airflow/models/baseoperator.py
##########
@@ -110,12 +126,13 @@ def _apply_defaults(cls, func: T) -> T:
# per decoration, i.e. each function decorated using apply_defaults
will
# have a different sig_cache.
sig_cache = signature(func)
- non_optional_args = {
- name
+ non_varaidc_params = {
+ name: param
for (name, param) in sig_cache.parameters.items()
- if param.default == param.empty
- and param.name != 'self'
- and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
+ if param.name != 'self' and param.kind not in
(param.VAR_POSITIONAL, param.VAR_KEYWORD)
+ }
+ non_optional_args = {
+ name for (name, param) in non_varaidc_params.items() if
param.default == param.empty
Review comment:
```suggestion
name for (name, param) in non_variadic_params.items() if
param.default == param.empty
```
##########
File path: airflow/decorators/base.py
##########
@@ -199,38 +304,23 @@ def task_decorator_factory(
:type decorated_operator_class: BaseDecoratedOperator
"""
- # try to infer from type annotation
- if python_callable and multiple_outputs is None:
- sig = signature(python_callable).return_annotation
- ttype = getattr(sig, "__origin__", None)
-
- multiple_outputs = sig != inspect.Signature.empty and ttype in (dict,
Dict)
-
- def wrapper(f: T):
- """
- Python wrapper to generate PythonDecoratedOperator out of simple
python functions.
- Used for Airflow Decorated interface
- """
- validate_python_callable(f)
- kwargs.setdefault('task_id', f.__name__)
-
- @functools.wraps(f)
- def factory(*args, **f_kwargs):
- op = decorated_operator_class(
- python_callable=f,
- op_args=args,
- op_kwargs=f_kwargs,
- multiple_outputs=multiple_outputs,
- **kwargs,
- )
- if f.__doc__:
- op.doc_md = f.__doc__
- return XComArg(op)
-
- return cast(T, factory)
-
- if callable(python_callable):
- return wrapper(python_callable)
+ if multiple_outputs is None:
+ multiple_outputs = cast(bool, attr.NOTHING)
+ if python_callable:
+ return _TaskDecorator( # type: ignore
Review comment:
Hmm.. _TaskDecorator is bound to callable so ignore should not be
needed I guess? Is the type: ignore here about the generics
##########
File path: airflow/models/baseoperator.py
##########
@@ -110,12 +126,13 @@ def _apply_defaults(cls, func: T) -> T:
# per decoration, i.e. each function decorated using apply_defaults
will
# have a different sig_cache.
sig_cache = signature(func)
- non_optional_args = {
- name
+ non_varaidc_params = {
Review comment:
```suggestion
non_variadic_params = {
```
##########
File path: airflow/models/taskmixin.py
##########
@@ -88,3 +99,167 @@ def __init_subclass__(cls) -> None:
stacklevel=2,
)
return super().__init_subclass__()
+
+
+class DAGNode(DependencyMixin):
+ """
+ A base class for a node in the graph of a workflow -- an Operator or a
Task Group, either mapped or
+ unmapped.
+ """
+
+ dag: Optional["DAG"] = None
+
+ @property
+ @abstractmethod
Review comment:
I'd vote for adding ABCMeta. The DAGNode becomes an important class for
us, and maybe some day we'll add more nodes, who knows.
##########
File path: airflow/models/baseoperator.py
##########
@@ -792,6 +854,8 @@ def __setattr__(self, key, value):
if self._lock_for_execution:
# Skip any custom behaviour during execute
return
+ if key in self.__init_kwargs:
Review comment:
How do we get `self.__init_kwargs` populated ? I could not see it in the
change ?
##########
File path: airflow/models/baseoperator.py
##########
@@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) ->
Any:
self._BaseOperator__instantiated = True
return result
+ apply_defaults.__non_optional_args = non_optional_args # type: ignore
+ apply_defaults.__param_names = set(non_varaidc_params.keys()) # type:
ignore
Review comment:
```suggestion
apply_defaults.__param_names = set(non_variadic_params) # type:
ignore
```
##########
File path: airflow/models/baseoperator.py
##########
@@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) ->
Any:
self._BaseOperator__instantiated = True
return result
+ apply_defaults.__non_optional_args = non_optional_args # type: ignore
+ apply_defaults.__param_names = set(non_varaidc_params.keys()) # type:
ignore
+
return cast(T, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
+ try:
+ # Update the partial descriptor with the class method so it call
call the actual function (but let
+ # subclasses override it if they need to)
+ partial_desc = vars(new_cls)['partial']
Review comment:
Catch 22 :)
##########
File path: airflow/decorators/base.py
##########
@@ -176,11 +178,110 @@ def _hook_apply_defaults(self, *args, **kwargs):
T = TypeVar("T", bound=Callable)
+OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
+
+
[email protected]
+class OperatorWrapper(Generic[T, OperatorSubclass]):
+ """
+ Helper class for providing dynamic task mapping to decorated functions.
+
+ ``task_decorator_factory`` returns an instance of this, instead of just a
plain wrapped function.
+
+ :meta private:
+ """
+
+ function: T = attr.ib(validator=attr.validators.is_callable())
+ operator_class: Type[OperatorSubclass]
+ multiple_outputs: bool = attr.ib()
+ kwargs: Dict[str, Any] = attr.ib(factory=dict)
+
+ decorator_name: str = attr.ib(repr=False, default="task")
+ function_arg_names: Set[str] = attr.ib(repr=False)
+
+ @function_arg_names.default
+ def _get_arg_names(self):
+ return set(inspect.signature(self.function).parameters)
+
+ @function.validator
+ def _validate_function(self, _, f):
+ if 'self' in self.function_arg_names:
+ raise TypeError(f'@{self.decorator_name} does not support methods')
+
+ @multiple_outputs.default
+ def _infer_multiple_outputs(self):
+ sig = inspect.signature(self.function).return_annotation
+ ttype = getattr(sig, "__origin__", None)
+
+ return sig is not inspect.Signature.empty and ttype in (dict, Dict)
+
+ def __attrs_post_init__(self):
+ self.kwargs.setdefault('task_id', self.function.__name__)
+
+ def __call__(self, *args, **kwargs) -> XComArg:
+ op = self.operator_class(
+ python_callable=self.function,
+ op_args=args,
+ op_kwargs=kwargs,
+ multiple_outputs=self.multiple_outputs,
+ **self.kwargs,
+ )
+ if self.function.__doc__:
+ op.doc_md = self.function.__doc__
+ return XComArg(op)
+
+ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any],
valid_names: Set[str] = set()):
+ unknown_args = kwargs.copy()
+ for name in itertools.chain(self.function_arg_names, valid_names):
+ unknown_args.pop(name, None)
+
+ if not unknown_args:
+ # If we have no args left ot check, we are valid
+ return
+
+ if len(unknown_args) == 1:
+ raise TypeError(f'{funcname} got unexpected keyword argument
{unknown_args.popitem()[0]!r}')
+ else:
+ names = ", ".join(repr(n) for n in unknown_args)
+ raise TypeError(f'{funcname} got unexpected keyword arguments
{names}')
+
+ def map(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> XComArg:
+
+ dag = dag or DagContext.get_current_dag()
+ task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+ task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
+
+ self._validate_arg_names("map", kwargs)
+
+ operator = MappedOperator(
+ operator_class=self.operator_class,
+ task_id=task_id,
+ dag=dag,
+ task_group=task_group,
+ partial_kwargs=self.kwargs,
+ # Set them to empty to bypass the validation, as for decorated
stuff we validate ourselves
+ mapped_kwargs={},
+ )
+ operator.mapped_kwargs.update(kwargs)
+
+ return XComArg(operator=operator)
+
+ def partial(
+ self, *, dag: Optional["DAG"] = None, task_group:
Optional["TaskGroup"] = None, **kwargs
+ ) -> "OperatorWrapper[T, OperatorSubclass]":
+ self._validate_arg_names("partial", kwargs, {'task_id'})
+ partial_kwargs = self.kwargs.copy()
+ partial_kwargs.update(kwargs)
Review comment:
I also think `map().map()` should be an error. We already agreed on
suporting `map(arg1, arg2)` and:
> There should be one-- and preferably only one --obvious way to do it.
##########
File path: airflow/models/baseoperator.py
##########
@@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) ->
Any:
self._BaseOperator__instantiated = True
return result
+ apply_defaults.__non_optional_args = non_optional_args # type: ignore
+ apply_defaults.__param_names = set(non_varaidc_params.keys()) # type:
ignore
Review comment:
And should we be able to mark __paam_names as `Set[str]` ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]