This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 49c193f  [AIP-34] TaskGroup: A UI task grouping concept as an 
alternative to SubDagOperator (#10153)
49c193f is described below

commit 49c193fb872856500d8919facf45b9ab5207a093
Author: yuqian90 <[email protected]>
AuthorDate: Sat Sep 19 08:51:37 2020 +0800

    [AIP-34] TaskGroup: A UI task grouping concept as an alternative to 
SubDagOperator (#10153)
    
    This commit introduces TaskGroup, which is a simple UI task grouping 
concept.
    
    - TaskGroups can be collapsed/expanded in Graph View when clicked
    - TaskGroups can be nested
    - TaskGroups can be put upstream/downstream of tasks or other TaskGroups 
with >> and << operators
    - Search box, hovering, focusing in Graph View treats TaskGroup properly. 
E.g. searching for tasks also highlights TaskGroup that contains matching 
task_id. When TaskGroup is expanded/collapsed, the affected TaskGroup is put in 
focus and moved to the centre of the graph.
    
    
    What this commit does not do:
    
    - This commit does not change or remove SubDagOperator. Although TaskGroup 
is intended as an alternative for SubDagOperator, deprecating SubDagOperator 
will need to be discussed/implemented in the future.
    - This PR only implemented TaskGroup handling in the Graph View. In places 
such as Tree View, it will look like as-if
    - TaskGroup does not exist and all tasks are in the same flat DAG.
    
    GitHub Issue: #8078
    AIP: 
https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-34+TaskGroup%3A+A+UI+task+grouping+concept+as+an+alternative+to+SubDagOperator
---
 airflow/example_dags/example_task_group.py    |  57 +++
 airflow/models/baseoperator.py                |  34 +-
 airflow/models/dag.py                         |  51 ++-
 airflow/models/taskmixin.py                   |  11 +
 airflow/models/xcom_arg.py                    |   5 +
 airflow/serialization/enums.py                |   1 +
 airflow/serialization/schema.json             |  48 ++-
 airflow/serialization/serialized_objects.py   |  92 +++++
 airflow/utils/task_group.py                   | 379 +++++++++++++++++
 airflow/www/static/css/graph.css              |  14 +
 airflow/www/templates/airflow/graph.html      | 458 ++++++++++++++++-----
 airflow/www/views.py                          | 186 +++++++--
 docs/concepts.rst                             |  42 ++
 docs/img/task_group.gif                       | Bin 0 -> 609981 bytes
 tests/serialization/test_dag_serialization.py |  63 +++
 tests/utils/test_task_group.py                | 561 ++++++++++++++++++++++++++
 16 files changed, 1857 insertions(+), 145 deletions(-)

diff --git a/airflow/example_dags/example_task_group.py 
b/airflow/example_dags/example_task_group.py
new file mode 100644
index 0000000..17134df5
--- /dev/null
+++ b/airflow/example_dags/example_task_group.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the TaskGroup."""
+
+from airflow.models.dag import DAG
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.dates import days_ago
+from airflow.utils.task_group import TaskGroup
+
+# [START howto_task_group]
+with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
+    start = DummyOperator(task_id="start")
+
+    # [START howto_task_group_section_1]
+    with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
+        task_1 = DummyOperator(task_id="task_1")
+        task_2 = DummyOperator(task_id="task_2")
+        task_3 = DummyOperator(task_id="task_3")
+
+        task_1 >> [task_2, task_3]
+    # [END howto_task_group_section_1]
+
+    # [START howto_task_group_section_2]
+    with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
+        task_1 = DummyOperator(task_id="task_1")
+
+        # [START howto_task_group_inner_section_2]
+        with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") 
as inner_section_2:
+            task_2 = DummyOperator(task_id="task_2")
+            task_3 = DummyOperator(task_id="task_3")
+            task_4 = DummyOperator(task_id="task_4")
+
+            [task_2, task_3] >> task_4
+        # [END howto_task_group_inner_section_2]
+
+    # [END howto_task_group_section_2]
+
+    end = DummyOperator(task_id='end')
+
+    start >> section_1 >> section_2 >> end
+# [END howto_task_group]
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 4058f05..6d48a27 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -27,7 +27,8 @@ import warnings
 from abc import ABCMeta, abstractmethod
 from datetime import datetime, timedelta
 from typing import (
-    Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, 
Sequence, Set, Tuple, Type, Union,
+    TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, 
Optional, Sequence, Set, Tuple,
+    Type, Union,
 )
 
 import attr
@@ -58,6 +59,9 @@ from airflow.utils.session import provide_session
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
+if TYPE_CHECKING:
+    from airflow.utils.task_group import TaskGroup  # pylint: 
disable=cyclic-import
+
 ScheduleInterval = Union[str, timedelta, relativedelta]
 
 TaskStateChangeCallback = Callable[[Context], None]
@@ -360,9 +364,12 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, 
metaclass=BaseOperatorMeta
         do_xcom_push: bool = True,
         inlets: Optional[Any] = None,
         outlets: Optional[Any] = None,
+        task_group: Optional["TaskGroup"] = None,
         **kwargs
     ):
         from airflow.models.dag import DagContext
+        from airflow.utils.task_group import TaskGroupContext
+
         super().__init__()
         if kwargs:
             if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
@@ -382,6 +389,11 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, 
metaclass=BaseOperatorMeta
             )
         validate_key(task_id)
         self.task_id = task_id
+        self.label = task_id
+        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+        if task_group:
+            self.task_id = task_group.child_id(task_id)
+            task_group.add(self)
         self.owner = owner
         self.email = email
         self.email_on_retry = email_on_retry
@@ -609,7 +621,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, 
metaclass=BaseOperatorMeta
         elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is 
not self:
             dag.add_task(self)
 
-        self._dag = dag  # pylint: disable=attribute-defined-outside-init
+        self._dag = dag
 
     def has_dag(self):
         """
@@ -1120,21 +1132,25 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, 
metaclass=BaseOperatorMeta
         """Required by TaskMixin"""
         return [self]
 
+    @property
+    def leaves(self) -> List["BaseOperator"]:
+        """Required by TaskMixin"""
+        return [self]
+
     def _set_relatives(
         self,
         task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
         upstream: bool = False,
     ) -> None:
         """Sets relatives for the task or task list."""
-
-        if isinstance(task_or_task_list, Sequence):
-            task_like_object_list = task_or_task_list
-        else:
-            task_like_object_list = [task_or_task_list]
+        if not isinstance(task_or_task_list, Sequence):
+            task_or_task_list = [task_or_task_list]
 
         task_list: List["BaseOperator"] = []
-        for task_object in task_like_object_list:
-            task_list.extend(task_object.roots)
+        for task_object in task_or_task_list:
+            task_object.update_relative(self, not upstream)
+            relatives = task_object.leaves if upstream else task_object.roots
+            task_list.extend(relatives)
 
         for task in task_list:
             if not isinstance(task, BaseOperator):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index eecf6b4..886837c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,9 @@ import traceback
 import warnings
 from collections import OrderedDict
 from datetime import datetime, timedelta
-from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, 
Optional, Set, Type, Union, cast
+from typing import (
+    TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, 
Optional, Set, Type, Union, cast,
+)
 
 import jinja2
 import pendulum
@@ -59,6 +61,9 @@ from airflow.utils.sqlalchemy import Interval, UtcDateTime
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 
+if TYPE_CHECKING:
+    from airflow.utils.task_group import TaskGroup
+
 log = logging.getLogger(__name__)
 
 ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -238,6 +243,8 @@ class DAG(BaseDag, LoggingMixin):
         jinja_environment_kwargs: Optional[Dict] = None,
         tags: Optional[List[str]] = None
     ):
+        from airflow.utils.task_group import TaskGroup
+
         self.user_defined_macros = user_defined_macros
         self.user_defined_filters = user_defined_filters
         self.default_args = copy.deepcopy(default_args or {})
@@ -329,6 +336,7 @@ class DAG(BaseDag, LoggingMixin):
 
         self.jinja_environment_kwargs = jinja_environment_kwargs
         self.tags = tags
+        self._task_group = TaskGroup.create_root(self)
 
     def __repr__(self):
         return "<DAG: {self.dag_id}>".format(self=self)
@@ -571,6 +579,10 @@ class DAG(BaseDag, LoggingMixin):
         return list(self.task_dict.keys())
 
     @property
+    def task_group(self) -> "TaskGroup":
+        return self._task_group
+
+    @property
     def filepath(self) -> str:
         """
         File location of where the dag object is instantiated
@@ -1240,7 +1252,6 @@ class DAG(BaseDag, LoggingMixin):
         based on a regex that should match one or many tasks, and includes
         upstream and downstream neighbours based on the flag passed.
         """
-
         # deep-copying self.task_dict takes a long time, and we don't want all
         # the tasks anyway, so we copy the tasks manually later
         task_dict = self.task_dict
@@ -1261,9 +1272,38 @@ class DAG(BaseDag, LoggingMixin):
         # Make sure to not recursively deepcopy the dag while copying the task
         dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
                          for t in regex_match + also_include}
+
+        # Remove tasks not included in the subdag from task_group
+        def remove_excluded(group):
+            for child in list(group.children.values()):
+                if isinstance(child, BaseOperator):
+                    if child.task_id not in dag.task_dict:
+                        group.children.pop(child.task_id)
+                    else:
+                        # The tasks in the subdag are a copy of tasks in the 
original dag
+                        # so update the reference in the TaskGroups too.
+                        group.children[child.task_id] = 
dag.task_dict[child.task_id]
+                else:
+                    remove_excluded(child)
+
+                    # Remove this TaskGroup if it doesn't contain any tasks in 
this subdag
+                    if not child.children:
+                        group.children.pop(child.group_id)
+
+        remove_excluded(dag.task_group)
+
+        # Removing upstream/downstream references to tasks and TaskGroups that 
did not make
+        # the cut.
+        subdag_task_groups = dag.task_group.get_task_group_dict()
+        for group in subdag_task_groups.values():
+            group.upstream_group_ids = 
group.upstream_group_ids.intersection(subdag_task_groups.keys())
+            group.downstream_group_ids = 
group.downstream_group_ids.intersection(subdag_task_groups.keys())
+            group.upstream_task_ids = 
group.upstream_task_ids.intersection(dag.task_dict.keys())
+            group.downstream_task_ids = 
group.downstream_task_ids.intersection(dag.task_dict.keys())
+
         for t in dag.tasks:
             # Removing upstream/downstream references to tasks that did not
-            # made the cut
+            # make the cut
             t._upstream_task_ids = 
t.upstream_task_ids.intersection(dag.task_dict.keys())
             t._downstream_task_ids = t.downstream_task_ids.intersection(
                 dag.task_dict.keys())
