uranusjr commented on a change in pull request #19965:
URL: https://github.com/apache/airflow/pull/19965#discussion_r771497603



##########
File path: airflow/decorators/task_group.py
##########
@@ -20,16 +20,93 @@
 together when the DAG is displayed graphically.
 """
 import functools
+import warnings
 from inspect import signature
-from typing import Callable, Optional, TypeVar, cast
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar, cast
 
-from airflow.utils.task_group import TaskGroup
+import attr
+
+from airflow.utils.task_group import MappedTaskGroup, TaskGroup
 
 T = TypeVar("T", bound=Callable)
+R = TypeVar("R")
 
 task_group_sig = signature(TaskGroup.__init__)
 
 
[email protected]
+class TaskGroupDecorator(Generic[R]):
+    """:meta private:"""
+
+    function: Callable[..., R] = 
attr.ib(validator=attr.validators.is_callable())
+    kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """kwargs for the TaskGroup"""
+
+    @function.validator
+    def _validate_function(self, _, f):
+        if 'self' in signature(f).parameters:
+            raise TypeError('@task_group does not support methods')
+
+    @kwargs.validator
+    def _validate(self, _, kwargs):
+        task_group_sig.bind_partial(**kwargs)
+
+    def __attrs_post_init__(self):
+        self.kwargs.setdefault('group_id', self.function.__name__)
+
+    def _make_task_group(self, **kwargs) -> TaskGroup:
+        return TaskGroup(**kwargs)
+
+    def __call__(self, *args, **kwargs) -> R:
+        with self._make_task_group(add_suffix_on_collision=True, 
**self.kwargs):
+            # Invoke function to run Tasks inside the TaskGroup
+            return self.function(*args, **kwargs)
+
+    def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
+        return MappedTaskGroupDecorator(function=self.function, 
kwargs=self.kwargs).partial(**kwargs)
+
+    def map(self, **kwargs) -> R:
+        return MappedTaskGroupDecorator(function=self.function, 
kwargs=self.kwargs).map(**kwargs)
+
+
[email protected]
+class MappedTaskGroupDecorator(TaskGroupDecorator[R]):
+    """:meta private:"""
+
+    partial_kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """static kwargs for the decorated function"""
+    mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict)
+    """kwargs for the decorated function"""
+
+    _invoked: bool = attr.ib(init=False, default=False, repr=False)
+
+    def __call__(self, *args, **kwargs):
+        raise RuntimeError("Mapped @task_group's cannot be called. Use `.map` 
and `.partial` instead")
+
+    def _make_task_group(self, **kwargs) -> MappedTaskGroup:
+        tg = MappedTaskGroup(**kwargs)
+        tg.partial_kwargs = self.partial_kwargs
+        tg.mapped_kwargs = self.mapped_kwargs
+        return tg
+
+    def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
+        self.partial_kwargs.update(kwargs)
+        return self
+
+    def map(self, **kwargs) -> R:
+        self.mapped_kwargs.update(kwargs)
+
+        call_kwargs = self.partial_kwargs.copy()
+        call_kwargs.update({k: object() for k in self.mapped_kwargs})
+
+        self._invoked = True
+        return super().__call__(**call_kwargs)
+
+    def __del__(self):
+        if not self._invoked:
+            warnings.warn(f"Partial task group {self.function.__name__} was 
never mapped!")

Review comment:
       What would happen in this case, nothing? (i.e. this is likely a user 
mistake)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to