uranusjr commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771446198



##########
File path: airflow/decorators/task_group.py
##########
@@ -20,16 +20,93 @@
 together when the DAG is displayed graphically.
 """
 import functools
+import warnings
 from inspect import signature
-from typing import Callable, Optional, TypeVar, cast
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar, cast
 
-from airflow.utils.task_group import TaskGroup
+import attr
+
+from airflow.utils.task_group import MappedTaskGroup, TaskGroup
 
 T = TypeVar("T", bound=Callable)
+R = TypeVar("R")
 
 task_group_sig = signature(TaskGroup.__init__)
 
 
[email protected]
+class TaskGroupDecorator(Generic[R]):
+    """:meta private:"""
+
+    function: Callable[..., R] = 
attr.ib(validator=attr.validators.is_callable())
+    kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """kwargs for the TaskGroup"""
+
+    @function.validator
+    def _validate_function(self, _, f):
+        if 'self' in signature(f).parameters:
+            raise TypeError('@task_group does not support methods')
+
+    @kwargs.validator
+    def _validate(self, _, kwargs):
+        task_group_sig.bind_partial(**kwargs)
+
+    def __attrs_post_init__(self):
+        self.kwargs.setdefault('group_id', self.function.__name__)
+
+    def _make_task_group(self, **kwargs) -> TaskGroup:
+        return TaskGroup(**kwargs)
+
+    def __call__(self, *args, **kwargs) -> R:
+        with self._make_task_group(add_suffix_on_collision=True, 
**self.kwargs):
+            # Invoke function to run Tasks inside the TaskGroup
+            return self.function(*args, **kwargs)
+
+    def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
+        return MappedTaskGroupDecorator(function=self.function, 
kwargs=self.kwargs).partial(**kwargs)
+
+    def map(self, **kwargs) -> R:
+        return MappedTaskGroupDecorator(function=self.function, 
kwargs=self.kwargs).map(**kwargs)
+
+
[email protected]
+class MappedTaskGroupDecorator(TaskGroupDecorator[R]):
+    """:meta private:"""
+
+    partial_kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """static kwargs for the decorated function"""
+    mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """kwargs for the decorated function"""
+
+    _invoked: bool = attr.ib(init=False, default=False, repr=False)
+
+    def __call__(self, *args, **kwargs):
+        raise RuntimeError("Mapped @task_group's cannot be called. Use `.map` 
and `.partial` instead")
+
+    def _make_task_group(self, **kwargs) -> MappedTaskGroup:
+        tg = MappedTaskGroup(**kwargs)
+        tg.partial_kwargs = self.partial_kwargs
+        tg.mapped_kwargs = self.mapped_kwargs
+        return tg
+
+    def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
+        self.partial_kwargs.update(kwargs)
+        return self
+
+    def map(self, **kwargs) -> R:
+        self.mapped_kwargs.update(kwargs)
+
+        call_kwargs = self.partial_kwargs.copy()
+        call_kwargs.update({k: object() for k in self.mapped_kwargs})
+
+        self._invoked = True
+        return super().__call__(**call_kwargs)
+
+    def __del__(self):
+        if not self._invoked:
+            warnings.warn(f"Partial task group {self.function.__name__} was 
never mapped!")

Review comment:
       What’s the purpose of this?

##########
File path: airflow/models/baseoperator.py
##########
@@ -91,6 +90,23 @@
 T = TypeVar('T', bound=FunctionType)
 
 
+class _PartialDescriptor:
+    """A descriptor that guards against ``.partial`` being called on Task 
objects."""
+
+    class_method = None
+
+    def __get__(
+        self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None
+    ) -> Callable[..., "MappedOperator"]:
+        # Call this "partial" so it looks nicer in stack traces
+        def partial(*, task_id: str, **kwargs):
+            raise TypeError("partial can only be called on Operator classes, 
not Tasks themselves")

Review comment:
       ```suggestion
           def partial(**kwargs):
               raise TypeError("partial can only be called on Operator classes, 
not Tasks themselves")
   ```

##########
File path: airflow/models/baseoperator.py
##########
@@ -1659,6 +1629,115 @@ def defer(
         """
         raise TaskDeferred(trigger=trigger, method_name=method_name, 
kwargs=kwargs, timeout=timeout)
 
+    def map(self, **kwargs) -> "MappedOperator":
+        return MappedOperator(
+            operator_class=type(self),
+            operator=self,
+            task_id=self.task_id,
+            task_group=getattr(self, 'task_group', None),
+            dag=getattr(self, '_dag', None),
+            start_date=self.start_date,
+            end_date=self.end_date,
+            partial_kwargs=self.__init_kwargs,
+            mapped_kwargs=kwargs,
+        )
+
+
+def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, 
value: Dict[str, Any]):

