This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b077965 Add dtype visualization to plot_network (#14066)
b077965 is described below
commit b077965b7e41108cafe0123659c4802cb771442b
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Thu Mar 14 18:14:44 2019 -0700
Add dtype visualization to plot_network (#14066)
* Add dtype to plot_network
* Added docstring for the new param
* Added dtype to the plot_network test
* Changes from review
* Fixes from review
* Fix typo
* Retrigger CI
---
python/mxnet/visualization.py | 38 ++++++++++++++++++++++++++++++--------
tests/python/unittest/test_viz.py | 2 ++
2 files changed, 32 insertions(+), 8 deletions(-)
diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 1ebdcb5..dd3a1df 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -208,7 +208,7 @@ def print_summary(symbol, shape=None, line_length=120,
positions=[.44, .64, .74,
print('Total params: %s' % total_params)
print('_' * line_length)
-def plot_network(symbol, title="plot", save_format='pdf', shape=None,
node_attrs={},
+def plot_network(symbol, title="plot", save_format='pdf', shape=None,
dtype=None, node_attrs={},
hide_weights=True):
"""Creates a visualization (Graphviz digraph object) of the given
computation graph.
Graphviz must be installed for this function to work.
@@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf',
shape=None, node_attrs
Specifies the shape of the input tensors. If specified, the
visualization will include
the shape of the tensors between the nodes. `shape` is a dictionary
mapping
input symbol names (str) to the corresponding tensor shape (tuple).
+ dtype: dict, optional
+ Specifies the type of the input tensors. If specified, the
visualization will include
+ the type of the tensors between the nodes. `dtype` is a dictionary
mapping
+ input symbol names (str) to the corresponding tensor type (e.g.
`numpy.float32`).
node_attrs: dict, optional
Specifies the attributes for nodes in the generated visualization.
`node_attrs` is
a dictionary of Graphviz attribute names and values. For example::
@@ -271,14 +275,19 @@ def plot_network(symbol, title="plot", save_format='pdf',
shape=None, node_attrs
raise ImportError("Draw network requires graphviz library")
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be a Symbol")
- draw_shape = False
- if shape is not None:
- draw_shape = True
- interals = symbol.get_internals()
- _, out_shapes, _ = interals.infer_shape(**shape)
+ internals = symbol.get_internals()
+ draw_shape = shape is not None
+ if draw_shape:
+ _, out_shapes, _ = internals.infer_shape(**shape)
if out_shapes is None:
raise ValueError("Input shape is incomplete")
- shape_dict = dict(zip(interals.list_outputs(), out_shapes))
+ shape_dict = dict(zip(internals.list_outputs(), out_shapes))
+ draw_type = dtype is not None
+ if draw_type:
+ _, out_types, _ = internals.infer_type(**dtype)
+ if out_types is None:
+ raise ValueError("Input type is incomplete")
+ type_dict = dict(zip(internals.list_outputs(), out_types))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
@@ -370,7 +379,7 @@ def plot_network(symbol, title="plot", save_format='pdf',
shape=None, node_attrs
input_node = nodes[item[0]]
input_name = input_node["name"]
if input_name not in hidden_nodes:
- attr = {"dir": "back", 'arrowtail':'open'}
+ attr = {"dir": "back", 'arrowtail':'open', 'label': ''}
# add shapes
if draw_shape:
if input_node["op"] != "null":
@@ -387,6 +396,19 @@ def plot_network(symbol, title="plot", save_format='pdf',
shape=None, node_attrs
shape = shape_dict[key][1:]
label = "x".join([str(x) for x in shape])
attr["label"] = label
+ if draw_type:
+ if input_node["op"] != "null":
+ key = input_name + "_output"
+ if "attrs" in input_node:
+ params = input_node["attrs"]
+ if "num_outputs" in params:
+ key += str(int(params["num_outputs"]) - 1)
+ dtype = type_dict[key]
+ attr["label"] += '(' + dtype.__name__ + ')'
+ else:
+ key = input_name
+ dtype = type_dict[key]
+ attr["label"] += '(' + dtype.__name__ + ')'
dot.edge(tail_name=name, head_name=input_name, **attr)
return dot
diff --git a/tests/python/unittest/test_viz.py
b/tests/python/unittest/test_viz.py
index fe564b0..1321099 100644
--- a/tests/python/unittest/test_viz.py
+++ b/tests/python/unittest/test_viz.py
@@ -19,6 +19,7 @@ import unittest
import warnings
import mxnet as mx
+import numpy as np
def test_print_summary():
@@ -55,6 +56,7 @@ def test_plot_network():
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
+ dtype={'data': np.float32},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in
str(w[-1].message)