@@ -1357,12 +1397,15 @@ class DAG(BaseDag, LoggingMixin):
         elif task.end_date and self.end_date:
             task.end_date = min(task.end_date, self.end_date)
 
-        if task.task_id in self.task_dict and self.task_dict[task.task_id] is 
not task:
+        if ((task.task_id in self.task_dict and self.task_dict[task.task_id] 
is not task)
+                or task.task_id in self._task_group.used_group_ids):
             raise DuplicateTaskIdFound(
                 "Task id '{}' has already been added to the 
DAG".format(task.task_id))
         else:
             self.task_dict[task.task_id] = task
             task.dag = self
+            # Add task_id to used_group_ids to prevent group_id and task_id 
collisions.
+            self._task_group.used_group_ids.add(task.task_id)
 
         self.task_count = len(self.task_dict)
 
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index a3d4224..cfdc714 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -33,6 +33,11 @@ class TaskMixin:
         """Should return list of root operator List[BaseOperator]"""
         raise NotImplementedError()
 
+    @property
+    def leaves(self):
+        """Should return list of leaf operator List[BaseOperator]"""
+        raise NotImplementedError()
+
     @abstractmethod
     def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
         """
@@ -47,6 +52,12 @@ class TaskMixin:
         """
         raise NotImplementedError()
 
+    def update_relative(self, other: "TaskMixin", upstream=True) -> None:
+        """
+        Update relationship information about another TaskMixin. Default is 
no-op.
+        Override if necessary.
+        """
+
     def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
         """
         Implements Task << Task
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 0f647bf..b9faaaba 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -103,6 +103,11 @@ class XComArg(TaskMixin):
         return [self._operator]
 
     @property
+    def leaves(self) -> List[BaseOperator]:
+        """Required by TaskMixin"""
+        return [self._operator]
+
+    @property
     def key(self) -> str:
         """Returns keys of this XComArg"""
         return self._key
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index b6fae5e..9a30231 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum):
     SET = 'set'
     TUPLE = 'tuple'
     POD = 'k8s.V1Pod'
+    TASK_GROUP = 'taskgroup'
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index 49de949..9056eaa 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -96,7 +96,11 @@
         "_default_view": { "type" : "string"},
         "_access_control": {"$ref": "#/definitions/dict" },
         "is_paused_upon_creation":  { "type": "boolean" },
-        "tags": { "type": "array" }
+        "tags": { "type": "array" },
+        "_task_group": {"anyOf": [
+          { "type": "null" },
+          { "$ref": "#/definitions/task_group" }
+        ]}
       },
       "required": [
         "_dag_id",
@@ -125,6 +129,7 @@
         "_task_module": { "type": "string" },
         "_operator_extra_links": { "$ref":  "#/definitions/extra_links" },
         "task_id": { "type": "string" },
+        "label": { "type": "string" },
         "owner": { "type": "string" },
         "start_date": { "$ref": "#/definitions/datetime" },
         "end_date": { "$ref": "#/definitions/datetime" },
@@ -156,6 +161,47 @@
         }
       },
       "additionalProperties": true
+    },
+    "task_group": {
+      "$comment": "A TaskGroup containing tasks",
+      "type": "object",
+      "required": [
+        "_group_id",
+        "prefix_group_id",
+        "children",
+        "tooltip",
+        "ui_color",
+        "ui_fgcolor",
+        "upstream_group_ids",
+        "downstream_group_ids",
+        "upstream_task_ids",
+        "downstream_task_ids"
+      ],
+      "properties": {
+        "_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
+        "prefix_group_id": { "type": "boolean" },
+        "children":  { "$ref": "#/definitions/dict" },
+        "tooltip": { "type": "string" },
+        "ui_color": { "type": "string" },
+        "ui_fgcolor": { "type": "string" },
+        "upstream_group_ids": {
+          "type": "array",
+          "items": { "type": "string" }
+        },
+        "downstream_group_ids": {
+          "type": "array",
+          "items": { "type": "string" }
+        },
+        "upstream_task_ids": {
+          "type": "array",
+          "items": { "type": "string" }
+        },
+        "downstream_task_ids": {
+          "type": "array",
+          "items": { "type": "string" }
+        }
+      },
+      "additionalProperties": false
     }
   },
 
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 6129a9d..41c6bc7 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -44,6 +44,7 @@ from airflow.serialization.json_schema import Validator, 
load_dag_schema
 from airflow.settings import json
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.module_loading import import_string
+from airflow.utils.task_group import TaskGroup
 
 log = logging.getLogger(__name__)
 FAILED = 'serialization_failed'
@@ -221,6 +222,8 @@ class BaseSerialization:
                 # FIXME: casts tuple to list in customized serialization in 
future.
                 return cls._encode(
                     [cls._serialize(v) for v in var], type_=DAT.TUPLE)
+            elif isinstance(var, TaskGroup):
+                return SerializedTaskGroup.serialize_task_group(var)
             else:
                 log.debug('Cast type %s to str in serialization.', type(var))
                 return str(var)
@@ -376,6 +379,10 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         # Extra Operator Links defined in Plugins
         op_extra_links_from_plugin = {}
 
+        if "label" not in encoded_op:
+            # Handle deserialization of old data before the introduction of 
TaskGroup
+            encoded_op["label"] = encoded_op["task_id"]
+
         for ope in plugins_manager.operator_extra_links:
             for operator in ope.operators:
                 if operator.__name__ == encoded_op["_task_type"] and \
@@ -570,6 +577,7 @@ class SerializedDAG(DAG, BaseSerialization):
         serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields)
 
         serialize_dag["tasks"] = [cls._serialize(task) for _, task in 
dag.task_dict.items()]
+        serialize_dag['_task_group'] = 
SerializedTaskGroup.serialize_task_group(dag.task_group)
         return serialize_dag
 
     @classmethod
@@ -598,6 +606,22 @@ class SerializedDAG(DAG, BaseSerialization):
 
             setattr(dag, k, v)
 
+        # Set _task_group
+        # pylint: disable=protected-access
+        if "_task_group" in encoded_dag:
+            dag._task_group = SerializedTaskGroup.deserialize_task_group(  # 
type: ignore
+                encoded_dag["_task_group"],
+                None,
+                dag.task_dict
+            )
+        else:
+            # This must be old data that had no task_group. Create a root 
TaskGroup and add
+            # all tasks to it.
+            dag._task_group = TaskGroup.create_root(dag)
+            for task in dag.tasks:
+                dag.task_group.add(task)
+        # pylint: enable=protected-access
+
         keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - 
cls._CONSTRUCTOR_PARAMS.keys()
         for k in keys_to_set_none:
             setattr(dag, k, None)
@@ -641,3 +665,71 @@ class SerializedDAG(DAG, BaseSerialization):
         if ver != cls.SERIALIZER_VERSION:
             raise ValueError("Unsure how to deserialize version 
{!r}".format(ver))
         return cls.deserialize_dag(serialized_obj['dag'])
+
+
+class SerializedTaskGroup(TaskGroup, BaseSerialization):
+    """
+    A JSON serializable representation of TaskGroup.
+    """
+    @classmethod
+    def serialize_task_group(cls, task_group: TaskGroup) -> 
Optional[Union[Dict[str, Any]]]:
+        """
+        Serializes TaskGroup into a JSON object.
+        """
+        if not task_group:
+            return None
+
+        serialize_group = {
+            "_group_id": task_group._group_id,  # pylint: 
disable=protected-access
+            "prefix_group_id": task_group.prefix_group_id,
+            "tooltip": task_group.tooltip,
+            "ui_color": task_group.ui_color,
+            "ui_fgcolor": task_group.ui_fgcolor,
+            "children": {
+                label: (DAT.OP, child.task_id)
+                if isinstance(child, BaseOperator) else
+                (DAT.TASK_GROUP, 
SerializedTaskGroup.serialize_task_group(child))
+                for label, child in task_group.children.items()
+            },
+            "upstream_group_ids": 
cls._serialize(list(task_group.upstream_group_ids)),
+            "downstream_group_ids": 
cls._serialize(list(task_group.downstream_group_ids)),
+            "upstream_task_ids": 
cls._serialize(list(task_group.upstream_task_ids)),
+            "downstream_task_ids": 
cls._serialize(list(task_group.downstream_task_ids)),
+
+        }
+
+        return serialize_group
+
+    @classmethod
+    def deserialize_task_group(
+        cls,
+        encoded_group: Dict[str, Any],
+        parent_group: Optional[TaskGroup],
+        task_dict: Dict[str, BaseOperator]
+    ) -> Optional[TaskGroup]:
+        """
+        Deserializes a TaskGroup from a JSON object.
+        """
+        if not encoded_group:
+            return None
+
+        group_id = cls._deserialize(encoded_group["_group_id"])
+        kwargs = {
+            key: cls._deserialize(encoded_group[key])
+            for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
+        }
+        group = SerializedTaskGroup(
+            group_id=group_id,
+            parent_group=parent_group,
+            **kwargs
+        )
+        group.children = {
+            label: task_dict[val] if _type == DAT.OP  # type: ignore
+            else SerializedTaskGroup.deserialize_task_group(val, group, 
task_dict) for label, (_type, val)
+            in encoded_group["children"].items()
+        }
+        group.upstream_group_ids = 
set(cls._deserialize(encoded_group["upstream_group_ids"]))
+        group.downstream_group_ids = 
set(cls._deserialize(encoded_group["downstream_group_ids"]))
+        group.upstream_task_ids = 
set(cls._deserialize(encoded_group["upstream_task_ids"]))
+        group.downstream_task_ids = 
set(cls._deserialize(encoded_group["downstream_task_ids"]))
+        return group
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
new file mode 100644
index 0000000..84cc540
--- /dev/null
+++ b/airflow/utils/task_group.py
@@ -0,0 +1,379 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+A TaskGroup is a collection of closely related tasks on the same DAG that 
should be grouped
+together when the DAG is displayed graphically.
+"""
+
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, 
Set, Union
+
+from airflow.exceptions import AirflowException, DuplicateTaskIdFound
+from airflow.models.taskmixin import TaskMixin
+
+if TYPE_CHECKING:
+    from airflow.models.baseoperator import BaseOperator
+    from airflow.models.dag import DAG
+
+
+class TaskGroup(TaskMixin):
+    """
+    A collection of tasks. When set_downstream() or set_upstream() are called 
on the
+    TaskGroup, it is applied across all tasks within the group if necessary.
+
+    :param group_id: a unique, meaningful id for the TaskGroup. group_id must 
not conflict
+        with group_id of TaskGroup or task_id of tasks in the DAG. Root 
TaskGroup has group_id
+        set to None.
+    :type group_id: str
+    :param prefix_group_id: If set to True, child task_id and group_id will be 
prefixed with
+        this TaskGroup's group_id. If set to False, child task_id and group_id 
are not prefixed.
+        Default is True.
+    :type prerfix_group_id: bool
+    :param parent_group: The parent TaskGroup of this TaskGroup. parent_group 
is set to None
+        for the root TaskGroup.
+    :type parent_group: TaskGroup
+    :param dag: The DAG that this TaskGroup belongs to.
+    :type dag: airflow.models.DAG
+    :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
+    :type tooltip: str
+    :param ui_color: The fill color of the TaskGroup node when displayed in 
the UI
+    :type ui_color: str
+    :param ui_fgcolor: The label color of the TaskGroup node when displayed in 
the UI
+    :type ui_fgcolor: str
+    """
+
+    def __init__(
+        self,
+        group_id: Optional[str],
+        prefix_group_id: bool = True,
+        parent_group: Optional["TaskGroup"] = None,
+        dag: Optional["DAG"] = None,
+        tooltip: str = "",
+        ui_color: str = "CornflowerBlue",
+        ui_fgcolor: str = "#000",
+    ):
+        from airflow.models.dag import DagContext
+
+        self.prefix_group_id = prefix_group_id
+
+        if group_id is None:
+            # This creates a root TaskGroup.
+            if parent_group:
+                raise AirflowException("Root TaskGroup cannot have 
parent_group")
+            # used_group_ids is shared across all TaskGroups in the same DAG 
to keep track
+            # of used group_id to avoid duplication.
+            self.used_group_ids: Set[Optional[str]] = set()
+            self._parent_group = None
+        else:
+            if not isinstance(group_id, str):
+                raise ValueError("group_id must be str")
+            if not group_id:
+                raise ValueError("group_id must not be empty")
+
+            dag = dag or DagContext.get_current_dag()
+
+            if not parent_group and not dag:
+                raise AirflowException("TaskGroup can only be used inside a 
dag")
+
+            self._parent_group = parent_group or 
TaskGroupContext.get_current_task_group(dag)
+            if not self._parent_group:
+                raise AirflowException("TaskGroup must have a parent_group 
except for the root TaskGroup")
+            self.used_group_ids = self._parent_group.used_group_ids
+
+        self._group_id = group_id
+        if self.group_id in self.used_group_ids:
+            raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has 
already been added to the DAG")
+        self.used_group_ids.add(self.group_id)
+        self.used_group_ids.add(self.downstream_join_id)
+        self.used_group_ids.add(self.upstream_join_id)
+        self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
+        if self._parent_group:
+            self._parent_group.add(self)
+
+        self.tooltip = tooltip
+        self.ui_color = ui_color
+        self.ui_fgcolor = ui_fgcolor
+
+        # Keep track of TaskGroups or tasks that depend on this entire 
TaskGroup separately
+        # so that we can optimize the number of edges when entire TaskGroups 
depend on each other.
+        self.upstream_group_ids: Set[Optional[str]] = set()
+        self.downstream_group_ids: Set[Optional[str]] = set()
+        self.upstream_task_ids: Set[Optional[str]] = set()
+        self.downstream_task_ids: Set[Optional[str]] = set()
+
+    @classmethod
+    def create_root(cls, dag: "DAG") -> "TaskGroup":
+        """
+        Create a root TaskGroup with no group_id or parent.
+        """
+        return cls(group_id=None, dag=dag)
+
+    @property
+    def is_root(self) -> bool:
+        """
+        Returns True if this TaskGroup is the root TaskGroup. Otherwise False
+        """
+        return not self.group_id
+
+    def __iter__(self):
+        for child in self.children.values():
+            if isinstance(child, TaskGroup):
+                for inner_task in child:
+                    yield inner_task
+            else:
+                yield child
+
+    def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None:
+        """
+        Add a task to this TaskGroup.
+        """
+        key = task.group_id if isinstance(task, TaskGroup) else task.task_id
+
+        if key in self.children:
+            raise DuplicateTaskIdFound(f"Task id '{key}' has already been 
added to the DAG")
+
+        if isinstance(task, TaskGroup):
+            if task.children:
+                raise AirflowException("Cannot add a non-empty TaskGroup")
+
+        self.children[key] = task  # type: ignore
+
+    @property
+    def group_id(self) -> Optional[str]:
+        """
+        group_id of this TaskGroup.
+        """
+        if self._parent_group and self._parent_group.prefix_group_id and 
self._parent_group.group_id:
+            return self._parent_group.child_id(self._group_id)
+
+        return self._group_id
+
+    @property
+    def label(self) -> Optional[str]:
+        """
+        group_id excluding parent's group_id used as the node label in UI.
+        """
+        return self._group_id
+
+    def update_relative(self, other: "TaskMixin", upstream=True) -> None:
+        """
+        Overrides TaskMixin.update_relative.
+
+        Update 
upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
+        accordingly so that we can reduce the number of edges when displaying 
Graph View.
+        """
+        from airflow.models.baseoperator import BaseOperator
+
+        if isinstance(other, TaskGroup):
+            # Handles setting relationship between a TaskGroup and another 
TaskGroup
+            if upstream:
+                parent, child = (self, other)
+            else:
+                parent, child = (other, self)
+
+            parent.upstream_group_ids.add(child.group_id)
+            child.downstream_group_ids.add(parent.group_id)
+        else:
+            # Handles setting relationship between a TaskGroup and a task
+            for task in other.roots:
+                if not isinstance(task, BaseOperator):
+                    raise AirflowException("Relationships can only be set 
between TaskGroup "
+                                           f"or operators; received 
{task.__class__.__name__}")
+
+                if upstream:
+                    self.upstream_task_ids.add(task.task_id)
+                else:
+                    self.downstream_task_ids.add(task.task_id)
+
+    def _set_relative(
+            self,
+            task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
+            upstream: bool = False
+    ) -> None:
+        """
+        Call set_upstream/set_downstream for all root/leaf tasks within this 
TaskGroup.
+        Update 
upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
+        """
+        if upstream:
+            for task in self.get_roots():
+                task.set_upstream(task_or_task_list)
+        else:
+            for task in self.get_leaves():
+                task.set_downstream(task_or_task_list)
+
+        if not isinstance(task_or_task_list, Sequence):
+            task_or_task_list = [task_or_task_list]
+
+        for task_like in task_or_task_list:
+            self.update_relative(task_like, upstream)
+
+    def set_downstream(
+        self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
+    ) -> None:
+        """
+        Set a TaskGroup/task/list of task downstream of this TaskGroup.
+        """
+        self._set_relative(task_or_task_list, upstream=False)
+
+    def set_upstream(
+        self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]
+    ) -> None:
+        """
+        Set a TaskGroup/task/list of task upstream of this TaskGroup.
+        """
+        self._set_relative(task_or_task_list, upstream=True)
+
+    def __enter__(self):
+        TaskGroupContext.push_context_managed_task_group(self)
+        return self
+
+    def __exit__(self, _type, _value, _tb):
+        TaskGroupContext.pop_context_managed_task_group()
+
+    def has_task(self, task: "BaseOperator") -> bool:
+        """
+        Returns True if this TaskGroup or its children TaskGroups contains the 
given task.
+        """
+        if task.task_id in self.children:
+            return True
+
+        return any(child.has_task(task) for child in self.children.values() if 
isinstance(child, TaskGroup))
+
+    @property
+    def roots(self) -> List["BaseOperator"]:
+        """Required by TaskMixin"""
+        return list(self.get_roots())
+
+    @property
+    def leaves(self) -> List["BaseOperator"]:
+        """Required by TaskMixin"""
+        return list(self.get_leaves())
+
+    def get_roots(self) -> Generator["BaseOperator", None, None]:
+        """
+        Returns a generator of tasks that are root tasks, i.e. those with no 
upstream
+        dependencies within the TaskGroup.
+        """
+        for task in self:
+            if not any(self.has_task(parent) for parent in 
task.get_direct_relatives(upstream=True)):
+                yield task
+
+    def get_leaves(self) -> Generator["BaseOperator", None, None]:
+        """
+        Returns a generator of tasks that are leaf tasks, i.e. those with no 
downstream
+        dependencies within the TaskGroup
+        """
+        for task in self:
+            if not any(self.has_task(child) for child in 
task.get_direct_relatives(upstream=False)):
+                yield task
+
+    def child_id(self, label):
+        """
+        Prefix label with group_id if prefix_group_id is True. Otherwise 
return the label
+        as-is.
+        """
+        if self.prefix_group_id and self.group_id:
+            return f"{self.group_id}.{label}"
+
+        return label
+
+    @property
+    def upstream_join_id(self) -> str:
+        """
+        If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy 
node called
+        upstream_join_id will be created in Graph View to join the outgoing 
edges from this
+        TaskGroup to reduce the total number of edges needed to be displayed.
+        """
+        return f"{self.group_id}.upstream_join_id"
+
+    @property
+    def downstream_join_id(self) -> str:
+        """
+        If this TaskGroup has immediate downstream TaskGroups or tasks, a 
dummy node called
+        downstream_join_id will be created in Graph View to join the outgoing 
edges from this
+        TaskGroup to reduce the total number of edges needed to be displayed.
+        """
+        return f"{self.group_id}.downstream_join_id"
+
+    def get_task_group_dict(self) -> Dict[str, "TaskGroup"]:
+        """
+        Returns a flat dictionary of group_id: TaskGroup
+        """
+        task_group_map = {}
+
+        def build_map(task_group):
+            if not isinstance(task_group, TaskGroup):
+                return
+
+            task_group_map[task_group.group_id] = task_group
+
+            for child in task_group.children.values():
+                build_map(child)
+
+        build_map(self)
+        return task_group_map
+
+    def get_child_by_label(self, label: str) -> Union["BaseOperator", 
"TaskGroup"]:
+        """
+        Get a child task/TaskGroup by its label (i.e. task_id/group_id without 
the group_id prefix)
+        """
+        return self.children[self.child_id(label)]
+
+
+class TaskGroupContext:
+    """
+    TaskGroup context is used to keep the current TaskGroup when TaskGroup is 
used as ContextManager.
+    """
+
+    _context_managed_task_group: Optional[TaskGroup] = None
+    _previous_context_managed_task_groups: List[TaskGroup] = []
+
+    @classmethod
+    def push_context_managed_task_group(cls, task_group: TaskGroup):
+        """
+        Push a TaskGroup into the list of managed TaskGroups.
+        """
+        if cls._context_managed_task_group:
+            
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
+        cls._context_managed_task_group = task_group
+
+    @classmethod
+    def pop_context_managed_task_group(cls) -> Optional[TaskGroup]:
+        """
+        Pops the last TaskGroup from the list of manged TaskGroups and update 
the current TaskGroup.
+        """
+        old_task_group = cls._context_managed_task_group
+        if cls._previous_context_managed_task_groups:
+            cls._context_managed_task_group = 
cls._previous_context_managed_task_groups.pop()
+        else:
+            cls._context_managed_task_group = None
+        return old_task_group
+
+    @classmethod
+    def get_current_task_group(cls, dag: Optional["DAG"]) -> 
Optional[TaskGroup]:
+        """
+        Get the current TaskGroup.
+        """
+        from airflow.models.dag import DagContext
+
+        if not cls._context_managed_task_group:
+            dag = dag or DagContext.get_current_dag()
+            if dag:
+                # If there's currently a DAG but no TaskGroup, return the root 
TaskGroup of the dag.
+                return dag.task_group
+
+        return cls._context_managed_task_group
diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css
index 7f1b818..ee1d093 100644
--- a/airflow/www/static/css/graph.css
+++ b/airflow/www/static/css/graph.css
@@ -36,12 +36,26 @@ svg {
   stroke-width: 1px;
 }
 
+g.cluster rect {
+  stroke: white;
+  stroke-dasharray: 5;
+  rx: 5;
+  ry: 5;
+  opacity: 0.5;
+}
+
 g.node rect {
   stroke: #fff;
   stroke-width: 3px;
   cursor: pointer;
 }
 
+g.node circle {
+  stroke: black;
+  stroke-width: 3px;
+  cursor: pointer;
+}
+
 g.node .label {
   font-size: inherit;
   font-weight: normal;
diff --git a/airflow/www/templates/airflow/graph.html 
b/airflow/www/templates/airflow/graph.html
index f4ec0b6..6d6d566 100644
--- a/airflow/www/templates/airflow/graph.html
+++ b/airflow/www/templates/airflow/graph.html
@@ -101,6 +101,8 @@
 <script src="{{ url_for_asset('d3.min.js') }}"></script>
 <script src="{{ url_for_asset('dagre-d3.min.js') }}"></script>
 <script src="{{ url_for_asset('d3-tip.js') }}"></script>
+
+<script src="{{ url_for_asset('task-instances.js') }}"></script>
 <script>
 
     var highlight_color = "#000000";
@@ -113,6 +115,10 @@
     var edges = {{ edges|tojson }};
     var execution_date = "{{ execution_date }}";
     var arrange = "{{ arrange }}";
+    var task_group_tips = get_task_group_tips(nodes);
+    // This maps the actual task_id to the current graph node id that contains 
the task
+    // (because tasks may be grouped into a group node)
+    var map_task_to_node = new Map()
 
     // Below variables are being used in dag.js
     var tasks = {{ tasks|tojson }};
@@ -140,33 +146,94 @@
       });
 
     // Preparation of DagreD3 data structures
-    var g = new dagreD3.graphlib.Graph().setGraph({
+    // "compound" is set to true to make use of clusters to display TaskGroup.
+    var g = new dagreD3.graphlib.Graph({compound: true}).setGraph({
         nodesep: 15,
         ranksep: 15,
         rankdir: arrange,
       })
       .setDefaultEdgeLabel(function() { return { lineInterpolate: 'basis' } });
 
-    // Set all nodes and styles
-    nodes.forEach(function(node) {
-      g.setNode(node.id, node.value)
-    });
-
-    // Set edges
-    edges.forEach(function(edge) {
-      g.setEdge(edge.source_id, edge.target_id);
-    });
-
     var render = dagreD3.render(),
       svg = d3.select("svg"),
       innerSvg = d3.select("svg g");
 
-    innerSvg.call(render, g);
-    innerSvg.call(taskTip);
+    // Update the page to show the latest DAG.
+    function draw() {
+      innerSvg.remove()
+      innerSvg = svg.append("g")
+      // Run the renderer. This is what draws the final graph.
+      innerSvg.call(render, g);
+      innerSvg.call(taskTip)
+
+      // When an expanded group is clicked, collapse it.
+      d3.selectAll("g.cluster").on("click", function (node_id) {
+        if (d3.event.defaultPrevented) // Ignore dragging actions.
+            return;
+        node = g.node(node_id)
+        collapse_group(node_id, node)
+      })
+      // When a node is clicked, action depends on the node type.
+      d3.selectAll("g.node").on("click", function (node_id) {
+        node = g.node(node_id)
+        if (node.children != undefined && Object.keys(node.children).length > 
0) {
+          // A group node
+          if (d3.event.defaultPrevented) // Ignore dragging actions.
+            return;
+          expand_group(node_id, node)
+        } else if (node_id in tasks) {
+          // A task node
+          task = tasks[node_id];
+          if (node_id in task_instances)
+            try_number = task_instances[node_id].try_number;
+          else
+            try_number = 0;
+
+          if (task.task_type == "SubDagOperator")
+            call_modal(node_id, execution_date, task.extra_links, try_number, 
true);
+          else
+            call_modal(node_id, execution_date, task.extra_links, try_number, 
undefined);
+        } else {
+          // join node between TaskGroup. Ignore.
+        }
+      });
+
+      d3.selectAll("g.node").on("mouseover", function (d) {
+        d3.select(this).selectAll("rect").style("stroke", highlight_color);
+        highlight_nodes(g.predecessors(d), upstream_color, 
highlightStrokeWidth);
+        highlight_nodes(g.successors(d), downstream_color, 
highlightStrokeWidth)
+        adjacent_node_names = [d, ...g.predecessors(d), ...g.successors(d)]
+        d3.selectAll("g.nodes g.node")
+          .filter(x => !adjacent_node_names.includes(x))
+          .style("opacity", 0.2);
+        adjacent_edges = g.nodeEdges(d)
+        d3.selectAll("g.edgePath")[0]
+          .filter(x => !adjacent_edges.includes(x.__data__))
+          .forEach(function (x) {
+            d3.select(x).style('opacity', .2)
+          })
+      });
+
+      d3.selectAll("g.node").on("mouseout", function (d) {
+        d3.select(this).selectAll("rect").style("stroke", null);
+        highlight_nodes(g.predecessors(d), null, initialStrokeWidth)
+        highlight_nodes(g.successors(d), null, initialStrokeWidth)
+        d3.selectAll("g.node")
+          .style("opacity", 1);
+        d3.selectAll("g.node rect")
+          .style("stroke-width", initialStrokeWidth);
+        d3.selectAll("g.edgePath")
+          .style("opacity", 1);
+      });
+      updateNodesStates(task_instances);
+      setUpZoomSupport();
+    }
+
+    var zoom = null;
 
     function setUpZoomSupport() {
       // Set up zoom support for Graph
-      var zoom = d3.behavior.zoom().on("zoom", function() {
+      zoom = d3.behavior.zoom().on("zoom", function() {
             innerSvg.attr("transform", "translate(" + d3.event.translate + ")" 
+
                                         "scale(" + d3.event.scale + ")");
           });
@@ -193,61 +260,16 @@
       zoom.event(innerSvg);
     }
 
-    setUpZoomSupport();
-    inject_node_ids(tasks);
-
-    d3.selectAll("g.node").on("click", function(d){
-        task = tasks[d];
-        if (d in task_instances)
-            try_number = task_instances[d].try_number;
-        else
-            try_number = 0;
-
-        if (task.task_type == "SubDagOperator")
-            call_modal(d, execution_date, task.extra_links, try_number, true);
-        else
-            call_modal(d, execution_date, task.extra_links, try_number, 
undefined);
-    });
-
-
     function highlight_nodes(nodes, color, stroke_width) {
         nodes.forEach (function (nodeid) {
-            my_node = d3.select('[id="' + nodeid + '"]').node().parentNode;
+            const my_node = g.node(nodeid).elem
             d3.select(my_node)
-                .selectAll("rect")
+                .selectAll("rect,circle")
                 .style("stroke", color)
                 .style("stroke-width", stroke_width) ;
         })
     }
 
-    d3.selectAll("g.node").on("mouseover", function(d){
-        d3.select(this).selectAll("rect").style("stroke", highlight_color) ;
-        highlight_nodes(g.predecessors(d), upstream_color, 
highlightStrokeWidth);
-        highlight_nodes(g.successors(d), downstream_color, 
highlightStrokeWidth)
-        adjacent_node_names = [d, ...g.predecessors(d), ...g.successors(d)]
-        d3.selectAll("g.nodes g.node")
-            .filter(x => !adjacent_node_names.includes(x))
-            .style("opacity", 0.2);
-        adjacent_edges = g.nodeEdges(d)
-        d3.selectAll("g.edgePath")[0]
-            .filter(x => !adjacent_edges.includes(x.__data__))
-            .forEach(function(x) {
-                d3.select(x).style('opacity', .2)
-            })
-    });
-
-    d3.selectAll("g.node").on("mouseout", function(d){
-        d3.select(this).selectAll("rect").style("stroke", null) ;
-        highlight_nodes(g.predecessors(d), null, initialStrokeWidth)
-        highlight_nodes(g.successors(d), null, initialStrokeWidth)
-        d3.selectAll("g.node")
-            .style("opacity", 1);
-        d3.selectAll("g.node rect")
-            .style("stroke-width", initialStrokeWidth);
-        d3.selectAll("g.edgePath")
-            .style("opacity", 1);
-    });
-
 
     {% if blur %}
     d3.selectAll("text").attr("class", "blur");
@@ -283,11 +305,31 @@
             }
         });
 
-    d3.select("#searchbox").on("keyup", function(){
+    // Returns true if a node's id or its children's id matches search_text
+    function node_matches(node_id, search_text) {
+      if (node_id.indexOf(search_text) > -1)
+        return true;
+
+      // The node's own id does not match, it may have children that match
+      var node = g.node(node_id)
+      if (node.children != undefined) {
+        var children = get_children_ids(node);
+        for(const child of children) {
+          if(child.indexOf(search_text) > -1)
+            return true
+        }
+      }
+    }
+
+    d3.select("#searchbox").on("keyup", function() {
         var s = document.getElementById("searchbox").value;
+
+        if(s == "")
+          return;
+
         var match = null;
 
-        if (stateIsSet){
+        if (stateIsSet()){
             clearFocus();
             setFocusMap();
         }
@@ -307,7 +349,7 @@
                 d3.select("g.edgePaths")
                     .transition().duration(duration)
                     .style("opacity", 0.2);
-                if (d.indexOf(s) > -1) {
+                if (node_matches(d, s)) {
                     if (!match)
                         match = this;
                     d3.select(this)
@@ -315,8 +357,7 @@
                         .style("opacity", 1)
                         .selectAll("rect")
                         .style("stroke-width", highlightStrokeWidth);
-                }
-                else {
+                } else {
                     d3.select(this)
                         .transition()
                         .style("opacity", 0.2).duration(duration)
@@ -326,37 +367,25 @@
             }
         });
 
-        // This moves the matched node in the center of the graph area
-        // ToDo: Should we keep this here as it has no added value
-        // and does not fit the graph on small screens, and has to scroll
+        // This moves the matched node to the center of the graph area
         if(match) {
             var transform = d3.transform(d3.select(match).attr("transform"));
+
+            var svgBb = svg.node().getBoundingClientRect();
             transform.translate = [
-                -transform.translate[0] + 520,
-                -(transform.translate[1] - 400)
+              svgBb.width / 2 - transform.translate[0],
+              svgBb.height / 2 - transform.translate[1]
             ];
             transform.scale = [1, 1];
 
-            d3.select("g.zoom")
-                .transition()
-                .attr("transform", transform.toString());
-            innerSvg.attr("transform", "translate(" + transform.translate + 
")" +
-                                        "scale(1)");
+            if(zoom != null) {
+              zoom.translate(transform.translate);
+              zoom.scale(1);
+              zoom.event(innerSvg);
+            }
         }
     });
 
-
-    // Injecting ids to be used for parent/child highlighting
-    // Separated from updateNodeStates since it must work even
-    // when there is no valid task instance available
-    function inject_node_ids(tasks) {
-        $.each(tasks, function(task_id, task) {
-            $('tspan').filter(function(index) { return $(this).text() === 
task_id; })
-                    .parent().parent().parent()
-                    .attr("id", task_id);
-        });
-    }
-
     function clearFocus(){
         d3.selectAll("g.node")
             .transition(duration)
@@ -413,8 +442,9 @@
         $("div#svg_container").css("opacity", "0.2");
         $.get(getTaskInstanceURL)
           .done(
-            (task_instances) => {
-              updateNodesStates(JSON.parse(task_instances));
+            (tis) => {
+              task_instances = JSON.parse(tis)
+              updateNodesStates(task_instances);
               $("#loading").hide();
               $("div#svg_container").css("opacity", "1");
               $('#error').hide();
@@ -429,31 +459,249 @@
       });
     }
 
+    // Generate tooltip for a group node
+    function group_tooltip(node_id, tis) {
+      var num_map = new Map([["success", 0],
+                             ["failed", 0],
+                             ["upstream_failed", 0],
+                             ["up_for_retry", 0],
+                             ["running", 0],
+                             ["no_status", 0]]
+                           );
+      for(const child of get_children_ids(g.node(node_id))) {
+        if(child in tis) {
+          const ti = tis[child];
+          const state_key = ti.state == null ? "no_status" : ti.state;
+          if(num_map.has(state_key))
+            num_map.set(state_key, num_map.get(state_key) + 1);
+        }
+      }
+
+      const tip = task_group_tips.get(node_id);
+      let tt = `${escapeHtml(tip)}<br><br>`;
+      for(const [key, val] of num_map.entries())
+        tt += `<strong>${escapeHtml(key)}:</strong> ${val} <br>`;
+
+      return tt;
+    }
+
+    // Build a map mapping node id to tooltip for all the TaskGroups.
+    function get_task_group_tips(node) {
+      var tips = new Map();
+      if(node.children != undefined) {
+        tips.set(node.id, node.tooltip);
+
+        for(const child of node.children.values()) {
+          for(const [key, val] of get_task_group_tips(child))
+            tips.set(key, val);
+        }
+      }
+      return tips;
+    }
+
     // Assigning css classes based on state to nodes
     // Initiating the tooltips
-    function updateNodesStates(task_instances) {
-      $.each(task_instances, (task_id, ti) => {
-        $('tspan').filter((index, el) => {
-          return $(el).text() === task_id;
-        })
-          .parent().parent().parent().parent()
-          .attr("class", "node enter " + (ti.state ? ti.state : "no_status"))
-          .attr("data-toggle", "tooltip")
-          .on("mouseover", (evt) => {
-            const task = tasks[task_id];
-            const tt = tiTooltip(ti);
+    function updateNodesStates(tis) {
+      for(const node_id of g.nodes())
+      {
+        elem = g.node(node_id).elem;
+        elem.setAttribute("class", "node enter " + get_node_state(node_id, 
tis));
+        elem.setAttribute("data-toggle", "tooltip");
+
+        const task_id = node_id;
+        elem.onmouseover = (evt) => {
+          if(task_id in tis) {
+            const tt = tiTooltip(tis[task_id]);
             taskTip.show(tt, evt.target); // taskTip is defined in graph.html
+          } else if(task_group_tips.has(task_id)) {
+            const tt = group_tooltip(task_id, tis)
+            taskTip.show(tt, evt.target);
+          }
+        };
+        elem.onmouseout = taskTip.hide;
+        elem.onclick = taskTip.hide;
+      }
+    }
+
+
+    // Returns list of children id of the given task group
+    function get_children_ids(group) {
+      var children = []
+      for(const [key, val] of Object.entries(group.children)) {
+        if(val.children == undefined) {
+          // node
+          children.push(val.id)
+        } else {
+          // group
+          const sub_group_children = get_children_ids(val)
+          for(const id of sub_group_children) {
+            children.push(id)
+          }
+        }
+      }
+      return children
+    }
+
+
+    // Return the state for the node based on the state of its taskinstance or 
that of its
+    // children if it's a group node
+    function get_node_state(node_id, tis) {
+      node = g.node(node_id)
+
+      if (node.children == undefined) {
+        if(node_id in tis)
+          return tis[node_id].state
+
+        return "no_status"
+      }
+      var children = get_children_ids(node)
+
+      children_states = new Set()
+      children.forEach(function(task_id) {
+        if (task_id in tis) {
+          var state = tis[task_id].state
+          children_states.add(state == null ? "no_status" : state)
+        }
+      })
+
+      // In this order, if any of these states appeared in children_states, 
return it as
+      // the group state.
+      var priority = ["failed", "upstream_failed", 
"up_for_retry","up_for_reschedule",
+                      "queued", "no_status", "success", "skipped"]
+      for(const state of priority) {
+        if (children_states.has(state))
+          return state
+      }
+      return "no_status"
+    }
+
+    // Focus the graph on the expanded/collapsed node
+    function focus_group(node_id) {
+      if(node_id != null && zoom != null) {
+          const x = g.node(node_id).x;
+          const y = g.node(node_id).y;
+          // This is the total canvas size.
+          const svg_box = svg.node().getBoundingClientRect();
+          const width = svg_box.width;
+          const height = svg_box.height;
+
+          // This is the size of the node or the cluster (i.e. group)
+          var rect = d3.selectAll("g.node").filter(x => {return x == 
node_id}).select('rect');
+          if (rect.empty())
+            rect = d3.selectAll("g.cluster").filter(x => {return x == 
node_id}).select('rect');
+
+          // Is there a better way to get node_width and node_height ?
+          const [node_width, node_height] = 
[rect[0][0].attributes.width.value, rect[0][0].attributes.height.value];
+
+          // Calculate zoom scale to fill most of the canvas with the the 
node/cluster in focus.
+          const scale = Math.min(
+            Math.min(width / node_width, height / node_height),
+            1.5,  // cap zoom level to 1.5 so nodes are not too large
+          ) * 0.9;
+
+          var [delta_x, delta_y] = [width / 2 - x * scale, height / 2 - y * 
scale];
+          zoom.translate([delta_x, delta_y]);
+          zoom.scale(scale);
+          zoom.event(innerSvg.transition().duration(duration));
+
+          const children = new Set(g.children(node_id))
+          // Change opacity to highlight the focused group.
+          d3.selectAll("g.nodes g.node").filter(function(d, i){
+            if (d == node_id || children.has(d)) {
+                d3.select(this)
+                    .transition().duration(duration)
+                    .style("opacity", 1)
+            } else {
+                d3.select(this)
+                    .transition()
+                    .style("opacity", 0.2).duration(duration)
+            }
+          });
+      }
+    }
+
+    // Expands a group node
+    function expand_group(node_id, node) {
+      node.children.forEach(function (val) {
+        // Set children nodes
+        g.setNode(val.id, val.value)
+        map_task_to_node.set(val.id, val.id)
+        g.node(val.id).id = val.id
+        if (val.children != undefined) {
+          // Set children attribute so that the group can be expanded later 
when needed.
+          group_node = g.node(val.id)
+          group_node.children = val.children
+          // Map task that are under this node to this node's id
+          for(const child_id of get_children_ids(val))
+            map_task_to_node.set(child_id, val.id)
+        }
+        // Only call setParent if node is not the root node.
+        if (node_id != null)
+          g.setParent(val.id, node_id)
+      })
+
+      // Add edges
+      edges.forEach(function(edge) {
+        source_id = map_task_to_node.get(edge.source_id)
+        target_id = map_task_to_node.get(edge.target_id)
+        if(source_id != target_id && !g.hasEdge(source_id, target_id))
+          g.setEdge(source_id, target_id)
+      })
+
+      g.edges().forEach(function (edge) {
+        // Remove edges that were associated with the expanded group node..
+        if(node_id == edge.v || node_id == edge.w)
+          g.removeEdge(edge.v, edge.w)
+      })
+
+      draw()
+      focus_group(node_id)
+  }
+
+  // Remove the node with this node_id from g.
+  function remove_node(node_id) {
+    if(g.hasNode(node_id)) {
+        node = g.node(node_id)
+        if(node.children != undefined) {
+          // If the child is an expanded group node, remove children too.
+          node.children.forEach(function (child) {
+            remove_node(child.id)
           })
-          .on('mouseout', taskTip.hide);
-      });
+        }
     }
+    g.removeNode(node_id)
+  }
+
+  // Collapse the children of the given group node.
+  function collapse_group(node_id, node) {
+      // Remove children nodes
+      node.children.forEach(function(child) {
+        remove_node(child.id)
+      })
+      // Map task that are under this node to this node's id
+      for(const child_id of get_children_ids(node))
+        map_task_to_node.set(child_id, node_id)
+
+      node = g.node(node_id)
+
+      // Set children edges onto the group edge
+      edges.forEach(function(edge) {
+        source_id = map_task_to_node.get(edge.source_id)
+        target_id = map_task_to_node.get(edge.target_id)
+        if(source_id != target_id && !g.hasEdge(source_id, target_id))
+          g.setEdge(source_id, target_id)
+      })
+
+      draw()
+      focus_group(node_id)
+    }
+
+    expand_group(null, nodes)
 
-    updateNodesStates(task_instances);
     initRefreshButton();
 
 </script>
 <script src="{{ url_for_asset('graph.js') }}"></script>
-<script src="{{ url_for_asset('task-instances.js') }}"></script>
 
 
 {% endblock %}
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 07c8d2f..4c5e2b1 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -58,6 +58,7 @@ from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.base_job import BaseJob
 from airflow.jobs.scheduler_job import SchedulerJob
 from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, 
TaskFail, XCom, errors
+from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
 from airflow.models.taskinstance import TaskInstance
@@ -147,6 +148,163 @@ def 
get_date_time_num_runs_dag_runs_form_data(www_request, session, dag):
     }
 
 
+def task_group_to_dict(task_group):
+    """
+    Create a nested dict representation of this TaskGroup and its children 
used to construct
+    the Graph View.
+    """
+    if isinstance(task_group, BaseOperator):
+        return {
+            'id': task_group.task_id,
+            'value': {
+                'label': task_group.label,
+                'labelStyle': f"fill:{task_group.ui_fgcolor};",
+                'style': f"fill:{task_group.ui_color};",
+                'rx': 5,
+                'ry': 5,
+            }
+        }
+
+    children = [task_group_to_dict(child) for child in
+                sorted(task_group.children.values(), key=lambda t: t.label)]
+
+    if task_group.upstream_group_ids or task_group.upstream_task_ids:
+        children.append({
+            'id': task_group.upstream_join_id,
+            'value': {
+                'label': '',
+                'labelStyle': f"fill:{task_group.ui_fgcolor};",
+                'style': f"fill:{task_group.ui_color};",
+                'shape': 'circle',
+            }
+        })
+
+    if task_group.downstream_group_ids or task_group.downstream_task_ids:
+        # This is the join node used to reduce the number of edges between two 
TaskGroup.
+        children.append({
+            'id': task_group.downstream_join_id,
+            'value': {
+                'label': '',
+                'labelStyle': f"fill:{task_group.ui_fgcolor};",
+                'style': f"fill:{task_group.ui_color};",
+                'shape': 'circle',
+            }
+        })
+
+    return {
+        "id": task_group.group_id,
+        'value': {
+            'label': task_group.label,
+            'labelStyle': f"fill:{task_group.ui_fgcolor};",
+            'style': f"fill:{task_group.ui_color}",
+            'rx': 5,
+            'ry': 5,
+            'clusterLabelPos': 'top',
+        },
+        'tooltip': task_group.tooltip,
+        'children': children
+    }
+
+
+def dag_edges(dag):
+    """
+    Create the list of edges needed to construct the Graph View.
+
+    A special case is made if a TaskGroup is immediately upstream/downstream 
of another
+    TaskGroup or task. Two dummy nodes named upstream_join_id and 
downstream_join_id are
+    created for the TaskGroup. Instead of drawing an edge onto every task in 
the TaskGroup,
+    all edges are directed onto the dummy nodes. This is to cut down the 
number of edges on
+    the graph.
+
+    For example: A DAG with TaskGroups group1 and group2:
+        group1: task1, task2, task3
+        group2: task4, task5, task6
+
+    group2 is downstream of group1:
+        group1 >> group2
+
+    Edges to add (This avoids having to create edges between every task in 
group1 and group2):
+        task1 >> downstream_join_id
+        task2 >> downstream_join_id
+        task3 >> downstream_join_id
+        downstream_join_id >> upstream_join_id
+        upstream_join_id >> task4
+        upstream_join_id >> task5
+        upstream_join_id >> task6
+    """
+
+    # Edges to add between TaskGroup
+    edges_to_add = set()
+    # Edges to remove between individual tasks that are replaced by 
edges_to_add.
+    edges_to_skip = set()
+
+    task_group_map = dag.task_group.get_task_group_dict()
+
+    def collect_edges(task_group):
+        """
+        Update edges_to_add and edges_to_skip according to TaskGroups.
+        """
+        if isinstance(task_group, BaseOperator):
+            return
+
+        for target_id in task_group.downstream_group_ids:
+            # For every TaskGroup immediately downstream, add edges between 
downstream_join_id
+            # and upstream_join_id. Skip edges between individual tasks of the 
TaskGroups.
+            target_group = task_group_map[target_id]
+            edges_to_add.add((task_group.downstream_join_id, 
target_group.upstream_join_id))
+
+            for child in task_group.get_leaves():
+                edges_to_add.add((child.task_id, 
task_group.downstream_join_id))
+                for target in target_group.get_roots():
+                    edges_to_skip.add((child.task_id, target.task_id))
+                edges_to_skip.add((child.task_id, 
target_group.upstream_join_id))
+
+            for child in target_group.get_roots():
+                edges_to_add.add((target_group.upstream_join_id, 
child.task_id))
+                edges_to_skip.add((task_group.downstream_join_id, 
child.task_id))
+
+        # For every individual task immediately downstream, add edges between 
downstream_join_id and
+        # the downstream task. Skip edges between individual tasks of the 
TaskGroup and the
+        # downstream task.
+        for target_id in task_group.downstream_task_ids:
+            edges_to_add.add((task_group.downstream_join_id, target_id))
+
+            for child in task_group.get_leaves():
+                edges_to_add.add((child.task_id, 
task_group.downstream_join_id))
+                edges_to_skip.add((child.task_id, target_id))
+
+        # For every individual task immediately upstream, add edges between 
the upstream task
+        # and upstream_join_id. Skip edges between the upstream task and 
individual tasks
+        # of the TaskGroup.
+        for source_id in task_group.upstream_task_ids:
+            edges_to_add.add((source_id, task_group.upstream_join_id))
+            for child in task_group.get_roots():
+                edges_to_add.add((task_group.upstream_join_id, child.task_id))
+                edges_to_skip.add((source_id, child.task_id))
+
+        for child in task_group.children.values():
+            collect_edges(child)
+
+    collect_edges(dag.task_group)
+
+    # Collect all the edges between individual tasks
+    edges = set()
+
+    def get_downstream(task):
+        for child in task.downstream_list:
+            edge = (task.task_id, child.task_id)
+            if edge not in edges:
+                edges.add(edge)
+                get_downstream(child)
+
+    for root in dag.roots:
+        get_downstream(root)
+
+    return [{'source_id': source_id, 'target_id': target_id}
+            for source_id, target_id
+            in sorted(edges.union(edges_to_add) - edges_to_skip)]
+
+
 
######################################################################################
 #                                    Error handlers
 
######################################################################################
@@ -1608,32 +1766,8 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: 
disable=too-many-public-m
 
         arrange = request.args.get('arrange', dag.orientation)
 
-        nodes = []
-        edges = []
-        for dag_task in dag.tasks:
-            nodes.append({
-                'id': dag_task.task_id,
-                'value': {
-                    'label': dag_task.task_id,
-                    'labelStyle': "fill:{0};".format(dag_task.ui_fgcolor),
-                    'style': "fill:{0};".format(dag_task.ui_color),
-                    'rx': 5,
-                    'ry': 5,
-                }
-            })
-
-        def get_downstream(task):
-            for downstream_task in task.downstream_list:
-                edge = {
-                    'source_id': task.task_id,
-                    'target_id': downstream_task.task_id,
-                }
-                if edge not in edges:
-                    edges.append(edge)
-                    get_downstream(downstream_task)
-
-        for dag_task in dag.roots:
-            get_downstream(dag_task)
+        nodes = task_group_to_dict(dag.task_group)
+        edges = dag_edges(dag)
 
         dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, 
session, dag)
         dt_nr_dr_data['arrange'] = arrange
diff --git a/docs/concepts.rst b/docs/concepts.rst
index 5d2be97..8e37582 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -939,6 +939,48 @@ See ``airflow/example_dags`` for a demonstration.
 Note that airflow pool is not honored by SubDagOperator. Hence resources could 
be
 consumed by SubdagOperators.
 
+
+TaskGroup
+=========
+TaskGroup can be used to organize tasks into hierarchical groups in Graph 
View. It is
+useful for creating repeating patterns and cutting down visual clutter. Unlike 
SubDagOperator,
+TaskGroup is a UI grouping concept. Tasks in TaskGroups live on the same 
original DAG. They
+honor all the pool configurations.
+
+Dependency relationships can be applied across all tasks in a TaskGroup with 
the ``>>`` and ``<<``
+operators. For example, the following code puts ``task1`` and ``task2`` in 
TaskGroup ``group1``
+and then puts both tasks upstream of ``task3``:
+
+.. code-block:: python
+
+    with TaskGroup("group1") as group1:
+        task1 = DummyOperator(task_id="task1")
+        task2 = DummyOperator(task_id="task2")
+
+    task3 = DummyOperator(task_id="task3")
+
+    group1 >> task3
+
+.. note::
+   By default, child tasks and TaskGroups have their task_id and group_id 
prefixed with the
+   group_id of their parent TaskGroup. This ensures uniqueness of group_id and 
task_id throughout
+   the DAG. To disable the prefixing, pass ``prefix_group_id=False`` when 
creating the TaskGroup.
+   This then gives the user full control over the actual group_id and task_id. 
They have to ensure
+   group_id and task_id are unique throughout the DAG. The option 
``prefix_group_id=False`` is
+   mainly useful for putting tasks on existing DAGs into TaskGroup without 
altering their task_id.
+
+Here is a more complicated example DAG with multiple levels of nested 
TaskGroups:
+
+.. exampleinclude:: /../airflow/example_dags/example_task_group.py
+    :language: python
+    :start-after: [START howto_task_group]
+    :end-before: [END howto_task_group]
+
+This animated gif shows the UI interactions. TaskGroups are expanded or 
collapsed when clicked:
+
+.. image:: img/task_group.gif
+
+
 SLAs
 ====
 
diff --git a/docs/img/task_group.gif b/docs/img/task_group.gif
new file mode 100644
index 0000000..ac4f6e9
Binary files /dev/null and b/docs/img/task_group.gif differ
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 75c35e8..1b3a993 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -67,6 +67,17 @@ serialized_simple_dag_ground_truth = {
             }
         },
         "start_date": 1564617600.0,
+        '_task_group': {'_group_id': None,
+                        'prefix_group_id': True,
+                        'children': {'bash_task': ('operator', 'bash_task'),
+                                     'custom_task': ('operator', 
'custom_task')},
+                        'tooltip': '',
+                        'ui_color': 'CornflowerBlue',
+                        'ui_fgcolor': '#000',
+                        'upstream_group_ids': [],
+                        'downstream_group_ids': [],
+                        'upstream_task_ids': [],
+                        'downstream_task_ids': []},
         "is_paused_upon_creation": False,
         "_dag_id": "simple_dag",
         "fileloc": None,
@@ -83,6 +94,7 @@ serialized_simple_dag_ground_truth = {
                 "ui_fgcolor": "#000",
                 "template_fields": ['bash_command', 'env'],
                 "bash_command": "echo {{ task.task_id }}",
+                'label': 'bash_task',
                 "_task_type": "BashOperator",
                 "_task_module": "airflow.operators.bash",
                 "pool": "default_pool",
@@ -107,6 +119,7 @@ serialized_simple_dag_ground_truth = {
                 "_task_type": "CustomOperator",
                 "_task_module": "tests.test_utils.mock_operators",
                 "pool": "default_pool",
+                'label': 'custom_task',
             },
         ],
         "timezone": "UTC",
@@ -329,6 +342,7 @@ class TestStringifiedDAGs(unittest.TestCase):
 
             # Need to check fields in it, to exclude functions
             'default_args',
+            "_task_group"
         }
         for field in fields_to_check:
             assert getattr(serialized_dag, field) == getattr(dag, field), \
@@ -765,6 +779,7 @@ class TestStringifiedDAGs(unittest.TestCase):
                           'execution_timeout': None,
                           'executor_config': {},
                           'inlets': [],
+                          'label': '10',
                           'max_retry_delay': None,
                           'on_execute_callback': None,
                           'on_failure_callback': None,
@@ -804,3 +819,51 @@ class TestStringifiedDAGs(unittest.TestCase):
 
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                          """
                          )
+
+    def test_task_group_serialization(self):
+        """
+        Test TaskGroup serialization/deserialization.
+        """
+        from airflow.operators.dummy_operator import DummyOperator
+        from airflow.utils.task_group import TaskGroup
+
+        execution_date = datetime(2020, 1, 1)
+        with DAG("test_task_group_serialization", start_date=execution_date) 
as dag:
+            task1 = DummyOperator(task_id="task1")
+            with TaskGroup("group234") as group234:
+                _ = DummyOperator(task_id="task2")
+
+                with TaskGroup("group34") as group34:
+                    _ = DummyOperator(task_id="task3")
+                    _ = DummyOperator(task_id="task4")
+
+            task5 = DummyOperator(task_id="task5")
+            task1 >> group234
+            group34 >> task5
+
+        dag_dict = SerializedDAG.to_dict(dag)
+        SerializedDAG.validate_schema(dag_dict)
+        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+        self.validate_deserialized_dag(json_dag, dag)
+
+        serialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
+
+        assert serialized_dag.task_group.children
+        assert serialized_dag.task_group.children.keys() == 
dag.task_group.children.keys()
+
+        def check_task_group(node):
+            try:
+                children = node.children.values()
+            except AttributeError:
+                # Round-trip serialization and check the result
+                expected_serialized = 
SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id))
+                expected_deserialized = 
SerializedBaseOperator.deserialize_operator(expected_serialized)
+                expected_dict = 
SerializedBaseOperator.serialize_operator(expected_deserialized)
+                assert node
+                assert SerializedBaseOperator.serialize_operator(node) == 
expected_dict
+                return
+
+            for child in children:
+                check_task_group(child)
+
+        check_task_group(serialized_dag.task_group)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
new file mode 100644
index 0000000..c4f7a12
--- /dev/null
+++ b/tests/utils/test_task_group.py
@@ -0,0 +1,561 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pendulum
+import pytest
+
+from airflow.models import DAG
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.task_group import TaskGroup
+from airflow.www.views import dag_edges, task_group_to_dict
+
+EXPECTED_JSON = {
+    'id': None,
+    'value': {
+        'label': None,
+        'labelStyle': 'fill:#000;',
+        'style': 'fill:CornflowerBlue',
+        'rx': 5,
+        'ry': 5,
+        'clusterLabelPos': 'top',
+    },
+    'tooltip': '',
+    'children': [
+        {
+            'id': 'group234',
+            'value': {
+                'label': 'group234',
+                'labelStyle': 'fill:#000;',
+                'style': 'fill:CornflowerBlue',
+                'rx': 5,
+                'ry': 5,
+                'clusterLabelPos': 'top',
+            },
+            'tooltip': '',
+            'children': [
+                {
+                    'id': 'group234.group34',
+                    'value': {
+                        'label': 'group34',
+                        'labelStyle': 'fill:#000;',
+                        'style': 'fill:CornflowerBlue',
+                        'rx': 5,
+                        'ry': 5,
+                        'clusterLabelPos': 'top',
+                    },
+                    'tooltip': '',
+                    'children': [
+                        {
+                            'id': 'group234.group34.task3',
+                            'value': {
+                                'label': 'task3',
+                                'labelStyle': 'fill:#000;',
+                                'style': 'fill:#e8f7e4;',
+                                'rx': 5,
+                                'ry': 5,
+                            },
+                        },
+                        {
+                            'id': 'group234.group34.task4',
+                            'value': {
+                                'label': 'task4',
+                                'labelStyle': 'fill:#000;',
+                                'style': 'fill:#e8f7e4;',
+                                'rx': 5,
+                                'ry': 5,
+                            },
+                        },
+                        {
+                            'id': 'group234.group34.downstream_join_id',
+                            'value': {
+                                'label': '',
+                                'labelStyle': 'fill:#000;',
+                                'style': 'fill:CornflowerBlue;',
+                                'shape': 'circle',
+                            },
+                        },
+                    ],
+                },
+                {
+                    'id': 'group234.task2',
+                    'value': {
+                        'label': 'task2',
+                        'labelStyle': 'fill:#000;',
+                        'style': 'fill:#e8f7e4;',
+                        'rx': 5,
+                        'ry': 5,
+                    },
+                },
+                {
+                    'id': 'group234.upstream_join_id',
+                    'value': {
+                        'label': '',
+                        'labelStyle': 'fill:#000;',
+                        'style': 'fill:CornflowerBlue;',
+                        'shape': 'circle',
+                    },
+                },
+            ],
+        },
+        {
+            'id': 'task1',
+            'value': {
+                'label': 'task1',
+                'labelStyle': 'fill:#000;',
+                'style': 'fill:#e8f7e4;',
+                'rx': 5,
+                'ry': 5,
+            },
+        },
+        {
+            'id': 'task5',
+            'value': {
+                'label': 'task5',
+                'labelStyle': 'fill:#000;',
+                'style': 'fill:#e8f7e4;',
+                'rx': 5,
+                'ry': 5,
+            },
+        },
+    ],
+}
+
+
+def test_build_task_group_context_manager():
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_build_task_group_context_manager", 
start_date=execution_date) as dag:
+        task1 = DummyOperator(task_id="task1")
+        with TaskGroup("group234") as group234:
+            _ = DummyOperator(task_id="task2")
+
+            with TaskGroup("group34") as group34:
+                _ = DummyOperator(task_id="task3")
+                _ = DummyOperator(task_id="task4")
+
+        task5 = DummyOperator(task_id="task5")
+        task1 >> group234
+        group34 >> task5
+
+    assert task1.get_direct_relative_ids(upstream=False) == {
+        'group234.group34.task4',
+        'group234.group34.task3',
+        'group234.task2',
+    }
+    assert task5.get_direct_relative_ids(upstream=True) == {
+        'group234.group34.task4',
+        'group234.group34.task3',
+    }
+
+    assert dag.task_group.group_id is None
+    assert dag.task_group.is_root
+    assert set(dag.task_group.children.keys()) == {"task1", "group234", 
"task5"}
+    assert group34.group_id == "group234.group34"
+
+    assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
+
+
+def test_build_task_group():
+    """
+    This is an alternative syntax to use TaskGroup. It should result in the 
same TaskGroup
+    as using context manager.
+    """
+    execution_date = pendulum.parse("20200101")
+    dag = DAG("test_build_task_group", start_date=execution_date)
+    task1 = DummyOperator(task_id="task1", dag=dag)
+    group234 = TaskGroup("group234", dag=dag)
+    _ = DummyOperator(task_id="task2", dag=dag, task_group=group234)
+    group34 = TaskGroup("group34", dag=dag, parent_group=group234)
+    _ = DummyOperator(task_id="task3", dag=dag, task_group=group34)
+    _ = DummyOperator(task_id="task4", dag=dag, task_group=group34)
+    task5 = DummyOperator(task_id="task5", dag=dag)
+
+    task1 >> group234
+    group34 >> task5
+
+    assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
+
+
+def extract_node_id(node, include_label=False):
+    ret = {"id": node["id"]}
+    if include_label:
+        ret["label"] = node["value"]["label"]
+    if "children" in node:
+        children = []
+        for child in node["children"]:
+            children.append(extract_node_id(child, 
include_label=include_label))
+
+        ret["children"] = children
+
+    return ret
+
+
+def test_build_task_group_with_prefix():
+    """
+    Tests that prefix_group_id turns on/off prefixing of task_id with group_id.
+    """
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_build_task_group_with_prefix", start_date=execution_date) 
as dag:
+        task1 = DummyOperator(task_id="task1")
+        with TaskGroup("group234", prefix_group_id=False) as group234:
+            task2 = DummyOperator(task_id="task2")
+
+            with TaskGroup("group34") as group34:
+                task3 = DummyOperator(task_id="task3")
+
+                with TaskGroup("group4", prefix_group_id=False) as group4:
+                    task4 = DummyOperator(task_id="task4")
+
+        task5 = DummyOperator(task_id="task5")
+        task1 >> group234
+        group34 >> task5
+
+    assert task2.task_id == "task2"
+    assert group34.group_id == "group34"
+    assert task3.task_id == "group34.task3"
+    assert group4.group_id == "group34.group4"
+    assert task4.task_id == "task4"
+    assert task5.task_id == "task5"
+    assert group234.get_child_by_label("task2") == task2
+    assert group234.get_child_by_label("group34") == group34
+    assert group4.get_child_by_label("task4") == task4
+
+    assert extract_node_id(task_group_to_dict(dag.task_group), 
include_label=True) == {
+        'id': None,
+        'label': None,
+        'children': [
+            {
+                'id': 'group234',
+                'label': 'group234',
+                'children': [
+                    {
+                        'id': 'group34',
+                        'label': 'group34',
+                        'children': [
+                            {
+                                'id': 'group34.group4',
+                                'label': 'group4',
+                                'children': [{'id': 'task4', 'label': 
'task4'}],
+                            },
+                            {'id': 'group34.task3', 'label': 'task3'},
+                            {'id': 'group34.downstream_join_id', 'label': ''},
+                        ],
+                    },
+                    {'id': 'task2', 'label': 'task2'},
+                    {'id': 'group234.upstream_join_id', 'label': ''},
+                ],
+            },
+            {'id': 'task1', 'label': 'task1'},
+            {'id': 'task5', 'label': 'task5'},
+        ],
+    }
+
+
+def test_build_task_group_with_task_decorator():
+    """
+    Test that TaskGroup can be used with the @task decorator.
+    """
+    from airflow.operators.python import task
+
+    @task
+    def task_1():
+        print("task_1")
+
+    @task
+    def task_2():
+        return "task_2"
+
+    @task
+    def task_3():
+        return "task_3"
+
+    @task
+    def task_4(task_2_output, task_3_output):
+        print(task_2_output, task_3_output)
+
+    @task
+    def task_5():
+        print("task_5")
+
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_build_task_group_with_task_decorator", 
start_date=execution_date) as dag:
+        tsk_1 = task_1()
+
+        with TaskGroup("group234") as group234:
+            tsk_2 = task_2()
+            tsk_3 = task_3()
+            tsk_4 = task_4(tsk_2, tsk_3)
+
+        tsk_5 = task_5()
+
+        tsk_1 >> group234 >> tsk_5
+
+    # pylint: disable=no-member
+    assert tsk_1.operator in tsk_2.operator.upstream_list
+    assert tsk_1.operator in tsk_3.operator.upstream_list
+    assert tsk_5.operator in tsk_4.operator.downstream_list
+    # pylint: enable=no-member
+
+    assert extract_node_id(task_group_to_dict(dag.task_group)) == {
+        'id': None,
+        'children': [
+            {
+                'id': 'group234',
+                'children': [
+                    {'id': 'group234.task_2'},
+                    {'id': 'group234.task_3'},
+                    {'id': 'group234.task_4'},
+                    {'id': 'group234.upstream_join_id'},
+                    {'id': 'group234.downstream_join_id'},
+                ],
+            },
+            {'id': 'task_1'},
+            {'id': 'task_5'},
+        ],
+    }
+
+    edges = dag_edges(dag)
+    assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+        ('group234.downstream_join_id', 'task_5'),
+        ('group234.task_2', 'group234.task_4'),
+        ('group234.task_3', 'group234.task_4'),
+        ('group234.task_4', 'group234.downstream_join_id'),
+        ('group234.upstream_join_id', 'group234.task_2'),
+        ('group234.upstream_join_id', 'group234.task_3'),
+        ('task_1', 'group234.upstream_join_id'),
+    ]
+
+
+def test_sub_dag_task_group():
+    """
+    Tests dag.sub_dag() updates task_group correctly.
+    """
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag:
+        task1 = DummyOperator(task_id="task1")
+        with TaskGroup("group234") as group234:
+            _ = DummyOperator(task_id="task2")
+
+            with TaskGroup("group34") as group34:
+                _ = DummyOperator(task_id="task3")
+                _ = DummyOperator(task_id="task4")
+
+        with TaskGroup("group6") as group6:
+            _ = DummyOperator(task_id="task6")
+
+        task7 = DummyOperator(task_id="task7")
+        task5 = DummyOperator(task_id="task5")
+
+        task1 >> group234
+        group34 >> task5
+        group234 >> group6
+        group234 >> task7
+
+    subdag = dag.sub_dag(task_regex="task5", include_upstream=True, 
include_downstream=False)
+
+    assert extract_node_id(task_group_to_dict(subdag.task_group)) == {
+        'id': None,
+        'children': [
+            {
+                'id': 'group234',
+                'children': [
+                    {
+                        'id': 'group234.group34',
+                        'children': [
+                            {'id': 'group234.group34.task3'},
+                            {'id': 'group234.group34.task4'},
+                            {'id': 'group234.group34.downstream_join_id'},
+                        ],
+                    },
+                    {'id': 'group234.upstream_join_id'},
+                ],
+            },
+            {'id': 'task1'},
+            {'id': 'task5'},
+        ],
+    }
+
+    edges = dag_edges(subdag)
+    assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+        ('group234.group34.downstream_join_id', 'task5'),
+        ('group234.group34.task3', 'group234.group34.downstream_join_id'),
+        ('group234.group34.task4', 'group234.group34.downstream_join_id'),
+        ('group234.upstream_join_id', 'group234.group34.task3'),
+        ('group234.upstream_join_id', 'group234.group34.task4'),
+        ('task1', 'group234.upstream_join_id'),
+    ]
+
+    subdag_task_groups = subdag.task_group.get_task_group_dict()
+    assert subdag_task_groups.keys() == {None, "group234", "group234.group34"}
+
+    included_group_ids = {"group234", "group234.group34"}
+    included_task_ids = {'group234.group34.task3', 'group234.group34.task4', 
'task1', 'task5'}
+
+    for task_group in subdag_task_groups.values():
+        assert task_group.upstream_group_ids.issubset(included_group_ids)
+        assert task_group.downstream_group_ids.issubset(included_group_ids)
+        assert task_group.upstream_task_ids.issubset(included_task_ids)
+        assert task_group.downstream_task_ids.issubset(included_task_ids)
+
+    for task in subdag.task_group:
+        assert task.upstream_task_ids.issubset(included_task_ids)
+        assert task.downstream_task_ids.issubset(included_task_ids)
+
+
+def test_dag_edges():
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_dag_edges", start_date=execution_date) as dag:
+        task1 = DummyOperator(task_id="task1")
+        with TaskGroup("group_a") as group_a:
+            with TaskGroup("group_b") as group_b:
+                task2 = DummyOperator(task_id="task2")
+                task3 = DummyOperator(task_id="task3")
+                task4 = DummyOperator(task_id="task4")
+                task2 >> [task3, task4]
+
+            task5 = DummyOperator(task_id="task5")
+
+            task5 << group_b
+
+        task1 >> group_a
+
+        with TaskGroup("group_c") as group_c:
+            task6 = DummyOperator(task_id="task6")
+            task7 = DummyOperator(task_id="task7")
+            task8 = DummyOperator(task_id="task8")
+            [task6, task7] >> task8
+            group_a >> group_c
+
+        task5 >> task8
+
+        task9 = DummyOperator(task_id="task9")
+        task10 = DummyOperator(task_id="task10")
+
+        group_c >> [task9, task10]
+
+        with TaskGroup("group_d") as group_d:
+            task11 = DummyOperator(task_id="task11")
+            task12 = DummyOperator(task_id="task12")
+            task11 >> task12
+
+        group_d << group_c
+
+    nodes = task_group_to_dict(dag.task_group)
+    edges = dag_edges(dag)
+
+    assert extract_node_id(nodes) == {
+        'id': None,
+        'children': [
+            {
+                'id': 'group_a',
+                'children': [
+                    {
+                        'id': 'group_a.group_b',
+                        'children': [
+                            {'id': 'group_a.group_b.task2'},
+                            {'id': 'group_a.group_b.task3'},
+                            {'id': 'group_a.group_b.task4'},
+                            {'id': 'group_a.group_b.downstream_join_id'},
+                        ],
+                    },
+                    {'id': 'group_a.task5'},
+                    {'id': 'group_a.upstream_join_id'},
+                    {'id': 'group_a.downstream_join_id'},
+                ],
+            },
+            {
+                'id': 'group_c',
+                'children': [
+                    {'id': 'group_c.task6'},
+                    {'id': 'group_c.task7'},
+                    {'id': 'group_c.task8'},
+                    {'id': 'group_c.upstream_join_id'},
+                    {'id': 'group_c.downstream_join_id'},
+                ],
+            },
+            {
+                'id': 'group_d',
+                'children': [
+                    {'id': 'group_d.task11'},
+                    {'id': 'group_d.task12'},
+                    {'id': 'group_d.upstream_join_id'},
+                ],
+            },
+            {'id': 'task1'},
+            {'id': 'task10'},
+            {'id': 'task9'},
+        ],
+    }
+
+    assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
+        ('group_a.downstream_join_id', 'group_c.upstream_join_id'),
+        ('group_a.group_b.downstream_join_id', 'group_a.task5'),
+        ('group_a.group_b.task2', 'group_a.group_b.task3'),
+        ('group_a.group_b.task2', 'group_a.group_b.task4'),
+        ('group_a.group_b.task3', 'group_a.group_b.downstream_join_id'),
+        ('group_a.group_b.task4', 'group_a.group_b.downstream_join_id'),
+        ('group_a.task5', 'group_a.downstream_join_id'),
+        ('group_a.task5', 'group_c.task8'),
+        ('group_a.upstream_join_id', 'group_a.group_b.task2'),
+        ('group_c.downstream_join_id', 'group_d.upstream_join_id'),
+        ('group_c.downstream_join_id', 'task10'),
+        ('group_c.downstream_join_id', 'task9'),
+        ('group_c.task6', 'group_c.task8'),
+        ('group_c.task7', 'group_c.task8'),
+        ('group_c.task8', 'group_c.downstream_join_id'),
+        ('group_c.upstream_join_id', 'group_c.task6'),
+        ('group_c.upstream_join_id', 'group_c.task7'),
+        ('group_d.task11', 'group_d.task12'),
+        ('group_d.upstream_join_id', 'group_d.task11'),
+        ('task1', 'group_a.upstream_join_id'),
+    ]
+
+
+def test_duplicate_group_id():
+    from airflow.exceptions import DuplicateTaskIdFound
+
+    execution_date = pendulum.parse("20200101")
+
+    with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"):
+        with DAG("test_duplicate_group_id", start_date=execution_date):
+            _ = DummyOperator(task_id="task1")
+            with TaskGroup("task1"):
+                pass
+
+    with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"):
+        with DAG("test_duplicate_group_id", start_date=execution_date):
+            _ = DummyOperator(task_id="task1")
+            with TaskGroup("group1", prefix_group_id=False):
+                with TaskGroup("group1"):
+                    pass
+
+    with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"):
+        with DAG("test_duplicate_group_id", start_date=execution_date):
+            with TaskGroup("group1", prefix_group_id=False):
+                _ = DummyOperator(task_id="group1")
+
+    with pytest.raises(DuplicateTaskIdFound, match=r".* 
'group1.downstream_join_id' .*"):
+        with DAG("test_duplicate_group_id", start_date=execution_date):
+            _ = DummyOperator(task_id="task1")
+            with TaskGroup("group1"):
+                _ = DummyOperator(task_id="downstream_join_id")
+
+    with pytest.raises(DuplicateTaskIdFound, match=r".* 
'group1.upstream_join_id' .*"):
+        with DAG("test_duplicate_group_id", start_date=execution_date):
+            _ = DummyOperator(task_id="task1")
+            with TaskGroup("group1"):
+                _ = DummyOperator(task_id="upstream_join_id")

Reply via email to