Review comment:
       I wonder if `_validate_arg_names` and this have similar logic to be 
refactored out.

##########
File path: tests/serialization/test_dag_serialization.py
##########
@@ -1093,14 +1093,12 @@ def test_no_new_fields_added_to_base_operator(self):
         tests should be added for it.
         """
         base_operator = BaseOperator(task_id="10")
-        fields = base_operator.__dict__
+        fields = {k: v for (k, v) in vars(base_operator).items() if k in 
BaseOperator.get_serialized_fields()}

Review comment:
       Do we have a test to ensure `get_serialized_fields()` is up-to-date? 
Otherwise we may accidentally add a field on BaseOperator but forget to 
serialize it, and this test would fail to reflect that.

##########
File path: airflow/models/baseoperator.py
##########
@@ -1659,6 +1629,115 @@ def defer(
         """
         raise TaskDeferred(trigger=trigger, method_name=method_name, 
kwargs=kwargs, timeout=timeout)
 
+    def map(self, **kwargs) -> "MappedOperator":
+        return MappedOperator(
+            operator_class=type(self),
+            operator=self,
+            task_id=self.task_id,
+            task_group=getattr(self, 'task_group', None),
+            dag=getattr(self, '_dag', None),
+            start_date=self.start_date,
+            end_date=self.end_date,
+            partial_kwargs=self.__init_kwargs,
+            mapped_kwargs=kwargs,
+        )
+
+
+def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, 
value: Dict[str, Any]):
+    if isinstance(str, cls):
+        # Serialized version -- would have been validated at parse time
+        return
+
+    # use a dict so order of args is same as code order
+    unknown_args = value.copy()
+    for clazz in cls.mro():
+        # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access 
"__init__" directly
+        init = clazz.__init__  # type: ignore
+
+        if not hasattr(init, '_BaseOperatorMeta__param_names'):
+            continue
+
+        for name in init._BaseOperatorMeta__param_names:
+            unknown_args.pop(name, None)
+
+        if not unknown_args:
+            # If we have no args left ot check: stop looking at the MRO chian
+            return
+
+    if len(unknown_args) == 1:
+        raise TypeError(
+            f'{cls.__name__}.{func_name} got unexpected keyword argument 
{unknown_args.popitem()[0]!r}'
+        )
+    else:
+        names = ", ".join(repr(n) for n in unknown_args)
+        raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword 
arguments {names}')
+
+
[email protected](kw_only=True)
+class MappedOperator(DAGNode):
+    """Object representing a mapped operator in a DAG"""
+
+    operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__)

Review comment:
       We should have a validator to ensure `operator_class` and `operator` 
don’t disagree.

##########
File path: airflow/models/baseoperator.py
##########
@@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> 
Any:
             self._BaseOperator__instantiated = True
             return result
 
+        apply_defaults.__non_optional_args = non_optional_args  # type: ignore
+        apply_defaults.__param_names = set(non_varaidc_params.keys())  # type: 
ignore

Review comment:
       ```suggestion
           apply_defaults.__param_names = set(non_varaidc_params)  # type: 
ignore
   ```

