This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b077965  Add dtype visualization to plot_network (#14066)
b077965 is described below

commit b077965b7e41108cafe0123659c4802cb771442b
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Thu Mar 14 18:14:44 2019 -0700

    Add dtype visualization to plot_network (#14066)
    
    * Add dtype to plot_network
    
    * Added docstring for the new param
    
    * Added dtype to the plot_network test
    
    * Changes from review
    
    * Fixes from review
    
    * Fix typo
    
    * Retrigger CI
---
 python/mxnet/visualization.py     | 38 ++++++++++++++++++++++++++++++--------
 tests/python/unittest/test_viz.py |  2 ++
 2 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 1ebdcb5..dd3a1df 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -208,7 +208,7 @@ def print_summary(symbol, shape=None, line_length=120, 
positions=[.44, .64, .74,
     print('Total params: %s' % total_params)
     print('_' * line_length)
 
-def plot_network(symbol, title="plot", save_format='pdf', shape=None, 
node_attrs={},
+def plot_network(symbol, title="plot", save_format='pdf', shape=None, 
dtype=None, node_attrs={},
                  hide_weights=True):
     """Creates a visualization (Graphviz digraph object) of the given 
computation graph.
     Graphviz must be installed for this function to work.
@@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', 
shape=None, node_attrs
         Specifies the shape of the input tensors. If specified, the 
visualization will include
         the shape of the tensors between the nodes. `shape` is a dictionary 
mapping
         input symbol names (str) to the corresponding tensor shape (tuple).
+    dtype: dict, optional
+        Specifies the type of the input tensors. If specified, the 
visualization will include
+        the type of the tensors between the nodes. `dtype` is a dictionary 
mapping
+        input symbol names (str) to the corresponding tensor type (e.g. 
`numpy.float32`).
     node_attrs: dict, optional
         Specifies the attributes for nodes in the generated visualization. 
`node_attrs` is
         a dictionary of Graphviz attribute names and values. For example::
@@ -271,14 +275,19 @@ def plot_network(symbol, title="plot", save_format='pdf', 
shape=None, node_attrs
         raise ImportError("Draw network requires graphviz library")
     if not isinstance(symbol, Symbol):
         raise TypeError("symbol must be a Symbol")
-    draw_shape = False
-    if shape is not None:
-        draw_shape = True
-        interals = symbol.get_internals()
-        _, out_shapes, _ = interals.infer_shape(**shape)
+    internals = symbol.get_internals()
+    draw_shape = shape is not None
+    if draw_shape:
+        _, out_shapes, _ = internals.infer_shape(**shape)
         if out_shapes is None:
             raise ValueError("Input shape is incomplete")
-        shape_dict = dict(zip(interals.list_outputs(), out_shapes))
+        shape_dict = dict(zip(internals.list_outputs(), out_shapes))
+    draw_type = dtype is not None
+    if draw_type:
+        _, out_types, _ = internals.infer_type(**dtype)
+        if out_types is None:
+            raise ValueError("Input type is incomplete")
+        type_dict = dict(zip(internals.list_outputs(), out_types))
     conf = json.loads(symbol.tojson())
     nodes = conf["nodes"]
     # check if multiple nodes have the same name
@@ -370,7 +379,7 @@ def plot_network(symbol, title="plot", save_format='pdf', 
shape=None, node_attrs
                 input_node = nodes[item[0]]
                 input_name = input_node["name"]
                 if input_name not in hidden_nodes:
-                    attr = {"dir": "back", 'arrowtail':'open'}
+                    attr = {"dir": "back", 'arrowtail':'open', 'label': ''}
                     # add shapes
                     if draw_shape:
                         if input_node["op"] != "null":
@@ -387,6 +396,19 @@ def plot_network(symbol, title="plot", save_format='pdf', 
shape=None, node_attrs
                             shape = shape_dict[key][1:]
                             label = "x".join([str(x) for x in shape])
                             attr["label"] = label
+                    if draw_type:
+                        if input_node["op"] != "null":
+                            key = input_name + "_output"
+                            if "attrs" in input_node:
+                                params = input_node["attrs"]
+                                if "num_outputs" in params:
+                                    key += str(int(params["num_outputs"]) - 1)
+                            dtype = type_dict[key]
+                            attr["label"] += '(' + dtype.__name__ + ')'
+                        else:
+                            key = input_name
+                            dtype = type_dict[key]
+                            attr["label"] += '(' + dtype.__name__ + ')'
                     dot.edge(tail_name=name, head_name=input_name, **attr)
 
     return dot
diff --git a/tests/python/unittest/test_viz.py 
b/tests/python/unittest/test_viz.py
index fe564b0..1321099 100644
--- a/tests/python/unittest/test_viz.py
+++ b/tests/python/unittest/test_viz.py
@@ -19,6 +19,7 @@ import unittest
 import warnings
 
 import mxnet as mx
+import numpy as np
 
 
 def test_print_summary():
@@ -55,6 +56,7 @@ def test_plot_network():
     net = mx.sym.SoftmaxOutput(data=net, name='out')
     with warnings.catch_warnings(record=True) as w:
         digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
+                                      dtype={'data': np.float32},
                                       node_attrs={"fixedsize": "false"})
     assert len(w) == 1
     assert "There are multiple variables with the same name in your graph" in 
str(w[-1].message)

Reply via email to