This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1a1753c724 Support setting dependencies for tasks called outside 
TaskGroup context manager (#32351)
1a1753c724 is described below

commit 1a1753c7246a2b35b993aad659f5551afd7e0215
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Jul 14 16:44:33 2023 +0100

    Support setting dependencies for tasks called outside TaskGroup context 
manager (#32351)
    
    * Support setting dependencies for tasks called outside TaskGroup context 
manager
    
    Currently, you must instantiate a classic operator or call a decorated
    operator inside the context manager before it will link up with the context 
manager.
    For example, tasks 1 and 2 below will be outside the group1 context:
    ```
       task1 = BashOperator(task_id="task1", bash_command="echo task1")
       task2 = BashOperator(task_id="task2", bash_command="echo task2")
       with TaskGroup('group1'):
            task1 >> task2
    ```
    This PR addresses the above such that when you do that, the tasks will
    be inside the group1 context.
    For a single task, you can do:
    ```
       task1 = BashOperator(task_id="task1", bash_command="echo task1")
       with TaskGroup('group1') as scope:
            scope.add(task1)
    ```
    
    * make a refactor
    
    * fixup! make a refactor
    
    * apply suggestions from code review
    
    * Relax reference resolution type check
    
    * fixup! Relax reference resolution type check
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/models/abstractoperator.py |  9 ++++
 airflow/models/taskmixin.py        | 45 ++++++++++++++------
 airflow/models/xcom_arg.py         | 10 +++++
 airflow/utils/edgemodifier.py      |  9 ++++
 airflow/utils/task_group.py        | 30 ++++++++++++-
 tests/utils/test_task_group.py     | 87 +++++++++++++++++++++++++++++++++++++-
 6 files changed, 173 insertions(+), 17 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 73d9204aab..9a8f88ce7d 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
     from airflow.models.mappedoperator import MappedOperator
     from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
+    from airflow.utils.task_group import TaskGroup
 
 DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
 DEFAULT_POOL_SLOTS: int = 1
@@ -382,6 +383,14 @@ class AbstractOperator(Templater, DAGNode):
                 yield parent
             parent = parent.task_group
 
+    def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+        """Add the task to the given task group.
+
+        :meta private:
+        """
+        if self.node_id not in task_group.children:
+            task_group.add(self)
+
     def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
         """Get the mapped task group "closest" to this task in the DAG.
 
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 658759a9b7..3a4f17974b 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -97,12 +97,14 @@ class DependencyMixin:
         """Implements Task << Task."""
         self.set_upstream(other)
         self.set_setup_teardown_ctx_dependencies(other)
+        self.set_taskgroup_ctx_dependencies(other)
         return other
 
     def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
         """Implements Task >> Task."""
         self.set_downstream(other)
         self.set_setup_teardown_ctx_dependencies(other)
+        self.set_taskgroup_ctx_dependencies(other)
         return other
 
     def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
@@ -115,24 +117,39 @@ class DependencyMixin:
         self.__rshift__(other)
         return self
 
+    @abstractmethod
+    def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+        """Add the task to the given task group."""
+        raise NotImplementedError()
+
+    @classmethod
+    def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, 
str]]:
+        from airflow.models.baseoperator import AbstractOperator
+        from airflow.utils.mixins import ResolveMixin
+
+        if isinstance(obj, AbstractOperator):
+            yield obj, "operator"
+        elif isinstance(obj, ResolveMixin):
+            yield from obj.iter_references()
+        elif isinstance(obj, Sequence):
+            for o in obj:
+                yield from cls._iter_references(o)
+
     def set_setup_teardown_ctx_dependencies(self, other: DependencyMixin | 
Sequence[DependencyMixin]):
         if not SetupTeardownContext.active:
             return
-        from airflow.models.xcom_arg import PlainXComArg
-
-        op1 = self
-        if isinstance(self, PlainXComArg):
-            op1 = self.operator
-        SetupTeardownContext.update_context_map(op1)
-        if isinstance(other, Sequence):
-            for op in other:
-                if isinstance(op, PlainXComArg):
-                    op = op.operator
-                SetupTeardownContext.update_context_map(op)
+        for op, _ in self._iter_references([self, other]):
+            SetupTeardownContext.update_context_map(op)
+
+    def set_taskgroup_ctx_dependencies(self, other: DependencyMixin | 
Sequence[DependencyMixin]):
+        from airflow.utils.task_group import TaskGroupContext
+
+        if not TaskGroupContext.active:
             return
-        if isinstance(other, PlainXComArg):
-            other = other.operator
-        SetupTeardownContext.update_context_map(other)
+        task_group = TaskGroupContext.get_current_task_group(None)
+        for op, _ in self._iter_references([self, other]):
+            if task_group:
+                op.add_to_taskgroup(task_group)
 
 
 class TaskMixin(DependencyMixin):
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index c7778ac676..154878a4db 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -42,6 +42,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
+    from airflow.utils.task_group import TaskGroup
 
 # Callable objects contained by MapXComArg. We only accept callables from
 # the user, but deserialize them into strings in a serialized XComArg for
@@ -207,6 +208,15 @@ class XComArg(ResolveMixin, DependencyMixin):
         """
         raise NotImplementedError()
 
+    def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+        """Add the task to the given task group.
+
+        :meta private:
+        """
+        for op, _ in self.iter_references():
+            if op.node_id not in task_group.children:
+                task_group.add(op)
+
     def __enter__(self):
         if not self.operator.is_setup and not self.operator.is_teardown:
             raise AirflowException("Only setup/teardown tasks can be used as 
context managers.")
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index b693e1a1be..c6d8065c45 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -172,6 +172,15 @@ class EdgeModifier(DependencyMixin):
         """
         dag.set_edge_info(upstream_id, downstream_id, {"label": self.label})
 
+    def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+        """No-op, since we're not a task.
+
+        We only add tasks to TaskGroups and not EdgeModifiers, but we need
+        this to satisfy the interface.
+
+        :meta private:
+        """
+
 
 # Factory functions
 def Label(label: str):
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index d59ecaf40b..e55a4abbe1 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -207,13 +207,17 @@ class TaskGroup(DAGNode):
             else:
                 yield child
 
-    def add(self, task: DAGNode) -> None:
+    def add(self, task: DAGNode) -> DAGNode:
         """Add a task to this TaskGroup.
 
         :meta private:
         """
         from airflow.models.abstractoperator import AbstractOperator
 
+        if TaskGroupContext.active:
+            if task.task_group and task.task_group != self:
+                task.task_group.children.pop(task.node_id, None)
+                task.task_group = self
         existing_tg = task.task_group
         if isinstance(task, AbstractOperator) and existing_tg is not None and 
existing_tg != self:
             raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, 
self.node_id)
@@ -237,6 +241,7 @@ class TaskGroup(DAGNode):
                 raise AirflowException("Cannot add a non-empty TaskGroup")
 
         self.children[key] = task
+        return task
 
     def _remove(self, task: DAGNode) -> None:
         key = task.node_id
@@ -543,6 +548,26 @@ class TaskGroup(DAGNode):
                         f"Encountered a DAGNode that is not a TaskGroup or an 
AbstractOperator: {type(child)}"
                     )
 
+    def add_task(self, task: AbstractOperator) -> None:
+        """Add a task to the task group.
+
+        :param task: the task to add
+        """
+        if not TaskGroupContext.active:
+            raise AirflowException(
+                "Using this method on a task group that's not a context 
manager is not supported."
+            )
+        task.add_to_taskgroup(self)
+
+    def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+        """No-op, since we're not a task.
+
+        We only add tasks to TaskGroups and not TaskGroup, but we need
+        this to satisfy the interface.
+
+        :meta private:
+        """
+
 
 class MappedTaskGroup(TaskGroup):
     """A mapped task group.
@@ -613,6 +638,7 @@ class MappedTaskGroup(TaskGroup):
 class TaskGroupContext:
     """TaskGroup context is used to keep the current TaskGroup when TaskGroup 
is used as ContextManager."""
 
+    active: bool = False
     _context_managed_task_group: TaskGroup | None = None
     _previous_context_managed_task_groups: list[TaskGroup] = []
 
@@ -622,6 +648,7 @@ class TaskGroupContext:
         if cls._context_managed_task_group:
             
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
         cls._context_managed_task_group = task_group
+        cls.active = True
 
     @classmethod
     def pop_context_managed_task_group(cls) -> TaskGroup | None:
@@ -631,6 +658,7 @@ class TaskGroupContext:
             cls._context_managed_task_group = 
cls._previous_context_managed_task_groups.pop()
         else:
             cls._context_managed_task_group = None
+        cls.active = False
         return old_task_group
 
     @classmethod
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index c2baec25d3..4475e7141a 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -22,8 +22,8 @@ from datetime import timedelta
 import pendulum
 import pytest
 
-from airflow.decorators import dag, task_group as task_group_decorator
-from airflow.exceptions import TaskAlreadyInTaskGroup
+from airflow.decorators import dag, task as task_decorator, task_group as 
task_group_decorator
+from airflow.exceptions import AirflowException, TaskAlreadyInTaskGroup
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
 from airflow.models.xcom_arg import XComArg
@@ -34,6 +34,20 @@ from airflow.utils.dag_edges import dag_edges
 from airflow.utils.task_group import TaskGroup, task_group_to_dict
 from tests.models import DEFAULT_DATE
 
+
+def make_task(name, type_="classic"):
+    if type_ == "classic":
+        return BashOperator(task_id=name, bash_command="echo 1")
+
+    else:
+
+        @task_decorator
+        def my_task():
+            pass
+
+        return my_task.override(task_id=name)()
+
+
 EXPECTED_JSON = {
     "id": None,
     "value": {
@@ -1465,3 +1479,72 @@ def test_task_group_arrow_with_setups_teardowns():
         tg1 >> w2
     assert t1.downstream_task_ids == set()
     assert w1.downstream_task_ids == {"tg1.t1", "w2"}
+
+
+def test_tasks_defined_outside_taskgrooup(dag_maker):
+    # Test that classic tasks defined outside a task group are added to the 
root task group
+    # when the relationships are defined inside the task group
+    with dag_maker() as dag:
+        t1 = make_task("t1")
+        t2 = make_task("t2")
+        t3 = make_task("t3")
+        with TaskGroup(group_id="tg1"):
+            t1 >> t2 >> t3
+    dag.validate()
+    assert dag.task_group.children.keys() == {"tg1"}
+    assert dag.task_group.children["tg1"].children.keys() == {"t1", "t2", "t3"}
+    assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids == 
set()
+    assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids 
== {"t2"}
+    assert dag.task_group.children["tg1"].children["t2"].upstream_task_ids == 
{"t1"}
+    assert dag.task_group.children["tg1"].children["t2"].downstream_task_ids 
== {"t3"}
+    assert dag.task_group.children["tg1"].children["t3"].upstream_task_ids == 
{"t2"}
+    assert dag.task_group.children["tg1"].children["t3"].downstream_task_ids 
== set()
+
+    # Test that decorated tasks defined outside a task group are added to the 
root task group
+    # when relationships are defined inside the task group
+    with dag_maker() as dag:
+        t1 = make_task("t1", type_="decorated")
+        t2 = make_task("t2", type_="decorated")
+        t3 = make_task("t3", type_="decorated")
+        with TaskGroup(group_id="tg1"):
+            t1 >> t2 >> t3
+    dag.validate()
+    assert dag.task_group.children.keys() == {"tg1"}
+    assert dag.task_group.children["tg1"].children.keys() == {"t1", "t2", "t3"}
+    assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids == 
set()
+    assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids 
== {"t2"}
+    assert dag.task_group.children["tg1"].children["t2"].upstream_task_ids == 
{"t1"}
+    assert dag.task_group.children["tg1"].children["t2"].downstream_task_ids 
== {"t3"}
+    assert dag.task_group.children["tg1"].children["t3"].upstream_task_ids == 
{"t2"}
+    assert dag.task_group.children["tg1"].children["t3"].downstream_task_ids 
== set()
+
+    # Test adding single decorated task defined outside a task group to a task 
group
+    with dag_maker() as dag:
+        t1 = make_task("t1", type_="decorated")
+        with TaskGroup(group_id="tg1") as tg1:
+            tg1.add_task(t1)
+    dag.validate()
+    assert dag.task_group.children.keys() == {"tg1"}
+    assert dag.task_group.children["tg1"].children.keys() == {"t1"}
+    assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids == 
set()
+    assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids 
== set()
+
+    # Test adding single classic task defined outside a task group to a task 
group
+    with dag_maker() as dag:
+        t1 = make_task("t1")
+        with TaskGroup(group_id="tg1") as tg1:
+            tg1.add_task(t1)
+    dag.validate()
+    assert dag.task_group.children.keys() == {"tg1"}
+    assert dag.task_group.children["tg1"].children.keys() == {"t1"}
+    assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids == 
set()
+    assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids 
== set()
+
+    with pytest.raises(
+        AirflowException,
+        match="Using this method on a task group that's not a context manager 
is not supported.",
+    ):
+        with dag_maker():
+            t1 = make_task("t1")
+            tg1 = TaskGroup(group_id="tg1")
+            tg1.add_task(t1)

Reply via email to