##########
File path: airflow/models/baseoperator.py
##########
@@ -1659,6 +1629,115 @@ def defer(
         """
         raise TaskDeferred(trigger=trigger, method_name=method_name, 
kwargs=kwargs, timeout=timeout)
 
+    def map(self, **kwargs) -> "MappedOperator":
+        return MappedOperator(
+            operator_class=type(self),
+            operator=self,
+            task_id=self.task_id,
+            task_group=getattr(self, 'task_group', None),
+            dag=getattr(self, '_dag', None),
+            start_date=self.start_date,
+            end_date=self.end_date,
+            partial_kwargs=self.__init_kwargs,
+            mapped_kwargs=kwargs,
+        )
+
+
+def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, 
value: Dict[str, Any]):
+    if isinstance(str, cls):
+        # Serialized version -- would have been validated at parse time
+        return
+
+    # use a dict so order of args is same as code order
+    unknown_args = value.copy()
+    for clazz in cls.mro():
+        # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access 
"__init__" directly

Review comment:
       ```suggestion
           # Mypy doesn't like doing `class.__init__`, Error is: Cannot access 
"__init__" directly
   ```

##########
File path: airflow/models/baseoperator.py
##########
@@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> 
Any:
             self._BaseOperator__instantiated = True
             return result
 
+        apply_defaults.__non_optional_args = non_optional_args  # type: ignore
+        apply_defaults.__param_names = set(non_varaidc_params.keys())  # type: 
ignore
+
         return cast(T, apply_defaults)
 
     def __new__(cls, name, bases, namespace, **kwargs):
         new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
+        try:
+            # Update the partial descriptor with the class method so it call 
call the actual function (but let
+            # subclasses override it if they need to)
+            partial_desc = vars(new_cls)['partial']

Review comment:
       How is this different from `new_cls.partial`?

##########
File path: airflow/models/dag.py
##########
@@ -2075,10 +2075,10 @@ def filter_task_group(group, parent_group):
         # the cut.
         subdag_task_groups = dag.task_group.get_task_group_dict()
         for group in subdag_task_groups.values():
-            group.upstream_group_ids = 
group.upstream_group_ids.intersection(subdag_task_groups.keys())
-            group.downstream_group_ids = 
group.downstream_group_ids.intersection(subdag_task_groups.keys())
-            group.upstream_task_ids = 
group.upstream_task_ids.intersection(dag.task_dict.keys())
-            group.downstream_task_ids = 
group.downstream_task_ids.intersection(dag.task_dict.keys())
+            
group.upstream_group_ids.intersection_update(subdag_task_groups.keys())
+            
group.downstream_group_ids.intersection_update(subdag_task_groups.keys())
+            group.upstream_task_ids.intersection_update(dag.task_dict.keys())
+            group.downstream_task_ids.intersection_update(dag.task_dict.keys())

Review comment:
       ```suggestion
               group.upstream_group_ids.intersection_update(subdag_task_groups)
               
group.downstream_group_ids.intersection_update(subdag_task_groups)
               group.upstream_task_ids.intersection_update(dag.task_dict)
               group.downstream_task_ids.intersection_update(dag.task_dict)
   ```

##########
File path: airflow/models/baseoperator.py
##########
@@ -1659,6 +1629,115 @@ def defer(
         """
         raise TaskDeferred(trigger=trigger, method_name=method_name, 
kwargs=kwargs, timeout=timeout)
 
+    def map(self, **kwargs) -> "MappedOperator":
+        return MappedOperator(
+            operator_class=type(self),
+            operator=self,
+            task_id=self.task_id,
+            task_group=getattr(self, 'task_group', None),
+            dag=getattr(self, '_dag', None),
+            start_date=self.start_date,
+            end_date=self.end_date,
+            partial_kwargs=self.__init_kwargs,
+            mapped_kwargs=kwargs,
+        )
+
+
+def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, 
value: Dict[str, Any]):
+    if isinstance(str, cls):
+        # Serialized version -- would have been validated at parse time
+        return
+
+    # use a dict so order of args is same as code order
+    unknown_args = value.copy()
+    for clazz in cls.mro():
+        # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access 
"__init__" directly
+        init = clazz.__init__  # type: ignore
+
+        if not hasattr(init, '_BaseOperatorMeta__param_names'):
+            continue
+
+        for name in init._BaseOperatorMeta__param_names:
+            unknown_args.pop(name, None)
+
+        if not unknown_args:
+            # If we have no args left ot check: stop looking at the MRO chian
+            return
+
+    if len(unknown_args) == 1:
+        raise TypeError(
+            f'{cls.__name__}.{func_name} got unexpected keyword argument 
{unknown_args.popitem()[0]!r}'
+        )
+    else:
+        names = ", ".join(repr(n) for n in unknown_args)
+        raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword 
arguments {names}')
+
+
[email protected](kw_only=True)
+class MappedOperator(DAGNode):
+    """Object representing a mapped operator in a DAG"""
+
+    operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__)
+    task_id: str
+    partial_kwargs: Dict[str, Any]
+    mapped_kwargs: Dict[str, Any] = attr.ib(
+        validator=lambda self, _, v: 
_validate_kwarg_names_for_mapping(self.operator_class, "map", v)
+    )
+    operator: Optional[BaseOperator] = None
+    dag: Optional["DAG"] = None
+    upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)
+    downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False)
+
+    task_group: Optional["TaskGroup"] = attr.ib(repr=False)
+    # BaseOperator-like interface -- needed so we can add oursleves to the 
dag.tasks
+    start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
+    end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None)
+
+    def __attrs_post_init__(self):
+        if self.dag and self.operator:
+            # When BaseOperator() was called with a DAG, it would have been 
added straight away, but now we
+            # are mapped, we want to _remove_ that task (`self.operator`) from 
the dag
+            self.dag._remove_task(self.task_id)

Review comment:
       Do we need to do the same for task group? Or is it OK for the unmapped 
task to stay there?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to