This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1a1753c724 Support setting dependencies for tasks called outside
TaskGroup context manager (#32351)
1a1753c724 is described below
commit 1a1753c7246a2b35b993aad659f5551afd7e0215
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Jul 14 16:44:33 2023 +0100
Support setting dependencies for tasks called outside TaskGroup context
manager (#32351)
* Support setting dependencies for tasks called outside TaskGroup context
manager
Currently, you must instantiate a classic operator or call a decorated
operator inside the context manager before it will link up with the context
manager.
For example, tasks 1 and 2 below will be outside the group1 context:
```
task1 = BashOperator(task_id="task1", bash_command="echo task1")
task2 = BashOperator(task_id="task2", bash_command="echo task2")
with TaskGroup('group1'):
task1 >> task2
```
This PR addresses the above such that when you do that, the tasks will
be inside the group1 context.
For a single task, you can do:
```
task1 = BashOperator(task_id="task1", bash_command="echo task1")
with TaskGroup('group1') as scope:
scope.add(task1)
```
* make a refactor
* fixup! make a refactor
* apply suggestions from code review
* Relax reference resolution type check
* fixup! Relax reference resolution type check
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/abstractoperator.py | 9 ++++
airflow/models/taskmixin.py | 45 ++++++++++++++------
airflow/models/xcom_arg.py | 10 +++++
airflow/utils/edgemodifier.py | 9 ++++
airflow/utils/task_group.py | 30 ++++++++++++-
tests/utils/test_task_group.py | 87 +++++++++++++++++++++++++++++++++++++-
6 files changed, 173 insertions(+), 17 deletions(-)
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 73d9204aab..9a8f88ce7d 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
+ from airflow.utils.task_group import TaskGroup
DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
DEFAULT_POOL_SLOTS: int = 1
@@ -382,6 +383,14 @@ class AbstractOperator(Templater, DAGNode):
yield parent
parent = parent.task_group
+ def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+ """Add the task to the given task group.
+
+ :meta private:
+ """
+ if self.node_id not in task_group.children:
+ task_group.add(self)
+
def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
"""Get the mapped task group "closest" to this task in the DAG.
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 658759a9b7..3a4f17974b 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -97,12 +97,14 @@ class DependencyMixin:
"""Implements Task << Task."""
self.set_upstream(other)
self.set_setup_teardown_ctx_dependencies(other)
+ self.set_taskgroup_ctx_dependencies(other)
return other
def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Implements Task >> Task."""
self.set_downstream(other)
self.set_setup_teardown_ctx_dependencies(other)
+ self.set_taskgroup_ctx_dependencies(other)
return other
def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
@@ -115,24 +117,39 @@ class DependencyMixin:
self.__rshift__(other)
return self
+ @abstractmethod
+ def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+ """Add the task to the given task group."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin,
str]]:
+ from airflow.models.baseoperator import AbstractOperator
+ from airflow.utils.mixins import ResolveMixin
+
+ if isinstance(obj, AbstractOperator):
+ yield obj, "operator"
+ elif isinstance(obj, ResolveMixin):
+ yield from obj.iter_references()
+ elif isinstance(obj, Sequence):
+ for o in obj:
+ yield from cls._iter_references(o)
+
def set_setup_teardown_ctx_dependencies(self, other: DependencyMixin |
Sequence[DependencyMixin]):
if not SetupTeardownContext.active:
return
- from airflow.models.xcom_arg import PlainXComArg
-
- op1 = self
- if isinstance(self, PlainXComArg):
- op1 = self.operator
- SetupTeardownContext.update_context_map(op1)
- if isinstance(other, Sequence):
- for op in other:
- if isinstance(op, PlainXComArg):
- op = op.operator
- SetupTeardownContext.update_context_map(op)
+ for op, _ in self._iter_references([self, other]):
+ SetupTeardownContext.update_context_map(op)
+
+ def set_taskgroup_ctx_dependencies(self, other: DependencyMixin |
Sequence[DependencyMixin]):
+ from airflow.utils.task_group import TaskGroupContext
+
+ if not TaskGroupContext.active:
return
- if isinstance(other, PlainXComArg):
- other = other.operator
- SetupTeardownContext.update_context_map(other)
+ task_group = TaskGroupContext.get_current_task_group(None)
+ for op, _ in self._iter_references([self, other]):
+ if task_group:
+ op.add_to_taskgroup(task_group)
class TaskMixin(DependencyMixin):
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index c7778ac676..154878a4db 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -42,6 +42,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY
if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.operator import Operator
+ from airflow.utils.task_group import TaskGroup
# Callable objects contained by MapXComArg. We only accept callables from
# the user, but deserialize them into strings in a serialized XComArg for
@@ -207,6 +208,15 @@ class XComArg(ResolveMixin, DependencyMixin):
"""
raise NotImplementedError()
+ def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+ """Add the task to the given task group.
+
+ :meta private:
+ """
+ for op, _ in self.iter_references():
+ if op.node_id not in task_group.children:
+ task_group.add(op)
+
def __enter__(self):
if not self.operator.is_setup and not self.operator.is_teardown:
raise AirflowException("Only setup/teardown tasks can be used as
context managers.")
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index b693e1a1be..c6d8065c45 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -172,6 +172,15 @@ class EdgeModifier(DependencyMixin):
"""
dag.set_edge_info(upstream_id, downstream_id, {"label": self.label})
+ def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+ """No-op, since we're not a task.
+
+ We only add tasks to TaskGroups and not EdgeModifiers, but we need
+ this to satisfy the interface.
+
+ :meta private:
+ """
+
# Factory functions
def Label(label: str):
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index d59ecaf40b..e55a4abbe1 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -207,13 +207,17 @@ class TaskGroup(DAGNode):
else:
yield child
- def add(self, task: DAGNode) -> None:
+ def add(self, task: DAGNode) -> DAGNode:
"""Add a task to this TaskGroup.
:meta private:
"""
from airflow.models.abstractoperator import AbstractOperator
+ if TaskGroupContext.active:
+ if task.task_group and task.task_group != self:
+ task.task_group.children.pop(task.node_id, None)
+ task.task_group = self
existing_tg = task.task_group
if isinstance(task, AbstractOperator) and existing_tg is not None and
existing_tg != self:
raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id,
self.node_id)
@@ -237,6 +241,7 @@ class TaskGroup(DAGNode):
raise AirflowException("Cannot add a non-empty TaskGroup")
self.children[key] = task
+ return task
def _remove(self, task: DAGNode) -> None:
key = task.node_id
@@ -543,6 +548,26 @@ class TaskGroup(DAGNode):
f"Encountered a DAGNode that is not a TaskGroup or an
AbstractOperator: {type(child)}"
)
+ def add_task(self, task: AbstractOperator) -> None:
+ """Add a task to the task group.
+
+ :param task: the task to add
+ """
+ if not TaskGroupContext.active:
+ raise AirflowException(
+ "Using this method on a task group that's not a context
manager is not supported."
+ )
+ task.add_to_taskgroup(self)
+
+ def add_to_taskgroup(self, task_group: TaskGroup) -> None:
+ """No-op, since we're not a task.
+
+ We only add tasks to TaskGroups and not TaskGroup, but we need
+ this to satisfy the interface.
+
+ :meta private:
+ """
+
class MappedTaskGroup(TaskGroup):
"""A mapped task group.
@@ -613,6 +638,7 @@ class MappedTaskGroup(TaskGroup):
class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup
is used as ContextManager."""
+ active: bool = False
_context_managed_task_group: TaskGroup | None = None
_previous_context_managed_task_groups: list[TaskGroup] = []
@@ -622,6 +648,7 @@ class TaskGroupContext:
if cls._context_managed_task_group:
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
cls._context_managed_task_group = task_group
+ cls.active = True
@classmethod
def pop_context_managed_task_group(cls) -> TaskGroup | None:
@@ -631,6 +658,7 @@ class TaskGroupContext:
cls._context_managed_task_group =
cls._previous_context_managed_task_groups.pop()
else:
cls._context_managed_task_group = None
+ cls.active = False
return old_task_group
@classmethod
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index c2baec25d3..4475e7141a 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -22,8 +22,8 @@ from datetime import timedelta
import pendulum
import pytest
-from airflow.decorators import dag, task_group as task_group_decorator
-from airflow.exceptions import TaskAlreadyInTaskGroup
+from airflow.decorators import dag, task as task_decorator, task_group as
task_group_decorator
+from airflow.exceptions import AirflowException, TaskAlreadyInTaskGroup
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg
@@ -34,6 +34,20 @@ from airflow.utils.dag_edges import dag_edges
from airflow.utils.task_group import TaskGroup, task_group_to_dict
from tests.models import DEFAULT_DATE
+
+def make_task(name, type_="classic"):
+ if type_ == "classic":
+ return BashOperator(task_id=name, bash_command="echo 1")
+
+ else:
+
+ @task_decorator
+ def my_task():
+ pass
+
+ return my_task.override(task_id=name)()
+
+
EXPECTED_JSON = {
"id": None,
"value": {
@@ -1465,3 +1479,72 @@ def test_task_group_arrow_with_setups_teardowns():
tg1 >> w2
assert t1.downstream_task_ids == set()
assert w1.downstream_task_ids == {"tg1.t1", "w2"}
+
+
+def test_tasks_defined_outside_taskgrooup(dag_maker):
+ # Test that classic tasks defined outside a task group are added to the
root task group
+ # when the relationships are defined inside the task group
+ with dag_maker() as dag:
+ t1 = make_task("t1")
+ t2 = make_task("t2")
+ t3 = make_task("t3")
+ with TaskGroup(group_id="tg1"):
+ t1 >> t2 >> t3
+ dag.validate()
+ assert dag.task_group.children.keys() == {"tg1"}
+ assert dag.task_group.children["tg1"].children.keys() == {"t1", "t2", "t3"}
+ assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids ==
set()
+ assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids
== {"t2"}
+ assert dag.task_group.children["tg1"].children["t2"].upstream_task_ids ==
{"t1"}
+ assert dag.task_group.children["tg1"].children["t2"].downstream_task_ids
== {"t3"}
+ assert dag.task_group.children["tg1"].children["t3"].upstream_task_ids ==
{"t2"}
+ assert dag.task_group.children["tg1"].children["t3"].downstream_task_ids
== set()
+
+ # Test that decorated tasks defined outside a task group are added to the
root task group
+ # when relationships are defined inside the task group
+ with dag_maker() as dag:
+ t1 = make_task("t1", type_="decorated")
+ t2 = make_task("t2", type_="decorated")
+ t3 = make_task("t3", type_="decorated")
+ with TaskGroup(group_id="tg1"):
+ t1 >> t2 >> t3
+ dag.validate()
+ assert dag.task_group.children.keys() == {"tg1"}
+ assert dag.task_group.children["tg1"].children.keys() == {"t1", "t2", "t3"}
+ assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids ==
set()
+ assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids
== {"t2"}
+ assert dag.task_group.children["tg1"].children["t2"].upstream_task_ids ==
{"t1"}
+ assert dag.task_group.children["tg1"].children["t2"].downstream_task_ids
== {"t3"}
+ assert dag.task_group.children["tg1"].children["t3"].upstream_task_ids ==
{"t2"}
+ assert dag.task_group.children["tg1"].children["t3"].downstream_task_ids
== set()
+
+ # Test adding single decorated task defined outside a task group to a task
group
+ with dag_maker() as dag:
+ t1 = make_task("t1", type_="decorated")
+ with TaskGroup(group_id="tg1") as tg1:
+ tg1.add_task(t1)
+ dag.validate()
+ assert dag.task_group.children.keys() == {"tg1"}
+ assert dag.task_group.children["tg1"].children.keys() == {"t1"}
+ assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids ==
set()
+ assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids
== set()
+
+ # Test adding single classic task defined outside a task group to a task
group
+ with dag_maker() as dag:
+ t1 = make_task("t1")
+ with TaskGroup(group_id="tg1") as tg1:
+ tg1.add_task(t1)
+ dag.validate()
+ assert dag.task_group.children.keys() == {"tg1"}
+ assert dag.task_group.children["tg1"].children.keys() == {"t1"}
+ assert dag.task_group.children["tg1"].children["t1"].upstream_task_ids ==
set()
+ assert dag.task_group.children["tg1"].children["t1"].downstream_task_ids
== set()
+
+ with pytest.raises(
+ AirflowException,
+ match="Using this method on a task group that's not a context manager
is not supported.",
+ ):
+ with dag_maker():
+ t1 = make_task("t1")
+ tg1 = TaskGroup(group_id="tg1")
+ tg1.add_task(t1)