This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 49c193f [AIP-34] TaskGroup: A UI task grouping concept as an
alternative to SubDagOperator (#10153)
49c193f is described below
commit 49c193fb872856500d8919facf45b9ab5207a093
Author: yuqian90 <[email protected]>
AuthorDate: Sat Sep 19 08:51:37 2020 +0800
[AIP-34] TaskGroup: A UI task grouping concept as an alternative to
SubDagOperator (#10153)
This commit introduces TaskGroup, which is a simple UI task grouping
concept.
- TaskGroups can be collapsed/expanded in Graph View when clicked
- TaskGroups can be nested
- TaskGroups can be put upstream/downstream of tasks or other TaskGroups
with >> and << operators
- Search box, hovering, focusing in Graph View treats TaskGroup properly.
E.g. searching for tasks also highlights TaskGroup that contains matching
task_id. When TaskGroup is expanded/collapsed, the affected TaskGroup is put in
focus and moved to the centre of the graph.
What this commit does not do:
- This commit does not change or remove SubDagOperator. Although TaskGroup
is intended as an alternative for SubDagOperator, deprecating SubDagOperator
will need to be discussed/implemented in the future.
- This PR only implemented TaskGroup handling in the Graph View. In places
such as Tree View, it will look like as-if
- TaskGroup does not exist and all tasks are in the same flat DAG.
GitHub Issue: #8078
AIP:
https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-34+TaskGroup%3A+A+UI+task+grouping+concept+as+an+alternative+to+SubDagOperator
---
airflow/example_dags/example_task_group.py | 57 +++
airflow/models/baseoperator.py | 34 +-
airflow/models/dag.py | 51 ++-
airflow/models/taskmixin.py | 11 +
airflow/models/xcom_arg.py | 5 +
airflow/serialization/enums.py | 1 +
airflow/serialization/schema.json | 48 ++-
airflow/serialization/serialized_objects.py | 92 +++++
airflow/utils/task_group.py | 379 +++++++++++++++++
airflow/www/static/css/graph.css | 14 +
airflow/www/templates/airflow/graph.html | 458 ++++++++++++++++-----
airflow/www/views.py | 186 +++++++--
docs/concepts.rst | 42 ++
docs/img/task_group.gif | Bin 0 -> 609981 bytes
tests/serialization/test_dag_serialization.py | 63 +++
tests/utils/test_task_group.py | 561 ++++++++++++++++++++++++++
16 files changed, 1857 insertions(+), 145 deletions(-)
diff --git a/airflow/example_dags/example_task_group.py
b/airflow/example_dags/example_task_group.py
new file mode 100644
index 0000000..17134df5
--- /dev/null
+++ b/airflow/example_dags/example_task_group.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the TaskGroup."""
+
+from airflow.models.dag import DAG
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.dates import days_ago
+from airflow.utils.task_group import TaskGroup
+
+# [START howto_task_group]
+with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
+ start = DummyOperator(task_id="start")
+
+ # [START howto_task_group_section_1]
+ with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
+ task_1 = DummyOperator(task_id="task_1")
+ task_2 = DummyOperator(task_id="task_2")
+ task_3 = DummyOperator(task_id="task_3")
+
+ task_1 >> [task_2, task_3]
+ # [END howto_task_group_section_1]
+
+ # [START howto_task_group_section_2]
+ with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
+ task_1 = DummyOperator(task_id="task_1")
+
+ # [START howto_task_group_inner_section_2]
+ with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2")
as inner_section_2:
+ task_2 = DummyOperator(task_id="task_2")
+ task_3 = DummyOperator(task_id="task_3")
+ task_4 = DummyOperator(task_id="task_4")
+
+ [task_2, task_3] >> task_4
+ # [END howto_task_group_inner_section_2]
+
+ # [END howto_task_group_section_2]
+
+ end = DummyOperator(task_id='end')
+
+ start >> section_1 >> section_2 >> end
+# [END howto_task_group]
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 4058f05..6d48a27 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -27,7 +27,8 @@ import warnings
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import (
- Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional,
Sequence, Set, Tuple, Type, Union,
+ TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List,
Optional, Sequence, Set, Tuple,
+ Type, Union,
)
import attr
@@ -58,6 +59,9 @@ from airflow.utils.session import provide_session
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
+if TYPE_CHECKING:
+ from airflow.utils.task_group import TaskGroup # pylint:
disable=cyclic-import
+
ScheduleInterval = Union[str, timedelta, relativedelta]
TaskStateChangeCallback = Callable[[Context], None]
@@ -360,9 +364,12 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin,
metaclass=BaseOperatorMeta
do_xcom_push: bool = True,
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
+ task_group: Optional["TaskGroup"] = None,
**kwargs
):
from airflow.models.dag import DagContext
+ from airflow.utils.task_group import TaskGroupContext
+
super().__init__()
if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
@@ -382,6 +389,11 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin,
metaclass=BaseOperatorMeta
)
validate_key(task_id)
self.task_id = task_id
+ self.label = task_id
+ task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+ if task_group:
+ self.task_id = task_group.child_id(task_id)
+ task_group.add(self)
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
@@ -609,7 +621,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin,
metaclass=BaseOperatorMeta
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is
not self:
dag.add_task(self)
- self._dag = dag # pylint: disable=attribute-defined-outside-init
+ self._dag = dag
def has_dag(self):
"""
@@ -1120,21 +1132,25 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin,
metaclass=BaseOperatorMeta
"""Required by TaskMixin"""
return [self]
+ @property
+ def leaves(self) -> List["BaseOperator"]:
+ """Required by TaskMixin"""
+ return [self]
+
def _set_relatives(
self,
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
upstream: bool = False,
) -> None:
"""Sets relatives for the task or task list."""
-
- if isinstance(task_or_task_list, Sequence):
- task_like_object_list = task_or_task_list
- else:
- task_like_object_list = [task_or_task_list]
+ if not isinstance(task_or_task_list, Sequence):
+ task_or_task_list = [task_or_task_list]
task_list: List["BaseOperator"] = []
- for task_object in task_like_object_list:
- task_list.extend(task_object.roots)
+ for task_object in task_or_task_list:
+ task_object.update_relative(self, not upstream)
+ relatives = task_object.leaves if upstream else task_object.roots
+ task_list.extend(relatives)
for task in task_list:
if not isinstance(task, BaseOperator):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index eecf6b4..886837c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,9 @@ import traceback
import warnings
from collections import OrderedDict
from datetime import datetime, timedelta
-from typing import Callable, Collection, Dict, FrozenSet, Iterable, List,
Optional, Set, Type, Union, cast
+from typing import (
+ TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List,
Optional, Set, Type, Union, cast,
+)
import jinja2
import pendulum
@@ -59,6 +61,9 @@ from airflow.utils.sqlalchemy import Interval, UtcDateTime
from airflow.utils.state import State
from airflow.utils.types import DagRunType
+if TYPE_CHECKING:
+ from airflow.utils.task_group import TaskGroup
+
log = logging.getLogger(__name__)
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -238,6 +243,8 @@ class DAG(BaseDag, LoggingMixin):
jinja_environment_kwargs: Optional[Dict] = None,
tags: Optional[List[str]] = None
):
+ from airflow.utils.task_group import TaskGroup
+
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = copy.deepcopy(default_args or {})
@@ -329,6 +336,7 @@ class DAG(BaseDag, LoggingMixin):
self.jinja_environment_kwargs = jinja_environment_kwargs
self.tags = tags
+ self._task_group = TaskGroup.create_root(self)
def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
@@ -571,6 +579,10 @@ class DAG(BaseDag, LoggingMixin):
return list(self.task_dict.keys())
@property
+ def task_group(self) -> "TaskGroup":
+ return self._task_group
+
+ @property
def filepath(self) -> str:
"""
File location of where the dag object is instantiated
@@ -1240,7 +1252,6 @@ class DAG(BaseDag, LoggingMixin):
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.
"""
-
# deep-copying self.task_dict takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
task_dict = self.task_dict
@@ -1261,9 +1272,38 @@ class DAG(BaseDag, LoggingMixin):
# Make sure to not recursively deepcopy the dag while copying the task
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}
+
+ # Remove tasks not included in the subdag from task_group
+ def remove_excluded(group):
+ for child in list(group.children.values()):
+ if isinstance(child, BaseOperator):
+ if child.task_id not in dag.task_dict:
+ group.children.pop(child.task_id)
+ else:
+ # The tasks in the subdag are a copy of tasks in the
original dag
+ # so update the reference in the TaskGroups too.
+ group.children[child.task_id] =
dag.task_dict[child.task_id]
+ else:
+ remove_excluded(child)
+
+ # Remove this TaskGroup if it doesn't contain any tasks in
this subdag
+ if not child.children:
+ group.children.pop(child.group_id)
+
+ remove_excluded(dag.task_group)
+
+ # Removing upstream/downstream references to tasks and TaskGroups that
did not make
+ # the cut.
+ subdag_task_groups = dag.task_group.get_task_group_dict()
+ for group in subdag_task_groups.values():
+ group.upstream_group_ids =
group.upstream_group_ids.intersection(subdag_task_groups.keys())
+ group.downstream_group_ids =
group.downstream_group_ids.intersection(subdag_task_groups.keys())
+ group.upstream_task_ids =
group.upstream_task_ids.intersection(dag.task_dict.keys())
+ group.downstream_task_ids =
group.downstream_task_ids.intersection(dag.task_dict.keys())
+
for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
- # made the cut
+ # make the cut
t._upstream_task_ids =
t.upstream_task_ids.intersection(dag.task_dict.keys())
t._downstream_task_ids = t.downstream_task_ids.intersection(
dag.task_dict.keys())
@@ -1357,12 +1397,15 @@ class DAG(BaseDag, LoggingMixin):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)
- if task.task_id in self.task_dict and self.task_dict[task.task_id] is
not task:
+ if ((task.task_id in self.task_dict and self.task_dict[task.task_id]
is not task)
+ or task.task_id in self._task_group.used_group_ids):
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the
DAG".format(task.task_id))
else:
self.task_dict[task.task_id] = task
task.dag = self
+ # Add task_id to used_group_ids to prevent group_id and task_id
collisions.
+ self._task_group.used_group_ids.add(task.task_id)
self.task_count = len(self.task_dict)
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index a3d4224..cfdc714 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -33,6 +33,11 @@ class TaskMixin:
"""Should return list of root operator List[BaseOperator]"""
raise NotImplementedError()
+ @property
+ def leaves(self):
+ """Should return list of leaf operator List[BaseOperator]"""
+ raise NotImplementedError()
+
@abstractmethod
def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
@@ -47,6 +52,12 @@ class TaskMixin:
"""
raise NotImplementedError()
+ def update_relative(self, other: "TaskMixin", upstream=True) -> None:
+ """
+ Update relationship information about another TaskMixin. Default is
no-op.
+ Override if necessary.
+ """
+
def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Implements Task << Task
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 0f647bf..b9faaaba 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -103,6 +103,11 @@ class XComArg(TaskMixin):
return [self._operator]
@property
+ def leaves(self) -> List[BaseOperator]:
+ """Required by TaskMixin"""
+ return [self._operator]
+
+ @property
def key(self) -> str:
"""Returns keys of this XComArg"""
return self._key
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index b6fae5e..9a30231 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum):
SET = 'set'
TUPLE = 'tuple'
POD = 'k8s.V1Pod'
+ TASK_GROUP = 'taskgroup'
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index 49de949..9056eaa 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -96,7 +96,11 @@
"_default_view": { "type" : "string"},
"_access_control": {"$ref": "#/definitions/dict" },
"is_paused_upon_creation": { "type": "boolean" },
- "tags": { "type": "array" }
+ "tags": { "type": "array" },
+ "_task_group": {"anyOf": [
+ { "type": "null" },
+ { "$ref": "#/definitions/task_group" }
+ ]}
},
"required": [
"_dag_id",
@@ -125,6 +129,7 @@
"_task_module": { "type": "string" },
"_operator_extra_links": { "$ref": "#/definitions/extra_links" },
"task_id": { "type": "string" },
+ "label": { "type": "string" },
"owner": { "type": "string" },
"start_date": { "$ref": "#/definitions/datetime" },
"end_date": { "$ref": "#/definitions/datetime" },
@@ -156,6 +161,47 @@
}
},
"additionalProperties": true
+ },
+ "task_group": {
+ "$comment": "A TaskGroup containing tasks",
+ "type": "object",
+ "required": [
+ "_group_id",
+ "prefix_group_id",
+ "children",
+ "tooltip",
+ "ui_color",
+ "ui_fgcolor",
+ "upstream_group_ids",
+ "downstream_group_ids",
+ "upstream_task_ids",
+ "downstream_task_ids"
+ ],
+ "properties": {
+ "_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
+ "prefix_group_id": { "type": "boolean" },
+ "children": { "$ref": "#/definitions/dict" },
+ "tooltip": { "type": "string" },
+ "ui_color": { "type": "string" },
+ "ui_fgcolor": { "type": "string" },
+ "upstream_group_ids": {
+ "type": "array",
+ "items": { "type": "string" }
+ },
+ "downstream_group_ids": {
+ "type": "array",
+ "items": { "type": "string" }
+ },
+ "upstream_task_ids": {
+ "type": "array",
+ "items": { "type": "string" }
+ },
+ "downstream_task_ids": {
+ "type": "array",
+ "items": { "type": "string" }
+ }
+ },
+ "additionalProperties": false
}
},
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 6129a9d..41c6bc7 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -44,6 +44,7 @@ from airflow.serialization.json_schema import Validator,
load_dag_schema
from airflow.settings import json
from airflow.utils.code_utils import get_python_source
from airflow.utils.module_loading import import_string
+from airflow.utils.task_group import TaskGroup
log = logging.getLogger(__name__)
FAILED = 'serialization_failed'
@@ -221,6 +222,8 @@ class BaseSerialization:
# FIXME: casts tuple to list in customized serialization in
future.
return cls._encode(
[cls._serialize(v) for v in var], type_=DAT.TUPLE)
+ elif isinstance(var, TaskGroup):
+ return SerializedTaskGroup.serialize_task_group(var)
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
@@ -376,6 +379,10 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
# Extra Operator Links defined in Plugins
op_extra_links_from_plugin = {}
+ if "label" not in encoded_op:
+ # Handle deserialization of old data before the introduction of
TaskGroup
+ encoded_op["label"] = encoded_op["task_id"]
+
for ope in plugins_manager.operator_extra_links:
for operator in ope.operators:
if operator.__name__ == encoded_op["_task_type"] and \
@@ -570,6 +577,7 @@ class SerializedDAG(DAG, BaseSerialization):
serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields)
serialize_dag["tasks"] = [cls._serialize(task) for _, task in
dag.task_dict.items()]
+ serialize_dag['_task_group'] =
SerializedTaskGroup.serialize_task_group(dag.task_group)
return serialize_dag
@classmethod
@@ -598,6 +606,22 @@ class SerializedDAG(DAG, BaseSerialization):
setattr(dag, k, v)
+ # Set _task_group
+ # pylint: disable=protected-access
+ if "_task_group" in encoded_dag:
+ dag._task_group = SerializedTaskGroup.deserialize_task_group( #
type: ignore
+ encoded_dag["_task_group"],
+ None,
+ dag.task_dict
+ )
+ else:
+ # This must be old data that had no task_group. Create a root
TaskGroup and add
+ # all tasks to it.
+ dag._task_group = TaskGroup.create_root(dag)
+ for task in dag.tasks:
+ dag.task_group.add(task)
+ # pylint: enable=protected-access
+
keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() -
cls._CONSTRUCTOR_PARAMS.keys()
for k in keys_to_set_none:
setattr(dag, k, None)
@@ -641,3 +665,71 @@ class SerializedDAG(DAG, BaseSerialization):
if ver != cls.SERIALIZER_VERSION:
raise ValueError("Unsure how to deserialize version
{!r}".format(ver))
return cls.deserialize_dag(serialized_obj['dag'])
+
+
+class SerializedTaskGroup(TaskGroup, BaseSerialization):
+ """
+ A JSON serializable representation of TaskGroup.
+ """
+ @classmethod
+ def serialize_task_group(cls, task_group: TaskGroup) ->
Optional[Union[Dict[str, Any]]]:
+ """
+ Serializes TaskGroup into a JSON object.
+ """
+ if not task_group:
+ return None
+
+ serialize_group = {
+ "_group_id": task_group._group_id, # pylint:
disable=protected-access
+ "prefix_group_id": task_group.prefix_group_id,
+ "tooltip": task_group.tooltip,
+ "ui_color": task_group.ui_color,
+ "ui_fgcolor": task_group.ui_fgcolor,
+ "children": {
+ label: (DAT.OP, child.task_id)
+ if isinstance(child, BaseOperator) else
+ (DAT.TASK_GROUP,
SerializedTaskGroup.serialize_task_group(child))
+ for label, child in task_group.children.items()
+ },
+ "upstream_group_ids":
cls._serialize(list(task_group.upstream_group_ids)),
+ "downstream_group_ids":
cls._serialize(list(task_group.downstream_group_ids)),
+ "upstream_task_ids":
cls._serialize(list(task_group.upstream_task_ids)),
+ "downstream_task_ids":
cls._serialize(list(task_group.downstream_task_ids)),
+
+ }
+
+ return serialize_group
+
+ @classmethod
+ def deserialize_task_group(
+ cls,
+ encoded_group: Dict[str, Any],
+ parent_group: Optional[TaskGroup],
+ task_dict: Dict[str, BaseOperator]
+ ) -> Optional[TaskGroup]:
+ """
+ Deserializes a TaskGroup from a JSON object.
+ """
+ if not encoded_group:
+ return None
+
+ group_id = cls._deserialize(encoded_group["_group_id"])
+ kwargs = {
+ key: cls._deserialize(encoded_group[key])
+ for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
+ }
+ group = SerializedTaskGroup(
+ group_id=group_id,
+ parent_group=parent_group,
+ **kwargs
+ )
+ group.children = {
+ label: task_dict[val] if _type == DAT.OP # type: ignore
+ else SerializedTaskGroup.deserialize_task_group(val, group,
task_dict) for label, (_type, val)
+ in encoded_group["children"].items()
+ }
+ group.upstream_group_ids =
set(cls._deserialize(encoded_group["upstream_group_ids"]))
+ group.downstream_group_ids =
set(cls._deserialize(encoded_group["downstream_group_ids"]))
+ group.upstream_task_ids =
set(cls._deserialize(encoded_group["upstream_task_ids"]))
+ group.downstream_task_ids =
set(cls._deserialize(encoded_group["downstream_task_ids"]))
+ return group
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
new file mode 100644
index 0000000..84cc540
--- /dev/null
+++ b/airflow/utils/task_group.py
@@ -0,0 +1,379 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+A TaskGroup is a collection of closely related tasks on the same DAG that
should be grouped
+together when the DAG is displayed graphically.
+"""
+
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence,
Set, Union
+
+from airflow.exceptions import AirflowException, DuplicateTaskIdFound
+from airflow.models.taskmixin import TaskMixin
+
+if TYPE_CHECKING:
+ from airflow.models.baseoperator import BaseOperator
+ from airflow.models.dag import DAG
+
+
+class TaskGroup(TaskMixin):
+ """
+ A collection of tasks. When set_downstream() or set_upstream() are called
on the
+ TaskGroup, it is applied across all tasks within the group if necessary.
+
+ :param group_id: a unique, meaningful id for the TaskGroup. group_id must
not conflict
+ with group_id of TaskGroup or task_id of tasks in the DAG. Root
TaskGroup has group_id
+ set to None.
+ :type group_id: str
+ :param prefix_group_id: If set to True, child task_id and group_id will be
prefixed with
+ this TaskGroup's group_id. If set to False, child task_id and group_id
are not prefixed.
+ Default is True.
+ :type prerfix_group_id: bool
+ :param parent_group: The parent TaskGroup of this TaskGroup. parent_group
is set to None
+ for the root TaskGroup.
+ :type parent_group: TaskGroup
+ :param dag: The DAG that this TaskGroup belongs to.
+ :type dag: airflow.models.DAG
+ :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
+ :type tooltip: str
+ :param ui_color: The fill color of the TaskGroup node when displayed in
the UI
+ :type ui_color: str
+ :param ui_fgcolor: The label color of the TaskGroup node when displayed in
the UI
+ :type ui_fgcolor: str
+ """
+
+ def __init__(
+ self,
+ group_id: Optional[str],
+ prefix_group_id: bool = True,
+ parent_group: Optional["TaskGroup"] = None,
+ dag: Optional["DAG"] = None,
+ tooltip: str = "",
+ ui_color: str = "CornflowerBlue",
+ ui_fgcolor: str = "#000",
+ ):
+ from airflow.models.dag import DagContext
+
+ self.prefix_group_id = prefix_group_id
+
+ if group_id is None:
+ # This creates a root TaskGroup.
+ if parent_group:
+ raise AirflowException("Root TaskGroup cannot have
parent_group")
+ # used_group_ids is shared across all TaskGroups in the same DAG
to keep track
+ # of used group_id to avoid duplication.
+ self.used_group_ids: Set[Optional[str]] = set()
+ self._parent_group = None
+ else:
+ if not isinstance(group_id, str):
+ raise ValueError("group_id must be str")
+ if not group_id:
+ raise ValueError("group_id must not be empty")
+
+ dag = dag or DagContext.get_current_dag()
+
+ if not parent_group and not dag:
+ raise AirflowException("TaskGroup can only be used inside a
dag")
+
+ self._parent_group = parent_group or
TaskGroupContext.get_current_task_group(dag)
+ if not self._parent_group:
+ raise AirflowException("TaskGroup must have a parent_group
except for the root TaskGroup")
+ self.used_group_ids = self._parent_group.used_group_ids
+
+ self._group_id = group_id
+ if self.group_id in self.used_group_ids:
+ raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has
already been added to the DAG")
+ self.used_group_ids.add(self.group_id)
+ self.used_group_ids.add(self.downstream_join_id)
+ self.used_group_ids.add(self.upstream_join_id)
+ self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
+ if self._parent_group:
+ self._parent_group.add(self)
+
+ self.tooltip = tooltip
+ self.ui_color = ui_color
+ self.ui_fgcolor = ui_fgcolor
+
+ # Keep track of TaskGroups or tasks that depend on this entire
TaskGroup separately
+ # so that we can optimize the number of edges when entire TaskGroups
depend on each other.
+ self.upstream_group_ids: Set[Optional[str]] = set()
+ self.downstream_group_ids: Set[Optional[str]] = set()
+ self.upstream_task_ids: Set[Optional[str]] = set()
+ self.downstream_task_ids: Set[Optional[str]] = set()
+
+ @classmethod
+ def create_root(cls, dag: "DAG") -> "TaskGroup":
+ """
+ Create a root TaskGroup with no group_id or parent.
+ """
+ return cls(group_id=None, dag=dag)
+
+ @property
+ def is_root(self) -> bool:
+ """
+ Returns True if this TaskGroup is the root TaskGroup. Otherwise False
+ """
+ return not self.group_id
+
+ def __iter__(self):
+ for child in self.children.values():
+ if isinstance(child, TaskGroup):
+ for inner_task in child:
+ yield inner_task
+ else:
+ yield child
+
+ def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None:
+ """
+ Add a task to this TaskGroup.
+ """
+ key = task.group_id if isinstance(task, TaskGroup) else task.task_id
+
+ if key in self.children:
+ raise DuplicateTaskIdFound(f"Task id '{key}' has already been
added to the DAG")
+
+ if isinstance(task, TaskGroup):
+ if task.children:
+ raise AirflowException("Cannot add a non-empty TaskGroup")
+
+ self.children[key] = task # type: ignore
+
+ @property
+ def group_id(self) -> Optional[str]:
+ """
+ group_id of this TaskGroup.
+ """
+ if self._parent_group and self._parent_group.prefix_group_id and
self._parent_group.group_id:
+ return self._parent_group.child_id(self._group_id)
+
+ return self._group_id
+
+ @property
+ def label(self) -> Optional[str]:
+ """
+ group_id excluding parent's group_id used as the node label in UI.
+ """
+ return self._group_id
+
+ def update_relative(self, other: "TaskMixin", upstream=True) -> None:
+ """
+ Overrides TaskMixin.update_relative.
+
+ Update
upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
+ accordingly so that we can reduce the number of edges when displaying
Graph View.
+ """
+ from airflow.models.baseoperator import BaseOperator
+
+ if isinstance(other, TaskGroup):
+ # Handles setting relationship between a TaskGroup and another
TaskGroup
+ if upstream:
+ parent, child = (self, other)
+ else:
+ parent, child = (other, self)
+
+ parent.upstream_group_ids.add(child.group_id)
+ child.downstream_group_ids.add(parent.group_id)
+ else:
+ # Handles setting relationship between a TaskGroup and a task
+ for task in other.roots:
+ if not isinstance(task, BaseOperator):
+ raise AirflowException("Relationships can only be set
between TaskGroup "
+ f"or operators; received
{task.__class__.__name__}")
+
+ if upstream:
+ self.upstream_task_ids.add(task.task_id)
+ else:
+ self.downstream_task_ids.add(task.task_id)
+
+ def _set_relative(
+ self,
+ task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
+ upstream: bool = False
+ ) -> None:
+ """
+ Call set_upstream/set_downstream for all root/leaf tasks within this
TaskGroup.
+ Update
upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
+ """
+ if upstream:
+ for task in self.get_roots():
+ task.set_upstream(task_or_task_list)
+ else:
+ for task in self.get_leaves():
+ task.set_downstream(task_or_task_list)
+
+ if not isinstance(task_or_task_list, Sequence):
+ task_or_task_list = [task_or_task_list]
+
+ for task_like in task_or_task_list:
+ self.update_relative(task_like, upstream)
+
+ def set_downstream(
+ self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
+ ) -> None:
+ """
+ Set a TaskGroup/task/list of task downstream of this TaskGroup.
+ """
+ self._set_relative(task_or_task_list, upstream=False)
+
+ def set_upstream(
+ self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
+ ) -> None:
+ """
+ Set a TaskGroup/task/list of task upstream of this TaskGroup.
+ """
+ self._set_relative(task_or_task_list, upstream=True)
+
+ def __enter__(self):
+ TaskGroupContext.push_context_managed_task_group(self)
+ return self
+
+ def __exit__(self, _type, _value, _tb):
+ TaskGroupContext.pop_context_managed_task_group()
+
+ def has_task(self, task: "BaseOperator") -> bool:
+ """
+ Returns True if this TaskGroup or its children TaskGroups contains the
given task.
+ """
+ if task.task_id in self.children:
+ return True
+
+ return any(child.has_task(task) for child in self.children.values() if
isinstance(child, TaskGroup))
+
+ @property
+ def roots(self) -> List["BaseOperator"]:
+ """Required by TaskMixin"""
+ return list(self.get_roots())
+
+ @property
+ def leaves(self) -> List["BaseOperator"]:
+ """Required by TaskMixin"""
+ return list(self.get_leaves())
+
+ def get_roots(self) -> Generator["BaseOperator", None, None]:
+ """
+ Returns a generator of tasks that are root tasks, i.e. those with no
upstream
+ dependencies within the TaskGroup.
+ """
+ for task in self:
+ if not any(self.has_task(parent) for parent in
task.get_direct_relatives(upstream=True)):
+ yield task
+
+ def get_leaves(self) -> Generator["BaseOperator", None, None]:
+ """
+ Returns a generator of tasks that are leaf tasks, i.e. those with no
downstream
+ dependencies within the TaskGroup
+ """
+ for task in self:
+ if not any(self.has_task(child) for child in
task.get_direct_relatives(upstream=False)):
+ yield task
+
+ def child_id(self, label):
+ """
+ Prefix label with group_id if prefix_group_id is True. Otherwise
return the label
+ as-is.
+ """
+ if self.prefix_group_id and self.group_id:
+ return f"{self.group_id}.{label}"
+
+ return label
+
+ @property
+ def upstream_join_id(self) -> str:
+ """
+ If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy
node called
+ upstream_join_id will be created in Graph View to join the outgoing
edges from this
+ TaskGroup to reduce the total number of edges needed to be displayed.
+ """
+ return f"{self.group_id}.upstream_join_id"
+
+ @property
+ def downstream_join_id(self) -> str:
+ """
+ If this TaskGroup has immediate downstream TaskGroups or tasks, a
dummy node called
+ downstream_join_id will be created in Graph View to join the outgoing
edges from this
+ TaskGroup to reduce the total number of edges needed to be displayed.
+ """
+ return f"{self.group_id}.downstream_join_id"
+
+ def get_task_group_dict(self) -> Dict[str, "TaskGroup"]:
+ """
+ Returns a flat dictionary of group_id: TaskGroup
+ """
+ task_group_map = {}
+
+ def build_map(task_group):
+ if not isinstance(task_group, TaskGroup):
+ return
+
+ task_group_map[task_group.group_id] = task_group
+
+ for child in task_group.children.values():
+ build_map(child)
+
+ build_map(self)
+ return task_group_map
+
+ def get_child_by_label(self, label: str) -> Union["BaseOperator",
"TaskGroup"]:
+ """
+ Get a child task/TaskGroup by its label (i.e. task_id/group_id without
the group_id prefix)
+ """
+ return self.children[self.child_id(label)]
+
+
+class TaskGroupContext:
+ """
+ TaskGroup context is used to keep the current TaskGroup when TaskGroup is
used as ContextManager.
+ """
+
+ _context_managed_task_group: Optional[TaskGroup] = None
+ _previous_context_managed_task_groups: List[TaskGroup] = []
+
+ @classmethod
+ def push_context_managed_task_group(cls, task_group: TaskGroup):
+ """
+ Push a TaskGroup into the list of managed TaskGroups.
+ """
+ if cls._context_managed_task_group:
+
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
+ cls._context_managed_task_group = task_group
+
+ @classmethod
+ def pop_context_managed_task_group(cls) -> Optional[TaskGroup]:
+ """
+ Pops the last TaskGroup from the list of manged TaskGroups and update
the current TaskGroup.
+ """
+ old_task_group = cls._context_managed_task_group
+ if cls._previous_context_managed_task_groups:
+ cls._context_managed_task_group =
cls._previous_context_managed_task_groups.pop()
+ else:
+ cls._context_managed_task_group = None
+ return old_task_group
+
+ @classmethod
+ def get_current_task_group(cls, dag: Optional["DAG"]) ->
Optional[TaskGroup]:
+ """
+ Get the current TaskGroup.
+ """
+ from airflow.models.dag import DagContext
+
+ if not cls._context_managed_task_group:
+ dag = dag or DagContext.get_current_dag()
+ if dag:
+ # If there's currently a DAG but no TaskGroup, return the root
TaskGroup of the dag.
+ return dag.task_group
+
+ return cls._context_managed_task_group
diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css
index 7f1b818..ee1d093 100644
--- a/airflow/www/static/css/graph.css
+++ b/airflow/www/static/css/graph.css
@@ -36,12 +36,26 @@ svg {
stroke-width: 1px;
}
+g.cluster rect {
+ stroke: white;
+ stroke-dasharray: 5;
+ rx: 5;
+ ry: 5;
+ opacity: 0.5;
+}
+
g.node rect {
stroke: #fff;
stroke-width: 3px;
cursor: pointer;
}
+g.node circle {
+ stroke: black;
+ stroke-width: 3px;
+ cursor: pointer;
+}
+
g.node .label {
font-size: inherit;
font-weight: normal;
diff --git a/airflow/www/templates/airflow/graph.html
b/airflow/www/templates/airflow/graph.html
index f4ec0b6..6d6d566 100644
--- a/airflow/www/templates/airflow/graph.html
+++ b/airflow/www/templates/airflow/graph.html
@@ -101,6 +101,8 @@
<script src="{{ url_for_asset('d3.min.js') }}"></script>
<script src="{{ url_for_asset('dagre-d3.min.js') }}"></script>
<script src="{{ url_for_asset('d3-tip.js') }}"></script>
+
+<script src="{{ url_for_asset('task-instances.js') }}"></script>
<script>
var highlight_color = "#000000";
@@ -113,6 +115,10 @@
var edges = {{ edges|tojson }};
var execution_date = "{{ execution_date }}";
var arrange = "{{ arrange }}";
+ var task_group_tips = get_task_group_tips(nodes);
+ // This maps the actual task_id to the current graph node id that contains
the task
+ // (because tasks may be grouped into a group node)
+ var map_task_to_node = new Map()
// Below variables are being used in dag.js
var tasks = {{ tasks|tojson }};
@@ -140,33 +146,94 @@
});
// Preparation of DagreD3 data structures
- var g = new dagreD3.graphlib.Graph().setGraph({
+ // "compound" is set to true to make use of clusters to display TaskGroup.
+ var g = new dagreD3.graphlib.Graph({compound: true}).setGraph({
nodesep: 15,
ranksep: 15,
rankdir: arrange,
})
.setDefaultEdgeLabel(function() { return { lineInterpolate: 'basis' } });
- // Set all nodes and styles
- nodes.forEach(function(node) {
- g.setNode(node.id, node.value)
- });
-
- // Set edges
- edges.forEach(function(edge) {
- g.setEdge(edge.source_id, edge.target_id);
- });
-
var render = dagreD3.render(),
svg = d3.select("svg"),
innerSvg = d3.select("svg g");
- innerSvg.call(render, g);
- innerSvg.call(taskTip);
+ // Update the page to show the latest DAG.
+ function draw() {
+ innerSvg.remove()
+ innerSvg = svg.append("g")
+ // Run the renderer. This is what draws the final graph.
+ innerSvg.call(render, g);
+ innerSvg.call(taskTip)
+
+ // When an expanded group is clicked, collapse it.
+ d3.selectAll("g.cluster").on("click", function (node_id) {
+ if (d3.event.defaultPrevented) // Ignore dragging actions.
+ return;
+ node = g.node(node_id)
+ collapse_group(node_id, node)
+ })
+ // When a node is clicked, action depends on the node type.
+ d3.selectAll("g.node").on("click", function (node_id) {
+ node = g.node(node_id)
+ if (node.children != undefined && Object.keys(node.children).length >
0) {
+ // A group node
+ if (d3.event.defaultPrevented) // Ignore dragging actions.
+ return;
+ expand_group(node_id, node)
+ } else if (node_id in tasks) {
+ // A task node
+ task = tasks[node_id];
+ if (node_id in task_instances)
+ try_number = task_instances[node_id].try_number;
+ else
+ try_number = 0;
+
+ if (task.task_type == "SubDagOperator")
+ call_modal(node_id, execution_date, task.extra_links, try_number,
true);
+ else
+ call_modal(node_id, execution_date, task.extra_links, try_number,
undefined);
+ } else {
+ // join node between TaskGroup. Ignore.
+ }
+ });
+
+ d3.selectAll("g.node").on("mouseover", function (d) {
+ d3.select(this).selectAll("rect").style("stroke", highlight_color);
+ highlight_nodes(g.predecessors(d), upstream_color,
highlightStrokeWidth);
+ highlight_nodes(g.successors(d), downstream_color,
highlightStrokeWidth)
+ adjacent_node_names = [d, ...g.predecessors(d), ...g.successors(d)]
+ d3.selectAll("g.nodes g.node")
+ .filter(x => !adjacent_node_names.includes(x))
+ .style("opacity", 0.2);
+ adjacent_edges = g.nodeEdges(d)
+ d3.selectAll("g.edgePath")[0]
+ .filter(x => !adjacent_edges.includes(x.__data__))
+ .forEach(function (x) {
+ d3.select(x).style('opacity', .2)
+ })
+ });
+
+ d3.selectAll("g.node").on("mouseout", function (d) {
+ d3.select(this).selectAll("rect").style("stroke", null);
+ highlight_nodes(g.predecessors(d), null, initialStrokeWidth)
+ highlight_nodes(g.successors(d), null, initialStrokeWidth)
+ d3.selectAll("g.node")
+ .style("opacity", 1);
+ d3.selectAll("g.node rect")
+ .style("stroke-width", initialStrokeWidth);
+ d3.selectAll("g.edgePath")
+ .style("opacity", 1);
+ });
+ updateNodesStates(task_instances);
+ setUpZoomSupport();
+ }
+
+ var zoom = null;
function setUpZoomSupport() {
// Set up zoom support for Graph
- var zoom = d3.behavior.zoom().on("zoom", function() {
+ zoom = d3.behavior.zoom().on("zoom", function() {
innerSvg.attr("transform", "translate(" + d3.event.translate + ")"
+
"scale(" + d3.event.scale + ")");
});
@@ -193,61 +260,16 @@
zoom.event(innerSvg);
}
- setUpZoomSupport();
- inject_node_ids(tasks);
-
- d3.selectAll("g.node").on("click", function(d){
- task = tasks[d];
- if (d in task_instances)
- try_number = task_instances[d].try_number;
- else
- try_number = 0;
-
- if (task.task_type == "SubDagOperator")
- call_modal(d, execution_date, task.extra_links, try_number, true);
- else
- call_modal(d, execution_date, task.extra_links, try_number,
undefined);
- });
-
-
function highlight_nodes(nodes, color, stroke_width) {
nodes.forEach (function (nodeid) {
- my_node = d3.select('[id="' + nodeid + '"]').node().parentNode;
+ const my_node = g.node(nodeid).elem
d3.select(my_node)
- .selectAll("rect")
+ .selectAll("rect,circle")
.style("stroke", color)
.style("stroke-width", stroke_width) ;
})
}
- d3.selectAll("g.node").on("mouseover", function(d){
- d3.select(this).selectAll("rect").style("stroke", highlight_color) ;
- highlight_nodes(g.predecessors(d), upstream_color,
highlightStrokeWidth);
- highlight_nodes(g.successors(d), downstream_color,
highlightStrokeWidth)
- adjacent_node_names = [d, ...g.predecessors(d), ...g.successors(d)]
- d3.selectAll("g.nodes g.node")
- .filter(x => !adjacent_node_names.includes(x))
- .style("opacity", 0.2);
- adjacent_edges = g.nodeEdges(d)
- d3.selectAll("g.edgePath")[0]
- .filter(x => !adjacent_edges.includes(x.__data__))
- .forEach(function(x) {
- d3.select(x).style('opacity', .2)
- })
- });
-
- d3.selectAll("g.node").on("mouseout", function(d){
- d3.select(this).selectAll("rect").style("stroke", null) ;
- highlight_nodes(g.predecessors(d), null, initialStrokeWidth)
- highlight_nodes(g.successors(d), null, initialStrokeWidth)
- d3.selectAll("g.node")
- .style("opacity", 1);
- d3.selectAll("g.node rect")
- .style("stroke-width", initialStrokeWidth);
- d3.selectAll("g.edgePath")
- .style("opacity", 1);
- });
-
{% if blur %}
d3.selectAll("text").attr("class", "blur");
@@ -283,11 +305,31 @@
}
});
- d3.select("#searchbox").on("keyup", function(){
+ // Returns true if a node's id or its children's id matches search_text
+ function node_matches(node_id, search_text) {
+ if (node_id.indexOf(search_text) > -1)
+ return true;
+
+ // The node's own id does not match, it may have children that match
+ var node = g.node(node_id)
+ if (node.children != undefined) {
+ var children = get_children_ids(node);
+ for(const child of children) {
+ if(child.indexOf(search_text) > -1)
+ return true
+ }
+ }
+ }
+
+ d3.select("#searchbox").on("keyup", function() {
var s = document.getElementById("searchbox").value;
+
+ if(s == "")
+ return;
+
var match = null;
- if (stateIsSet){
+ if (stateIsSet()){
clearFocus();
setFocusMap();
}
@@ -307,7 +349,7 @@
d3.select("g.edgePaths")
.transition().duration(duration)
.style("opacity", 0.2);
- if (d.indexOf(s) > -1) {
+ if (node_matches(d, s)) {
if (!match)
match = this;
d3.select(this)
@@ -315,8 +357,7 @@
.style("opacity", 1)
.selectAll("rect")
.style("stroke-width", highlightStrokeWidth);
- }
- else {
+ } else {
d3.select(this)
.transition()
.style("opacity", 0.2).duration(duration)
@@ -326,37 +367,25 @@
}
});
- // This moves the matched node in the center of the graph area
- // ToDo: Should we keep this here as it has no added value
- // and does not fit the graph on small screens, and has to scroll
+ // This moves the matched node to the center of the graph area
if(match) {
var transform = d3.transform(d3.select(match).attr("transform"));
+
+ var svgBb = svg.node().getBoundingClientRect();
transform.translate = [
- -transform.translate[0] + 520,
- -(transform.translate[1] - 400)
+ svgBb.width / 2 - transform.translate[0],
+ svgBb.height / 2 - transform.translate[1]
];
transform.scale = [1, 1];
- d3.select("g.zoom")
- .transition()
- .attr("transform", transform.toString());
- innerSvg.attr("transform", "translate(" + transform.translate +
")" +
- "scale(1)");
+ if(zoom != null) {
+ zoom.translate(transform.translate);
+ zoom.scale(1);
+ zoom.event(innerSvg);
+ }
}
});
-
- // Injecting ids to be used for parent/child highlighting
- // Separated from updateNodeStates since it must work even
- // when there is no valid task instance available
- function inject_node_ids(tasks) {
- $.each(tasks, function(task_id, task) {
- $('tspan').filter(function(index) { return $(this).text() ===
task_id; })
- .parent().parent().parent()
- .attr("id", task_id);
- });
- }
-
function clearFocus(){
d3.selectAll("g.node")
.transition(duration)
@@ -413,8 +442,9 @@
$("div#svg_container").css("opacity", "0.2");
$.get(getTaskInstanceURL)
.done(
- (task_instances) => {
- updateNodesStates(JSON.parse(task_instances));
+ (tis) => {
+ task_instances = JSON.parse(tis)
+ updateNodesStates(task_instances);
$("#loading").hide();
$("div#svg_container").css("opacity", "1");
$('#error').hide();
@@ -429,31 +459,249 @@
});
}
+ // Generate tooltip for a group node
+ function group_tooltip(node_id, tis) {
+ var num_map = new Map([["success", 0],
+ ["failed", 0],
+ ["upstream_failed", 0],
+ ["up_for_retry", 0],
+ ["running", 0],
+ ["no_status", 0]]
+ );
+ for(const child of get_children_ids(g.node(node_id))) {
+ if(child in tis) {
+ const ti = tis[child];
+ const state_key = ti.state == null ? "no_status" : ti.state;
+ if(num_map.has(state_key))
+ num_map.set(state_key, num_map.get(state_key) + 1);
+ }
+ }
+
+ const tip = task_group_tips.get(node_id);
+ let tt = `${escapeHtml(tip)}<br><br>`;
+ for(const [key, val] of num_map.entries())
+ tt += `<strong>${escapeHtml(key)}:</strong> ${val} <br>`;
+
+ return tt;
+ }
+
+ // Build a map mapping node id to tooltip for all the TaskGroups.
+ function get_task_group_tips(node) {
+ var tips = new Map();
+ if(node.children != undefined) {
+ tips.set(node.id, node.tooltip);
+
+ for(const child of node.children.values()) {
+ for(const [key, val] of get_task_group_tips(child))
+ tips.set(key, val);
+ }
+ }
+ return tips;
+ }
+
// Assigning css classes based on state to nodes
// Initiating the tooltips
- function updateNodesStates(task_instances) {
- $.each(task_instances, (task_id, ti) => {
- $('tspan').filter((index, el) => {
- return $(el).text() === task_id;
- })
- .parent().parent().parent().parent()
- .attr("class", "node enter " + (ti.state ? ti.state : "no_status"))
- .attr("data-toggle", "tooltip")
- .on("mouseover", (evt) => {
- const task = tasks[task_id];
- const tt = tiTooltip(ti);
+ function updateNodesStates(tis) {
+ for(const node_id of g.nodes())
+ {
+ elem = g.node(node_id).elem;
+ elem.setAttribute("class", "node enter " + get_node_state(node_id,
tis));
+ elem.setAttribute("data-toggle", "tooltip");
+
+ const task_id = node_id;
+ elem.onmouseover = (evt) => {
+ if(task_id in tis) {
+ const tt = tiTooltip(tis[task_id]);
taskTip.show(tt, evt.target); // taskTip is defined in graph.html
+ } else if(task_group_tips.has(task_id)) {
+ const tt = group_tooltip(task_id, tis)
+ taskTip.show(tt, evt.target);
+ }
+ };
+ elem.onmouseout = taskTip.hide;
+ elem.onclick = taskTip.hide;
+ }
+ }
+
+
+ // Returns list of children id of the given task group
+ function get_children_ids(group) {
+ var children = []
+ for(const [key, val] of Object.entries(group.children)) {
+ if(val.children == undefined) {
+ // node
+ children.push(val.id)
+ } else {
+ // group
+ const sub_group_children = get_children_ids(val)
+ for(const id of sub_group_children) {
+ children.push(id)
+ }
+ }
+ }
+ return children
+ }
+
+
+ // Return the state for the node based on the state of its taskinstance or
that of its
+ // children if it's a group node
+ function get_node_state(node_id, tis) {
+ node = g.node(node_id)
+
+ if (node.children == undefined) {
+ if(node_id in tis)
+ return tis[node_id].state
+
+ return "no_status"
+ }
+ var children = get_children_ids(node)
+
+ children_states = new Set()
+ children.forEach(function(task_id) {
+ if (task_id in tis) {
+ var state = tis[task_id].state
+ children_states.add(state == null ? "no_status" : state)
+ }
+ })
+
+ // In this order, if any of these states appeared in children_states,
return it as
+ // the group state.
+ var priority = ["failed", "upstream_failed",
"up_for_retry","up_for_reschedule",
+ "queued", "no_status", "success", "skipped"]
+ for(const state of priority) {
+ if (children_states.has(state))
+ return state
+ }
+ return "no_status"
+ }
+
+ // Focus the graph on the expanded/collapsed node
+ function focus_group(node_id) {
+ if(node_id != null && zoom != null) {
+ const x = g.node(node_id).x;
+ const y = g.node(node_id).y;
+ // This is the total canvas size.
+ const svg_box = svg.node().getBoundingClientRect();
+ const width = svg_box.width;
+ const height = svg_box.height;
+
+ // This is the size of the node or the cluster (i.e. group)
+ var rect = d3.selectAll("g.node").filter(x => {return x ==
node_id}).select('rect');
+ if (rect.empty())
+ rect = d3.selectAll("g.cluster").filter(x => {return x ==
node_id}).select('rect');
+
+ // Is there a better way to get node_width and node_height ?
+ const [node_width, node_height] =
[rect[0][0].attributes.width.value, rect[0][0].attributes.height.value];
+
+ // Calculate zoom scale to fill most of the canvas with the the
node/cluster in focus.
+ const scale = Math.min(
+ Math.min(width / node_width, height / node_height),
+ 1.5, // cap zoom level to 1.5 so nodes are not too large
+ ) * 0.9;
+
+ var [delta_x, delta_y] = [width / 2 - x * scale, height / 2 - y *
scale];
+ zoom.translate([delta_x, delta_y]);
+ zoom.scale(scale);
+ zoom.event(innerSvg.transition().duration(duration));
+
+ const children = new Set(g.children(node_id))
+ // Change opacity to highlight the focused group.
+ d3.selectAll("g.nodes g.node").filter(function(d, i){
+ if (d == node_id || children.has(d)) {
+ d3.select(this)
+ .transition().duration(duration)
+ .style("opacity", 1)
+ } else {
+ d3.select(this)
+ .transition()
+ .style("opacity", 0.2).duration(duration)
+ }
+ });
+ }
+ }
+
+ // Expands a group node
+ function expand_group(node_id, node) {
+ node.children.forEach(function (val) {
+ // Set children nodes
+ g.setNode(val.id, val.value)
+ map_task_to_node.set(val.id, val.id)
+ g.node(val.id).id = val.id
+ if (val.children != undefined) {
+ // Set children attribute so that the group can be expanded later
when needed.
+ group_node = g.node(val.id)
+ group_node.children = val.children
+ // Map task that are under this node to this node's id
+ for(const child_id of get_children_ids(val))
+ map_task_to_node.set(child_id, val.id)
+ }
+ // Only call setParent if node is not the root node.
+ if (node_id != null)
+ g.setParent(val.id, node_id)
+ })
+
+ // Add edges
+ edges.forEach(function(edge) {
+ source_id = map_task_to_node.get(edge.source_id)
+ target_id = map_task_to_node.get(edge.target_id)
+ if(source_id != target_id && !g.hasEdge(source_id, target_id))
+ g.setEdge(source_id, target_id)
+ })
+
+ g.edges().forEach(function (edge) {
+ // Remove edges that were associated with the expanded group node..
+ if(node_id == edge.v || node_id == edge.w)
+ g.removeEdge(edge.v, edge.w)
+ })
+
+ draw()
+ focus_group(node_id)
+ }
+
+ // Remove the node with this node_id from g.
+ function remove_node(node_id) {
+ if(g.hasNode(node_id)) {
+ node = g.node(node_id)
+ if(node.children != undefined) {
+ // If the child is an expanded group node, remove children too.
+ node.children.forEach(function (child) {
+ remove_node(child.id)
})
- .on('mouseout', taskTip.hide);
- });
+ }
}
+ g.removeNode(node_id)
+ }
+
+ // Collapse the children of the given group node.
+ function collapse_group(node_id, node) {
+ // Remove children nodes
+ node.children.forEach(function(child) {
+ remove_node(child.id)
+ })
+ // Map task that are under this node to this node's id
+ for(const child_id of get_children_ids(node))
+ map_task_to_node.set(child_id, node_id)
+
+ node = g.node(node_id)
+
+ // Set children edges onto the group edge
+ edges.forEach(function(edge) {
+ source_id = map_task_to_node.get(edge.source_id)
+ target_id = map_task_to_node.get(edge.target_id)
+ if(source_id != target_id && !g.hasEdge(source_id, target_id))
+ g.setEdge(source_id, target_id)
+ })
+
+ draw()
+ focus_group(node_id)
+ }
+
+ expand_group(null, nodes)
- updateNodesStates(task_instances);
initRefreshButton();
</script>
<script src="{{ url_for_asset('graph.js') }}"></script>
-<script src="{{ url_for_asset('task-instances.js') }}"></script>
{% endblock %}
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 07c8d2f..4c5e2b1 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -58,6 +58,7 @@ from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.base_job import BaseJob
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss,
TaskFail, XCom, errors
+from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
@@ -147,6 +148,163 @@ def
get_date_time_num_runs_dag_runs_form_data(www_request, session, dag):
}
+def task_group_to_dict(task_group):
+ """
+ Create a nested dict representation of this TaskGroup and its children
used to construct
+ the Graph View.
+ """
+ if isinstance(task_group, BaseOperator):
+ return {
+ 'id': task_group.task_id,
+ 'value': {
+ 'label': task_group.label,
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'rx': 5,
+ 'ry': 5,
+ }
+ }
+
+ children = [task_group_to_dict(child) for child in
+ sorted(task_group.children.values(), key=lambda t: t.label)]
+
+ if task_group.upstream_group_ids or task_group.upstream_task_ids:
+ children.append({
+ 'id': task_group.upstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ }
+ })
+
+ if task_group.downstream_group_ids or task_group.downstream_task_ids:
+ # This is the join node used to reduce the number of edges between two
TaskGroup.
+ children.append({
+ 'id': task_group.downstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ }
+ })
+
+ return {
+ "id": task_group.group_id,
+ 'value': {
+ 'label': task_group.label,
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color}",
+ 'rx': 5,
+ 'ry': 5,
+ 'clusterLabelPos': 'top',
+ },
+ 'tooltip': task_group.tooltip,
+ 'children': children
+ }
+
+
+def dag_edges(dag):
+ """
+ Create the list of edges needed to construct the Graph View.
+
+ A special case is made if a TaskGroup is immediately upstream/downstream
of another
+ TaskGroup or task. Two dummy nodes named upstream_join_id and
downstream_join_id are
+ created for the TaskGroup. Instead of drawing an edge onto every task in
the TaskGroup,
+ all edges are directed onto the dummy nodes. This is to cut down the
number of edges on
+ the graph.
+
+ For example: A DAG with TaskGroups group1 and group2:
+ group1: task1, task2, task3
+ group2: task4, task5, task6
+
+ group2 is downstream of group1:
+ group1 >> group2
+
+ Edges to add (This avoids having to create edges between every task in
group1 and group2):
+ task1 >> downstream_join_id
+ task2 >> downstream_join_id
+ task3 >> downstream_join_id
+ downstream_join_id >> upstream_join_id
+ upstream_join_id >> task4
+ upstream_join_id >> task5
+ upstream_join_id >> task6
+ """
+
+ # Edges to add between TaskGroup
+ edges_to_add = set()
+ # Edges to remove between individual tasks that are replaced by
edges_to_add.
+ edges_to_skip = set()
+
+ task_group_map = dag.task_group.get_task_group_dict()
+
+ def collect_edges(task_group):
+ """
+ Update edges_to_add and edges_to_skip according to TaskGroups.
+ """
+ if isinstance(task_group, BaseOperator):
+ return
+
+ for target_id in task_group.downstream_group_ids:
+ # For every TaskGroup immediately downstream, add edges between
downstream_join_id
+ # and upstream_join_id. Skip edges between individual tasks of the
TaskGroups.
+ target_group = task_group_map[target_id]
+ edges_to_add.add((task_group.downstream_join_id,
target_group.upstream_join_id))
+
+ for child in task_group.get_leaves():
+ edges_to_add.add((child.task_id,
task_group.downstream_join_id))
+ for target in target_group.get_roots():
+ edges_to_skip.add((child.task_id, target.task_id))
+ edges_to_skip.add((child.task_id,
target_group.upstream_join_id))
+
+ for child in target_group.get_roots():
+ edges_to_add.add((target_group.upstream_join_id,
child.task_id))
+ edges_to_skip.add((task_group.downstream_join_id,
child.task_id))
+
+ # For every individual task immediately downstream, add edges between
downstream_join_id and
+ # the downstream task. Skip edges between individual tasks of the
TaskGroup and the
+ # downstream task.
+ for target_id in task_group.downstream_task_ids:
+ edges_to_add.add((task_group.downstream_join_id, target_id))
+
+ for child in task_group.get_leaves():
+ edges_to_add.add((child.task_id,
task_group.downstream_join_id))
+ edges_to_skip.add((child.task_id, target_id))
+
+ # For every individual task immediately upstream, add edges between
the upstream task
+ # and upstream_join_id. Skip edges between the upstream task and
individual tasks
+ # of the TaskGroup.
+ for source_id in task_group.upstream_task_ids:
+ edges_to_add.add((source_id, task_group.upstream_join_id))
+ for child in task_group.get_roots():
+ edges_to_add.add((task_group.upstream_join_id, child.task_id))
+ edges_to_skip.add((source_id, child.task_id))
+
+ for child in task_group.children.values():
+ collect_edges(child)
+
+ collect_edges(dag.task_group)
+
+ # Collect all the edges between individual tasks
+ edges = set()
+
+ def get_downstream(task):
+ for child in task.downstream_list:
+ edge = (task.task_id, child.task_id)
+ if edge not in edges:
+ edges.add(edge)
+ get_downstream(child)
+
+ for root in dag.roots:
+ get_downstream(root)
+
+ return [{'source_id': source_id, 'target_id': target_id}
+ for source_id, target_id
+ in sorted(edges.union(edges_to_add) - edges_to_skip)]
+
+
######################################################################################
# Error handlers
######################################################################################
@@ -1608,32 +1766,8 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint:
disable=too-many-public-m
arrange = request.args.get('arrange', dag.orientation)
- nodes = []
- edges = []
- for dag_task in dag.tasks:
- nodes.append({
- 'id': dag_task.task_id,
- 'value': {
- 'label': dag_task.task_id,
- 'labelStyle': "fill:{0};".format(dag_task.ui_fgcolor),
- 'style': "fill:{0};".format(dag_task.ui_color),
- 'rx': 5,
- 'ry': 5,
- }
- })
-
- def get_downstream(task):
- for downstream_task in task.downstream_list:
- edge = {
- 'source_id': task.task_id,
- 'target_id': downstream_task.task_id,
- }
- if edge not in edges:
- edges.append(edge)
- get_downstream(downstream_task)
-
- for dag_task in dag.roots:
- get_downstream(dag_task)
+ nodes = task_group_to_dict(dag.task_group)
+ edges = dag_edges(dag)
dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request,
session, dag)
dt_nr_dr_data['arrange'] = arrange
diff --git a/docs/concepts.rst b/docs/concepts.rst
index 5d2be97..8e37582 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -939,6 +939,48 @@ See ``airflow/example_dags`` for a demonstration.
Note that airflow pool is not honored by SubDagOperator. Hence resources could
be
consumed by SubdagOperators.
+
+TaskGroup
+=========
+TaskGroup can be used to organize tasks into hierarchical groups in Graph
View. It is
+useful for creating repeating patterns and cutting down visual clutter. Unlike
SubDagOperator,
+TaskGroup is a UI grouping concept. Tasks in TaskGroups live on the same
original DAG. They
+honor all the pool configurations.
+
+Dependency relationships can be applied across all tasks in a TaskGroup with
the ``>>`` and ``<<``
+operators. For example, the following code puts ``task1`` and ``task2`` in
TaskGroup ``group1``
+and then puts both tasks upstream of ``task3``:
+
+.. code-block:: python
+
+ with TaskGroup("group1") as group1:
+ task1 = DummyOperator(task_id="task1")
+ task2 = DummyOperator(task_id="task2")
+
+ task3 = DummyOperator(task_id="task3")
+
+ group1 >> task3
+
+.. note::
+ By default, child tasks and TaskGroups have their task_id and group_id
prefixed with the
+ group_id of their parent TaskGroup. This ensures uniqueness of group_id and
task_id throughout
+ the DAG. To disable the prefixing, pass ``prefix_group_id=False`` when
creating the TaskGroup.
+ This then gives the user full control over the actual group_id and task_id.
They have to ensure
+ group_id and task_id are unique throughout the DAG. The option
``prefix_group_id=False`` is
+ mainly useful for putting tasks on existing DAGs into TaskGroup without
altering their task_id.
+
+Here is a more complicated example DAG with multiple levels of nested
TaskGroups:
+
+.. exampleinclude:: /../airflow/example_dags/example_task_group.py
+ :language: python
+ :start-after: [START howto_task_group]
+ :end-before: [END howto_task_group]
+
+This animated gif shows the UI interactions. TaskGroups are expanded or
collapsed when clicked:
+
+.. image:: img/task_group.gif
+
+
SLAs
====
diff --git a/docs/img/task_group.gif b/docs/img/task_group.gif
new file mode 100644
index 0000000..ac4f6e9
Binary files /dev/null and b/docs/img/task_group.gif differ
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 75c35e8..1b3a993 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -67,6 +67,17 @@ serialized_simple_dag_ground_truth = {
}
},
"start_date": 1564617600.0,
+ '_task_group': {'_group_id': None,
+ 'prefix_group_id': True,
+ 'children': {'bash_task': ('operator', 'bash_task'),
+ 'custom_task': ('operator',
'custom_task')},
+ 'tooltip': '',
+ 'ui_color': 'CornflowerBlue',
+ 'ui_fgcolor': '#000',
+ 'upstream_group_ids': [],
+ 'downstream_group_ids': [],
+ 'upstream_task_ids': [],
+ 'downstream_task_ids': []},
"is_paused_upon_creation": False,
"_dag_id": "simple_dag",
"fileloc": None,
@@ -83,6 +94,7 @@ serialized_simple_dag_ground_truth = {
"ui_fgcolor": "#000",
"template_fields": ['bash_command', 'env'],
"bash_command": "echo {{ task.task_id }}",
+ 'label': 'bash_task',
"_task_type": "BashOperator",
"_task_module": "airflow.operators.bash",
"pool": "default_pool",
@@ -107,6 +119,7 @@ serialized_simple_dag_ground_truth = {
"_task_type": "CustomOperator",
"_task_module": "tests.test_utils.mock_operators",
"pool": "default_pool",
+ 'label': 'custom_task',
},
],
"timezone": "UTC",
@@ -329,6 +342,7 @@ class TestStringifiedDAGs(unittest.TestCase):
# Need to check fields in it, to exclude functions
'default_args',
+ "_task_group"
}
for field in fields_to_check:
assert getattr(serialized_dag, field) == getattr(dag, field), \
@@ -765,6 +779,7 @@ class TestStringifiedDAGs(unittest.TestCase):
'execution_timeout': None,
'executor_config': {},
'inlets': [],
+ 'label': '10',
'max_retry_delay': None,
'on_execute_callback': None,
'on_failure_callback': None,
@@ -804,3 +819,51 @@ class TestStringifiedDAGs(unittest.TestCase):
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"""
)
+
+ def test_task_group_serialization(self):
+ """
+ Test TaskGroup serialization/deserialization.
+ """
+ from airflow.operators.dummy_operator import DummyOperator
+ from airflow.utils.task_group import TaskGroup
+
+ execution_date = datetime(2020, 1, 1)
+ with DAG("test_task_group_serialization", start_date=execution_date)
as dag:
+ task1 = DummyOperator(task_id="task1")
+ with TaskGroup("group234") as group234:
+ _ = DummyOperator(task_id="task2")
+
+ with TaskGroup("group34") as group34:
+ _ = DummyOperator(task_id="task3")
+ _ = DummyOperator(task_id="task4")
+
+ task5 = DummyOperator(task_id="task5")
+ task1 >> group234
+ group34 >> task5
+
+ dag_dict = SerializedDAG.to_dict(dag)
+ SerializedDAG.validate_schema(dag_dict)
+ json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+ self.validate_deserialized_dag(json_dag, dag)
+
+ serialized_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
+
+ assert serialized_dag.task_group.children
+ assert serialized_dag.task_group.children.keys() ==
dag.task_group.children.keys()
+
+ def check_task_group(node):
+ try:
+ children = node.children.values()
+ except AttributeError:
+ # Round-trip serialization and check the result
+ expected_serialized =
SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id))
+ expected_deserialized =
SerializedBaseOperator.deserialize_operator(expected_serialized)
+ expected_dict =
SerializedBaseOperator.serialize_operator(expected_deserialized)
+ assert node
+ assert SerializedBaseOperator.serialize_operator(node) ==
expected_dict
+ return
+
+ for child in children:
+ check_task_group(child)
+
+ check_task_group(serialized_dag.task_group)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
new file mode 100644
index 0000000..c4f7a12
--- /dev/null
+++ b/tests/utils/test_task_group.py
@@ -0,0 +1,561 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pendulum
+import pytest
+
+from airflow.models import DAG
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.task_group import TaskGroup
+from airflow.www.views import dag_edges, task_group_to_dict
+
+EXPECTED_JSON = {
+ 'id': None,
+ 'value': {
+ 'label': None,
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:CornflowerBlue',
+ 'rx': 5,
+ 'ry': 5,
+ 'clusterLabelPos': 'top',
+ },
+ 'tooltip': '',
+ 'children': [
+ {
+ 'id': 'group234',
+ 'value': {
+ 'label': 'group234',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:CornflowerBlue',
+ 'rx': 5,
+ 'ry': 5,
+ 'clusterLabelPos': 'top',
+ },
+ 'tooltip': '',
+ 'children': [
+ {
+ 'id': 'group234.group34',
+ 'value': {
+ 'label': 'group34',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:CornflowerBlue',
+ 'rx': 5,
+ 'ry': 5,
+ 'clusterLabelPos': 'top',
+ },
+ 'tooltip': '',
+ 'children': [
+ {
+ 'id': 'group234.group34.task3',
+ 'value': {
+ 'label': 'task3',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:#e8f7e4;',
+ 'rx': 5,
+ 'ry': 5,
+ },
+ },
+ {
+ 'id': 'group234.group34.task4',
+ 'value': {
+ 'label': 'task4',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:#e8f7e4;',
+ 'rx': 5,
+ 'ry': 5,
+ },
+ },
+ {
+ 'id': 'group234.group34.downstream_join_id',
+ 'value': {
+ 'label': '',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:CornflowerBlue;',
+ 'shape': 'circle',
+ },
+ },
+ ],
+ },
+ {
+ 'id': 'group234.task2',
+ 'value': {
+ 'label': 'task2',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:#e8f7e4;',
+ 'rx': 5,
+ 'ry': 5,
+ },
+ },
+ {
+ 'id': 'group234.upstream_join_id',
+ 'value': {
+ 'label': '',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:CornflowerBlue;',
+ 'shape': 'circle',
+ },
+ },
+ ],
+ },
+ {
+ 'id': 'task1',
+ 'value': {
+ 'label': 'task1',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:#e8f7e4;',
+ 'rx': 5,
+ 'ry': 5,
+ },
+ },
+ {
+ 'id': 'task5',
+ 'value': {
+ 'label': 'task5',
+ 'labelStyle': 'fill:#000;',
+ 'style': 'fill:#e8f7e4;',
+ 'rx': 5,
+ 'ry': 5,
+ },
+ },
+ ],
+}
+
+
+def test_build_task_group_context_manager():
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_build_task_group_context_manager",
start_date=execution_date) as dag:
+ task1 = DummyOperator(task_id="task1")
+ with TaskGroup("group234") as group234:
+ _ = DummyOperator(task_id="task2")
+
+ with TaskGroup("group34") as group34:
+ _ = DummyOperator(task_id="task3")
+ _ = DummyOperator(task_id="task4")
+
+ task5 = DummyOperator(task_id="task5")
+ task1 >> group234
+ group34 >> task5
+
+ assert task1.get_direct_relative_ids(upstream=False) == {
+ 'group234.group34.task4',
+ 'group234.group34.task3',
+ 'group234.task2',
+ }
+ assert task5.get_direct_relative_ids(upstream=True) == {
+ 'group234.group34.task4',
+ 'group234.group34.task3',
+ }
+
+ assert dag.task_group.group_id is None
+ assert dag.task_group.is_root
+ assert set(dag.task_group.children.keys()) == {"task1", "group234",
"task5"}
+ assert group34.group_id == "group234.group34"
+
+ assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
+
+
+def test_build_task_group():
+ """
+ This is an alternative syntax to use TaskGroup. It should result in the
same TaskGroup
+ as using context manager.
+ """
+ execution_date = pendulum.parse("20200101")
+ dag = DAG("test_build_task_group", start_date=execution_date)
+ task1 = DummyOperator(task_id="task1", dag=dag)
+ group234 = TaskGroup("group234", dag=dag)
+ _ = DummyOperator(task_id="task2", dag=dag, task_group=group234)
+ group34 = TaskGroup("group34", dag=dag, parent_group=group234)
+ _ = DummyOperator(task_id="task3", dag=dag, task_group=group34)
+ _ = DummyOperator(task_id="task4", dag=dag, task_group=group34)
+ task5 = DummyOperator(task_id="task5", dag=dag)
+
+ task1 >> group234
+ group34 >> task5
+
+ assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
+
+
+def extract_node_id(node, include_label=False):
+ ret = {"id": node["id"]}
+ if include_label:
+ ret["label"] = node["value"]["label"]
+ if "children" in node:
+ children = []
+ for child in node["children"]:
+ children.append(extract_node_id(child,
include_label=include_label))
+
+ ret["children"] = children
+
+ return ret
+
+
+def test_build_task_group_with_prefix():
+ """
+ Tests that prefix_group_id turns on/off prefixing of task_id with group_id.
+ """
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_build_task_group_with_prefix", start_date=execution_date)
as dag:
+ task1 = DummyOperator(task_id="task1")
+ with TaskGroup("group234", prefix_group_id=False) as group234:
+ task2 = DummyOperator(task_id="task2")
+
+ with TaskGroup("group34") as group34:
+ task3 = DummyOperator(task_id="task3")
+
+ with TaskGroup("group4", prefix_group_id=False) as group4:
+ task4 = DummyOperator(task_id="task4")
+
+ task5 = DummyOperator(task_id="task5")
+ task1 >> group234
+ group34 >> task5
+
+ assert task2.task_id == "task2"
+ assert group34.group_id == "group34"
+ assert task3.task_id == "group34.task3"
+ assert group4.group_id == "group34.group4"
+ assert task4.task_id == "task4"
+ assert task5.task_id == "task5"
+ assert group234.get_child_by_label("task2") == task2
+ assert group234.get_child_by_label("group34") == group34
+ assert group4.get_child_by_label("task4") == task4
+
+ assert extract_node_id(task_group_to_dict(dag.task_group),
include_label=True) == {
+ 'id': None,
+ 'label': None,
+ 'children': [
+ {
+ 'id': 'group234',
+ 'label': 'group234',
+ 'children': [
+ {
+ 'id': 'group34',
+ 'label': 'group34',
+ 'children': [
+ {
+ 'id': 'group34.group4',
+ 'label': 'group4',
+ 'children': [{'id': 'task4', 'label':
'task4'}],
+ },
+ {'id': 'group34.task3', 'label': 'task3'},
+ {'id': 'group34.downstream_join_id', 'label': ''},
+ ],
+ },
+ {'id': 'task2', 'label': 'task2'},
+ {'id': 'group234.upstream_join_id', 'label': ''},
+ ],
+ },
+ {'id': 'task1', 'label': 'task1'},
+ {'id': 'task5', 'label': 'task5'},
+ ],
+ }
+
+
+def test_build_task_group_with_task_decorator():
+ """
+ Test that TaskGroup can be used with the @task decorator.
+ """
+ from airflow.operators.python import task
+
+ @task
+ def task_1():
+ print("task_1")
+
+ @task
+ def task_2():
+ return "task_2"
+
+ @task
+ def task_3():
+ return "task_3"
+
+ @task
+ def task_4(task_2_output, task_3_output):
+ print(task_2_output, task_3_output)
+
+ @task
+ def task_5():
+ print("task_5")
+
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_build_task_group_with_task_decorator",
start_date=execution_date) as dag:
+ tsk_1 = task_1()
+
+ with TaskGroup("group234") as group234:
+ tsk_2 = task_2()
+ tsk_3 = task_3()
+ tsk_4 = task_4(tsk_2, tsk_3)
+
+ tsk_5 = task_5()
+
+ tsk_1 >> group234 >> tsk_5
+
+ # pylint: disable=no-member
+ assert tsk_1.operator in tsk_2.operator.upstream_list
+ assert tsk_1.operator in tsk_3.operator.upstream_list
+ assert tsk_5.operator in tsk_4.operator.downstream_list
+ # pylint: enable=no-member
+
+ assert extract_node_id(task_group_to_dict(dag.task_group)) == {
+ 'id': None,
+ 'children': [
+ {
+ 'id': 'group234',
+ 'children': [
+ {'id': 'group234.task_2'},
+ {'id': 'group234.task_3'},
+ {'id': 'group234.task_4'},
+ {'id': 'group234.upstream_join_id'},
+ {'id': 'group234.downstream_join_id'},
+ ],
+ },
+ {'id': 'task_1'},
+ {'id': 'task_5'},
+ ],
+ }
+
+ edges = dag_edges(dag)
+ assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+ ('group234.downstream_join_id', 'task_5'),
+ ('group234.task_2', 'group234.task_4'),
+ ('group234.task_3', 'group234.task_4'),
+ ('group234.task_4', 'group234.downstream_join_id'),
+ ('group234.upstream_join_id', 'group234.task_2'),
+ ('group234.upstream_join_id', 'group234.task_3'),
+ ('task_1', 'group234.upstream_join_id'),
+ ]
+
+
+def test_sub_dag_task_group():
+ """
+ Tests dag.sub_dag() updates task_group correctly.
+ """
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag:
+ task1 = DummyOperator(task_id="task1")
+ with TaskGroup("group234") as group234:
+ _ = DummyOperator(task_id="task2")
+
+ with TaskGroup("group34") as group34:
+ _ = DummyOperator(task_id="task3")
+ _ = DummyOperator(task_id="task4")
+
+ with TaskGroup("group6") as group6:
+ _ = DummyOperator(task_id="task6")
+
+ task7 = DummyOperator(task_id="task7")
+ task5 = DummyOperator(task_id="task5")
+
+ task1 >> group234
+ group34 >> task5
+ group234 >> group6
+ group234 >> task7
+
+ subdag = dag.sub_dag(task_regex="task5", include_upstream=True,
include_downstream=False)
+
+ assert extract_node_id(task_group_to_dict(subdag.task_group)) == {
+ 'id': None,
+ 'children': [
+ {
+ 'id': 'group234',
+ 'children': [
+ {
+ 'id': 'group234.group34',
+ 'children': [
+ {'id': 'group234.group34.task3'},
+ {'id': 'group234.group34.task4'},
+ {'id': 'group234.group34.downstream_join_id'},
+ ],
+ },
+ {'id': 'group234.upstream_join_id'},
+ ],
+ },
+ {'id': 'task1'},
+ {'id': 'task5'},
+ ],
+ }
+
+ edges = dag_edges(subdag)
+ assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+ ('group234.group34.downstream_join_id', 'task5'),
+ ('group234.group34.task3', 'group234.group34.downstream_join_id'),
+ ('group234.group34.task4', 'group234.group34.downstream_join_id'),
+ ('group234.upstream_join_id', 'group234.group34.task3'),
+ ('group234.upstream_join_id', 'group234.group34.task4'),
+ ('task1', 'group234.upstream_join_id'),
+ ]
+
+ subdag_task_groups = subdag.task_group.get_task_group_dict()
+ assert subdag_task_groups.keys() == {None, "group234", "group234.group34"}
+
+ included_group_ids = {"group234", "group234.group34"}
+ included_task_ids = {'group234.group34.task3', 'group234.group34.task4',
'task1', 'task5'}
+
+ for task_group in subdag_task_groups.values():
+ assert task_group.upstream_group_ids.issubset(included_group_ids)
+ assert task_group.downstream_group_ids.issubset(included_group_ids)
+ assert task_group.upstream_task_ids.issubset(included_task_ids)
+ assert task_group.downstream_task_ids.issubset(included_task_ids)
+
+ for task in subdag.task_group:
+ assert task.upstream_task_ids.issubset(included_task_ids)
+ assert task.downstream_task_ids.issubset(included_task_ids)
+
+
+def test_dag_edges():
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_dag_edges", start_date=execution_date) as dag:
+ task1 = DummyOperator(task_id="task1")
+ with TaskGroup("group_a") as group_a:
+ with TaskGroup("group_b") as group_b:
+ task2 = DummyOperator(task_id="task2")
+ task3 = DummyOperator(task_id="task3")
+ task4 = DummyOperator(task_id="task4")
+ task2 >> [task3, task4]
+
+ task5 = DummyOperator(task_id="task5")
+
+ task5 << group_b
+
+ task1 >> group_a
+
+ with TaskGroup("group_c") as group_c:
+ task6 = DummyOperator(task_id="task6")
+ task7 = DummyOperator(task_id="task7")
+ task8 = DummyOperator(task_id="task8")
+ [task6, task7] >> task8
+ group_a >> group_c
+
+ task5 >> task8
+
+ task9 = DummyOperator(task_id="task9")
+ task10 = DummyOperator(task_id="task10")
+
+ group_c >> [task9, task10]
+
+ with TaskGroup("group_d") as group_d:
+ task11 = DummyOperator(task_id="task11")
+ task12 = DummyOperator(task_id="task12")
+ task11 >> task12
+
+ group_d << group_c
+
+ nodes = task_group_to_dict(dag.task_group)
+ edges = dag_edges(dag)
+
+ assert extract_node_id(nodes) == {
+ 'id': None,
+ 'children': [
+ {
+ 'id': 'group_a',
+ 'children': [
+ {
+ 'id': 'group_a.group_b',
+ 'children': [
+ {'id': 'group_a.group_b.task2'},
+ {'id': 'group_a.group_b.task3'},
+ {'id': 'group_a.group_b.task4'},
+ {'id': 'group_a.group_b.downstream_join_id'},
+ ],
+ },
+ {'id': 'group_a.task5'},
+ {'id': 'group_a.upstream_join_id'},
+ {'id': 'group_a.downstream_join_id'},
+ ],
+ },
+ {
+ 'id': 'group_c',
+ 'children': [
+ {'id': 'group_c.task6'},
+ {'id': 'group_c.task7'},
+ {'id': 'group_c.task8'},
+ {'id': 'group_c.upstream_join_id'},
+ {'id': 'group_c.downstream_join_id'},
+ ],
+ },
+ {
+ 'id': 'group_d',
+ 'children': [
+ {'id': 'group_d.task11'},
+ {'id': 'group_d.task12'},
+ {'id': 'group_d.upstream_join_id'},
+ ],
+ },
+ {'id': 'task1'},
+ {'id': 'task10'},
+ {'id': 'task9'},
+ ],
+ }
+
+ assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+ ('group_a.downstream_join_id', 'group_c.upstream_join_id'),
+ ('group_a.group_b.downstream_join_id', 'group_a.task5'),
+ ('group_a.group_b.task2', 'group_a.group_b.task3'),
+ ('group_a.group_b.task2', 'group_a.group_b.task4'),
+ ('group_a.group_b.task3', 'group_a.group_b.downstream_join_id'),
+ ('group_a.group_b.task4', 'group_a.group_b.downstream_join_id'),
+ ('group_a.task5', 'group_a.downstream_join_id'),
+ ('group_a.task5', 'group_c.task8'),
+ ('group_a.upstream_join_id', 'group_a.group_b.task2'),
+ ('group_c.downstream_join_id', 'group_d.upstream_join_id'),
+ ('group_c.downstream_join_id', 'task10'),
+ ('group_c.downstream_join_id', 'task9'),
+ ('group_c.task6', 'group_c.task8'),
+ ('group_c.task7', 'group_c.task8'),
+ ('group_c.task8', 'group_c.downstream_join_id'),
+ ('group_c.upstream_join_id', 'group_c.task6'),
+ ('group_c.upstream_join_id', 'group_c.task7'),
+ ('group_d.task11', 'group_d.task12'),
+ ('group_d.upstream_join_id', 'group_d.task11'),
+ ('task1', 'group_a.upstream_join_id'),
+ ]
+
+
+def test_duplicate_group_id():
+ from airflow.exceptions import DuplicateTaskIdFound
+
+ execution_date = pendulum.parse("20200101")
+
+ with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"):
+ with DAG("test_duplicate_group_id", start_date=execution_date):
+ _ = DummyOperator(task_id="task1")
+ with TaskGroup("task1"):
+ pass
+
+ with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"):
+ with DAG("test_duplicate_group_id", start_date=execution_date):
+ _ = DummyOperator(task_id="task1")
+ with TaskGroup("group1", prefix_group_id=False):
+ with TaskGroup("group1"):
+ pass
+
+ with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"):
+ with DAG("test_duplicate_group_id", start_date=execution_date):
+ with TaskGroup("group1", prefix_group_id=False):
+ _ = DummyOperator(task_id="group1")
+
+ with pytest.raises(DuplicateTaskIdFound, match=r".*
'group1.downstream_join_id' .*"):
+ with DAG("test_duplicate_group_id", start_date=execution_date):
+ _ = DummyOperator(task_id="task1")
+ with TaskGroup("group1"):
+ _ = DummyOperator(task_id="downstream_join_id")
+
+ with pytest.raises(DuplicateTaskIdFound, match=r".*
'group1.upstream_join_id' .*"):
+ with DAG("test_duplicate_group_id", start_date=execution_date):
+ _ = DummyOperator(task_id="task1")
+ with TaskGroup("group1"):
+ _ = DummyOperator(task_id="upstream_join_id")