This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 666df953b8 Move dag_edges and task_group_to_dict to corresponding util
modules (#26212)
666df953b8 is described below
commit 666df953b8b8cd1d68d16c5666206247663a09d4
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Sep 7 22:15:40 2022 +0200
Move dag_edges and task_group_to_dict to corresponding util modules (#26212)
The methods were implemented in "view" but they were used in other
places (in dot_renderer) so they conceptually belong to common
util code. Having those in a wrong package (airflow/www) caused
the tests to pass because selective checks did nor realise that
change in "airflow/www" also requires running Core tests.
This PR moves the methods to "airflow/utils". The methods
are imported in the "views" module so even if someone used them
from there, they will stil be available there, so the change is
fully backwards compatible (even if those are not "public"
airflow API methods.
Follow up after #26188
---
airflow/utils/dag_edges.py | 127 ++++++++++++++++++++++++++++++
airflow/utils/dot_renderer.py | 2 +-
airflow/utils/task_group.py | 64 ++++++++++++++++
airflow/www/views.py | 170 +----------------------------------------
tests/utils/test_task_group.py | 4 +-
5 files changed, 196 insertions(+), 171 deletions(-)
diff --git a/airflow/utils/dag_edges.py b/airflow/utils/dag_edges.py
new file mode 100644
index 0000000000..570960b53e
--- /dev/null
+++ b/airflow/utils/dag_edges.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List
+
+from airflow.models import Operator
+from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.dag import DAG
+
+
+def dag_edges(dag: DAG):
+ """
+ Create the list of edges needed to construct the Graph view.
+
+ A special case is made if a TaskGroup is immediately upstream/downstream
of another
+ TaskGroup or task. Two dummy nodes named upstream_join_id and
downstream_join_id are
+ created for the TaskGroup. Instead of drawing an edge onto every task in
the TaskGroup,
+ all edges are directed onto the dummy nodes. This is to cut down the
number of edges on
+ the graph.
+
+ For example: A DAG with TaskGroups group1 and group2:
+ group1: task1, task2, task3
+ group2: task4, task5, task6
+
+ group2 is downstream of group1:
+ group1 >> group2
+
+ Edges to add (This avoids having to create edges between every task in
group1 and group2):
+ task1 >> downstream_join_id
+ task2 >> downstream_join_id
+ task3 >> downstream_join_id
+ downstream_join_id >> upstream_join_id
+ upstream_join_id >> task4
+ upstream_join_id >> task5
+ upstream_join_id >> task6
+ """
+ # Edges to add between TaskGroup
+ edges_to_add = set()
+ # Edges to remove between individual tasks that are replaced by
edges_to_add.
+ edges_to_skip = set()
+
+ task_group_map = dag.task_group.get_task_group_dict()
+
+ def collect_edges(task_group):
+ """Update edges_to_add and edges_to_skip according to TaskGroups."""
+ if isinstance(task_group, AbstractOperator):
+ return
+
+ for target_id in task_group.downstream_group_ids:
+ # For every TaskGroup immediately downstream, add edges between
downstream_join_id
+ # and upstream_join_id. Skip edges between individual tasks of the
TaskGroups.
+ target_group = task_group_map[target_id]
+ edges_to_add.add((task_group.downstream_join_id,
target_group.upstream_join_id))
+
+ for child in task_group.get_leaves():
+ edges_to_add.add((child.task_id,
task_group.downstream_join_id))
+ for target in target_group.get_roots():
+ edges_to_skip.add((child.task_id, target.task_id))
+ edges_to_skip.add((child.task_id,
target_group.upstream_join_id))
+
+ for child in target_group.get_roots():
+ edges_to_add.add((target_group.upstream_join_id,
child.task_id))
+ edges_to_skip.add((task_group.downstream_join_id,
child.task_id))
+
+ # For every individual task immediately downstream, add edges between
downstream_join_id and
+ # the downstream task. Skip edges between individual tasks of the
TaskGroup and the
+ # downstream task.
+ for target_id in task_group.downstream_task_ids:
+ edges_to_add.add((task_group.downstream_join_id, target_id))
+
+ for child in task_group.get_leaves():
+ edges_to_add.add((child.task_id,
task_group.downstream_join_id))
+ edges_to_skip.add((child.task_id, target_id))
+
+ # For every individual task immediately upstream, add edges between
the upstream task
+ # and upstream_join_id. Skip edges between the upstream task and
individual tasks
+ # of the TaskGroup.
+ for source_id in task_group.upstream_task_ids:
+ edges_to_add.add((source_id, task_group.upstream_join_id))
+ for child in task_group.get_roots():
+ edges_to_add.add((task_group.upstream_join_id, child.task_id))
+ edges_to_skip.add((source_id, child.task_id))
+
+ for child in task_group.children.values():
+ collect_edges(child)
+
+ collect_edges(dag.task_group)
+
+ # Collect all the edges between individual tasks
+ edges = set()
+
+ tasks_to_trace: List[Operator] = dag.roots
+ while tasks_to_trace:
+ tasks_to_trace_next: List[Operator] = []
+ for task in tasks_to_trace:
+ for child in task.downstream_list:
+ edge = (task.task_id, child.task_id)
+ if edge in edges:
+ continue
+ edges.add(edge)
+ tasks_to_trace_next.append(child)
+ tasks_to_trace = tasks_to_trace_next
+
+ result = []
+ # Build result dicts with the two ends of the edge, plus any extra metadata
+ # if we have it.
+ for source_id, target_id in sorted(edges.union(edges_to_add) -
edges_to_skip):
+ record = {"source_id": source_id, "target_id": target_id}
+ label = dag.get_edge_info(source_id, target_id).get("label")
+ if label:
+ record["label"] = label
+ result.append(record)
+ return result
diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index f36075eac4..08b66a8961 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -28,9 +28,9 @@ from airflow.models.dag import DAG
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.serialized_objects import DagDependency
+from airflow.utils.dag_edges import dag_edges
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
-from airflow.www.views import dag_edges
def _refine_color(color: str):
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 64c11f79db..57881db2ce 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -30,6 +30,7 @@ from airflow.exceptions import (
DuplicateTaskIdFound,
TaskAlreadyInTaskGroup,
)
+from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
@@ -487,3 +488,66 @@ class TaskGroupContext:
return dag.task_group
return cls._context_managed_task_group
+
+
+def task_group_to_dict(task_item_or_group):
+ """
+ Create a nested dict representation of this TaskGroup and its children
used to construct
+ the Graph.
+ """
+ if isinstance(task_item_or_group, AbstractOperator):
+ return {
+ 'id': task_item_or_group.task_id,
+ 'value': {
+ 'label': task_item_or_group.label,
+ 'labelStyle': f"fill:{task_item_or_group.ui_fgcolor};",
+ 'style': f"fill:{task_item_or_group.ui_color};",
+ 'rx': 5,
+ 'ry': 5,
+ },
+ }
+ task_group = task_item_or_group
+ children = [
+ task_group_to_dict(child) for child in
sorted(task_group.children.values(), key=lambda t: t.label)
+ ]
+
+ if task_group.upstream_group_ids or task_group.upstream_task_ids:
+ children.append(
+ {
+ 'id': task_group.upstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ },
+ }
+ )
+
+ if task_group.downstream_group_ids or task_group.downstream_task_ids:
+ # This is the join node used to reduce the number of edges between two
TaskGroup.
+ children.append(
+ {
+ 'id': task_group.downstream_join_id,
+ 'value': {
+ 'label': '',
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color};",
+ 'shape': 'circle',
+ },
+ }
+ )
+
+ return {
+ "id": task_group.group_id,
+ 'value': {
+ 'label': task_group.label,
+ 'labelStyle': f"fill:{task_group.ui_fgcolor};",
+ 'style': f"fill:{task_group.ui_color}",
+ 'rx': 5,
+ 'ry': 5,
+ 'clusterLabelPos': 'top',
+ 'tooltip': task_group.tooltip,
+ },
+ 'children': children,
+ }
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 46b815abcc..d004f7bb71 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -108,6 +108,7 @@ from airflow.timetables.base import DataInterval,
TimeRestriction
from airflow.timetables.interval import CronDataIntervalTimetable
from airflow.utils import json as utils_json, timezone, yaml
from airflow.utils.airflow_flask_app import get_airflow_app
+from airflow.utils.dag_edges import dag_edges
from airflow.utils.dates import infer_time_unit, scale_time_units
from airflow.utils.docs import get_doc_url_for_provider, get_docs_url
from airflow.utils.helpers import alchemy_to_dict
@@ -117,6 +118,7 @@ from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.strings import to_boolean
+from airflow.utils.task_group import task_group_to_dict
from airflow.utils.timezone import td_format, utcnow
from airflow.version import version
from airflow.www import auth, utils as wwwutils
@@ -404,69 +406,6 @@ def dag_to_grid(dag, dag_runs, session):
return task_group_to_grid(dag.task_group, dag_runs, grouped_tis)
-def task_group_to_dict(task_item_or_group):
- """
- Create a nested dict representation of this TaskGroup and its children
used to construct
- the Graph.
- """
- if isinstance(task_item_or_group, AbstractOperator):
- return {
- 'id': task_item_or_group.task_id,
- 'value': {
- 'label': task_item_or_group.label,
- 'labelStyle': f"fill:{task_item_or_group.ui_fgcolor};",
- 'style': f"fill:{task_item_or_group.ui_color};",
- 'rx': 5,
- 'ry': 5,
- },
- }
- task_group = task_item_or_group
- children = [
- task_group_to_dict(child) for child in
sorted(task_group.children.values(), key=lambda t: t.label)
- ]
-
- if task_group.upstream_group_ids or task_group.upstream_task_ids:
- children.append(
- {
- 'id': task_group.upstream_join_id,
- 'value': {
- 'label': '',
- 'labelStyle': f"fill:{task_group.ui_fgcolor};",
- 'style': f"fill:{task_group.ui_color};",
- 'shape': 'circle',
- },
- }
- )
-
- if task_group.downstream_group_ids or task_group.downstream_task_ids:
- # This is the join node used to reduce the number of edges between two
TaskGroup.
- children.append(
- {
- 'id': task_group.downstream_join_id,
- 'value': {
- 'label': '',
- 'labelStyle': f"fill:{task_group.ui_fgcolor};",
- 'style': f"fill:{task_group.ui_color};",
- 'shape': 'circle',
- },
- }
- )
-
- return {
- "id": task_group.group_id,
- 'value': {
- 'label': task_group.label,
- 'labelStyle': f"fill:{task_group.ui_fgcolor};",
- 'style': f"fill:{task_group.ui_color}",
- 'rx': 5,
- 'ry': 5,
- 'clusterLabelPos': 'top',
- 'tooltip': task_group.tooltip,
- },
- 'children': children,
- }
-
-
def get_key_paths(input_dict):
"""Return a list of dot-separated dictionary paths"""
for key, value in input_dict.items():
@@ -490,111 +429,6 @@ def get_value_from_path(key_path, content):
return elem
-def dag_edges(dag: DAG):
- """
- Create the list of edges needed to construct the Graph view.
-
- A special case is made if a TaskGroup is immediately upstream/downstream
of another
- TaskGroup or task. Two dummy nodes named upstream_join_id and
downstream_join_id are
- created for the TaskGroup. Instead of drawing an edge onto every task in
the TaskGroup,
- all edges are directed onto the dummy nodes. This is to cut down the
number of edges on
- the graph.
-
- For example: A DAG with TaskGroups group1 and group2:
- group1: task1, task2, task3
- group2: task4, task5, task6
-
- group2 is downstream of group1:
- group1 >> group2
-
- Edges to add (This avoids having to create edges between every task in
group1 and group2):
- task1 >> downstream_join_id
- task2 >> downstream_join_id
- task3 >> downstream_join_id
- downstream_join_id >> upstream_join_id
- upstream_join_id >> task4
- upstream_join_id >> task5
- upstream_join_id >> task6
- """
- # Edges to add between TaskGroup
- edges_to_add = set()
- # Edges to remove between individual tasks that are replaced by
edges_to_add.
- edges_to_skip = set()
-
- task_group_map = dag.task_group.get_task_group_dict()
-
- def collect_edges(task_group):
- """Update edges_to_add and edges_to_skip according to TaskGroups."""
- if isinstance(task_group, AbstractOperator):
- return
-
- for target_id in task_group.downstream_group_ids:
- # For every TaskGroup immediately downstream, add edges between
downstream_join_id
- # and upstream_join_id. Skip edges between individual tasks of the
TaskGroups.
- target_group = task_group_map[target_id]
- edges_to_add.add((task_group.downstream_join_id,
target_group.upstream_join_id))
-
- for child in task_group.get_leaves():
- edges_to_add.add((child.task_id,
task_group.downstream_join_id))
- for target in target_group.get_roots():
- edges_to_skip.add((child.task_id, target.task_id))
- edges_to_skip.add((child.task_id,
target_group.upstream_join_id))
-
- for child in target_group.get_roots():
- edges_to_add.add((target_group.upstream_join_id,
child.task_id))
- edges_to_skip.add((task_group.downstream_join_id,
child.task_id))
-
- # For every individual task immediately downstream, add edges between
downstream_join_id and
- # the downstream task. Skip edges between individual tasks of the
TaskGroup and the
- # downstream task.
- for target_id in task_group.downstream_task_ids:
- edges_to_add.add((task_group.downstream_join_id, target_id))
-
- for child in task_group.get_leaves():
- edges_to_add.add((child.task_id,
task_group.downstream_join_id))
- edges_to_skip.add((child.task_id, target_id))
-
- # For every individual task immediately upstream, add edges between
the upstream task
- # and upstream_join_id. Skip edges between the upstream task and
individual tasks
- # of the TaskGroup.
- for source_id in task_group.upstream_task_ids:
- edges_to_add.add((source_id, task_group.upstream_join_id))
- for child in task_group.get_roots():
- edges_to_add.add((task_group.upstream_join_id, child.task_id))
- edges_to_skip.add((source_id, child.task_id))
-
- for child in task_group.children.values():
- collect_edges(child)
-
- collect_edges(dag.task_group)
-
- # Collect all the edges between individual tasks
- edges = set()
-
- tasks_to_trace: List[Operator] = dag.roots
- while tasks_to_trace:
- tasks_to_trace_next: List[Operator] = []
- for task in tasks_to_trace:
- for child in task.downstream_list:
- edge = (task.task_id, child.task_id)
- if edge in edges:
- continue
- edges.add(edge)
- tasks_to_trace_next.append(child)
- tasks_to_trace = tasks_to_trace_next
-
- result = []
- # Build result dicts with the two ends of the edge, plus any extra metadata
- # if we have it.
- for source_id, target_id in sorted(edges.union(edges_to_add) -
edges_to_skip):
- record = {"source_id": source_id, "target_id": target_id}
- label = dag.get_edge_info(source_id, target_id).get("label")
- if label:
- record["label"] = label
- result.append(record)
- return result
-
-
def get_task_stats_from_query(qry):
"""
Return a dict of the task quantity, grouped by dag id and task status.
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 3216ca0d70..13b8a9fff7 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -26,8 +26,8 @@ from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
-from airflow.utils.task_group import TaskGroup
-from airflow.www.views import dag_edges, task_group_to_dict
+from airflow.utils.dag_edges import dag_edges
+from airflow.utils.task_group import TaskGroup, task_group_to_dict
from tests.models import DEFAULT_DATE
EXPECTED_JSON = {