This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 33823d3  [MXNET-106] [ONNX_MXNet] Change parameter names in imported 
model (#10472)
33823d3 is described below

commit 33823d3a572ec21de4d76cb8c9a38804d56f6798
Author: Anirudh <2778341+anirudhacha...@users.noreply.github.com>
AuthorDate: Tue Apr 10 13:15:58 2018 -0700

    [MXNET-106] [ONNX_MXNet] Change parameter names in imported model (#10472)
    
    * fix param names in model
    
    * corresponding changes to tutorials
    
    * test rendering
    
    * add comments to data name fetch stmt.
---
 docs/tutorials/onnx/fine_tuning_gluon.md         |  4 +--
 docs/tutorials/onnx/inference_on_onnx_model.md   | 15 ++++++++--
 docs/tutorials/onnx/super_resolution.md          | 23 +++++++++++-----
 example/onnx/super_resolution.py                 |  8 ++++--
 python/mxnet/contrib/onnx/_import/import_onnx.py | 19 +++----------
 tests/python-pytest/onnx/backend_rep.py          | 13 +++++++--
 tests/python-pytest/onnx/onnx_test.py            | 35 +++++++++++++-----------
 7 files changed, 71 insertions(+), 46 deletions(-)

diff --git a/docs/tutorials/onnx/fine_tuning_gluon.md 
b/docs/tutorials/onnx/fine_tuning_gluon.md
index 7961f9f..4116ff6 100644
--- a/docs/tutorials/onnx/fine_tuning_gluon.md
+++ b/docs/tutorials/onnx/fine_tuning_gluon.md
@@ -230,7 +230,7 @@ sym.get_internals()
 
 
 
-```<Symbol group [input_0, param_0, param_1, convolution0, relu0, lrn0, pad0, 
pooling0, param_2, param_3, convolution1, relu1, lrn1, pad1, pooling1, param_4, 
param_5, convolution2, relu2, param_6, param_7, convolution3, relu3, param_8, 
param_9, convolution4, relu4, pad2, pooling2, _mulscalar0, param_10, param_11, 
_mulscalar1, fullyconnected0, relu5, _mulscalar2, param_12, param_13, 
_mulscalar3, fullyconnected1, relu6, _mulscalar4, param_14, param_15, 
_mulscalar5, fullyconnected2, softmax [...]
+```<Symbol group [gpu_0/data_0, gpu_0/conv1_w_0, gpu_0/conv1_b_0, 
convolution0, relu0, lrn0, pad0, pooling0, gpu_0/conv2_w_0, gpu_0/conv2_b_0, 
convolution1, relu1, lrn1, pad1, pooling1, gpu_0/conv3_w_0, gpu_0/conv3_b_0, 
convolution2, relu2, gpu_0/conv4_w_0, gpu_0/conv4_b_0, convolution3, relu3, 
gpu_0/conv5_w_0, gpu_0/conv5_b_0, convolution4, relu4, pad2, pooling2, 
flatten0, gpu_0/fc6_w_0, linalg_gemm20, gpu_0/fc6_b_0, _mulscalar0, 
broadcast_add0, relu5, flatten1, gpu_0/fc7_w_0, linalg_ge [...]
 
 
 
@@ -258,7 +258,7 @@ We create a symbol block that is going to hold all our 
pre-trained layers, and a
 
 
 ```python
-pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, 
inputs=mx.sym.var('input_0'))
+pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, 
inputs=mx.sym.var('gpu_0/data_0'))
 net_params = pre_trained.collect_params()
 for param in new_arg_params:
     if param in net_params:
diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md 
b/docs/tutorials/onnx/inference_on_onnx_model.md
index 9415d00..bdda820 100644
--- a/docs/tutorials/onnx/inference_on_onnx_model.md
+++ b/docs/tutorials/onnx/inference_on_onnx_model.md
@@ -104,11 +104,22 @@ We pick a context, GPU if available, otherwise CPU
 ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
 ```
 
-And load them into a MXNet Gluon symbol block. For ONNX models the default 
input name is `input_0`.
+We obtain the data names of the inputs to the model, by listing all the inputs 
to the symbol graph and excluding the argument and auxiliary parameters from 
that list:
 
+```python
+data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+print(data_names)
+```
+
+
+```['gpu_0/data_0']```
+
+
+And load them into a MXNet Gluon symbol block. 
 
 ```python
-net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))
+net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('gpu_0/data_0'))
 net_params = net.collect_params()
 for param in arg_params:
     if param in net_params:
diff --git a/docs/tutorials/onnx/super_resolution.md 
b/docs/tutorials/onnx/super_resolution.md
index dc75b66..36c06b7 100644
--- a/docs/tutorials/onnx/super_resolution.md
+++ b/docs/tutorials/onnx/super_resolution.md
@@ -51,7 +51,7 @@ mx.viz.plot_network(sym, 
node_attrs={"shape":"oval","fixedsize":"false"})
 
 
 
-![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png)
+![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) 
<!--notebook-skip-line-->
 
 
 
@@ -71,10 +71,19 @@ test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
 
 We will use MXNet's Module API to run the inference. For this we will need to 
create the module, bind it to the input data and assign the loaded weights from 
the two parameter objects - argument parameters and auxilliary parameters.
 
+To obtain the input data names we run the following line, which picks all the 
inputs of the symbol graph excluding the argument and auxiliary parameters:
 
 ```python
-mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), 
label_names=None)
-mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], 
label_shapes=None)
+data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg and graph_input not in aux]
+print(data_names)
+```
+
+```['1']```
+
+```python
+mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
+mod.bind(for_training=False, data_shapes=[(data_names[0],test_image.shape)], 
label_shapes=None)
 mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, 
allow_extra=True)
 ```
 
@@ -105,10 +114,10 @@ result_img = Image.merge(
 result_img.save("super_res_output.jpg")
 ```
 
-Here's the input image and the resulting output images compared. As you can 
see, the model was able to increase the spatial resolution from ``256x256`` to 
``672x672``.
+You can now compare the input image and the resulting output image. As you 
will notice, the model was able to increase the spatial resolution from 
``256x256`` to ``672x672``.
 
-| Input Image | Output Image |
-| ----------- | ------------ |
-| 
![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true)
 | 
![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true)
 |
+| Input Image | Output Image | <!--notebook-skip-line-->
+| ----------- | ------------ | <!--notebook-skip-line-->
+| 
![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true)
 | 
![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true)
 | <!--notebook-skip-line-->
 
 <!-- INSERT SOURCE DOWNLOAD BUTTONS -->
\ No newline at end of file
diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py
index f7c7886..a52f1a8 100644
--- a/example/onnx/super_resolution.py
+++ b/example/onnx/super_resolution.py
@@ -55,9 +55,13 @@ def get_test_image():
 
 def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
     """Perform inference on image using mxnet"""
+    # To fetch the data names of the input to the model we list the inputs of 
the symbol graph
+    # and exclude the argument and auxiliary parameters from the list
+    data_names = [graph_input for graph_input in sym.list_inputs()
+                  if graph_input not in arg_params and graph_input not in 
aux_params]
     # create module
-    mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
-    mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
+    mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
+    mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_img.shape)])
     mod.set_params(arg_params=arg_params, aux_params=aux_params)
 
     # run inference
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py 
b/python/mxnet/contrib/onnx/_import/import_onnx.py
index 92e7cb9..5192c6f 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -31,7 +31,6 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
     def __init__(self):
         self._nodes = {}
         self._params = {}
-        self._renames = {}
         self._num_input = 0
         self._num_param = 0
 
@@ -72,9 +71,6 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
 
     def from_onnx(self, graph):
         """Construct symbol from onnx graph.
-        The inputs from onnx graph is vague, only providing "1", "2"...
-        For convenience, we rename the `real` input names to "input_0",
-        "input_1"... And renaming parameters to "param_0", "param_1"...
 
         Parameters
         ----------
@@ -98,17 +94,10 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
         for i in graph.input:
             if i.name in self._params:
                 # i is a param instead of input
-                name_param = 'param_{}'.format(self._num_param)
-                self._num_param += 1
-                self._params[name_param] = self._params.pop(i.name)
-                self._nodes[name_param] = symbol.Variable(name=name_param,
-                                                          
shape=self._params[name_param].shape)
-                self._renames[i.name] = name_param
+                self._nodes[i.name] = symbol.Variable(name=i.name,
+                                                      
shape=self._params[i.name].shape)
             else:
-                name_input = 'input_{}'.format(self._num_input)
-                self._num_input += 1
-                self._nodes[name_input] = symbol.Variable(name=name_input)
-                self._renames[i.name] = name_input
+                self._nodes[i.name] = symbol.Variable(name=i.name)
 
         # For storing arg  and aux params for the graph.
         auxDict = {}
@@ -121,7 +110,7 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
             node_name = node.name.strip()
             node_name = node_name if node_name else None
             onnx_attr = self._parse_attr(node.attribute)
-            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
+            inputs = [self._nodes[i] for i in node.input]
             mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, 
inputs)
 
             for k, i in zip(list(node.output), 
range(len(mxnet_sym.list_outputs()))):
diff --git a/tests/python-pytest/onnx/backend_rep.py 
b/tests/python-pytest/onnx/backend_rep.py
index 47ea6c1..114a2eb 100644
--- a/tests/python-pytest/onnx/backend_rep.py
+++ b/tests/python-pytest/onnx/backend_rep.py
@@ -64,9 +64,18 @@ class MXNetBackendRep(BackendRep):
         else:
             raise NotImplementedError("Only CPU context is supported for now")
 
-        mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], 
context=ctx,
+        # To fetch the data names of the input to the model we list the inputs 
of the symbol graph
+        # and exclude the argument and auxiliary parameters from the list
+        data_names = [graph_input for graph_input in self.symbol.list_inputs()
+                      if graph_input not in self.arg_params and graph_input 
not in self.aux_params]
+
+        data_shapes = []
+        for idx, input_name in enumerate(data_names):
+            data_shapes.append((input_name, inputs[idx].shape))
+
+        mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, 
context=ctx,
                             label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)],
+        mod.bind(for_training=False, data_shapes=data_shapes,
                  label_shapes=None)
         mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
 
diff --git a/tests/python-pytest/onnx/onnx_test.py 
b/tests/python-pytest/onnx/onnx_test.py
index ddc633e..36cb9ab 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -117,8 +117,7 @@ def test_super_resolution_example():
 
     inputs = sym.list_inputs()
     assert len(inputs) == 9
-    for i, input_param in enumerate(['param_7', 'param_5', 'param_3', 
'param_1',
-                                     'input_0', 'param_0', 'param_2', 
'param_4', 'param_6']):
+    for i, input_param in enumerate(['9', '7', '5', '3', '1', '2', '4', '6', 
'8']):
         assert inputs[i] == input_param
 
     assert len(sym.list_outputs()) == 1
@@ -126,18 +125,16 @@ def test_super_resolution_example():
 
     attrs_keys = sym.attr_dict().keys()
     assert len(attrs_keys) == 19
-    for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7',
-                                  'param_6', 'param_1', 'param_0', 'param_3',
-                                  'param_2', 'reshape2', 'reshape3', 
'reshape0',
-                                  'reshape1', 'convolution2', 'convolution3',
-                                  'convolution0', 'convolution1', 'reshape5',
-                                  'transpose0']):
+    for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0',
+                                  'transpose0', '6', 'reshape0', 'reshape2',
+                                  'reshape3', '3', 'reshape1', '5', '4', '7',
+                                  'convolution1', '9', '2', 'convolution3',
+                                  'reshape5', '8']):
         assert key_item in attrs_keys
 
     param_keys = arg_params.keys()
     assert len(param_keys) == 8
-    for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6',
-                                    'param_1', 'param_0', 'param_3', 
'param_2']):
+    for i, param_item in enumerate(['3', '2', '5', '4', '7', '6', '9', '8']):
         assert param_item in param_keys
 
     logging.info("Asserted the result of the onnx model conversion")
@@ -192,8 +189,10 @@ def test_bvlc_googlenet():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference
@@ -214,8 +213,10 @@ def test_bvlc_reference_caffenet():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference
@@ -236,8 +237,10 @@ def test_bvlc_rcnn_ilsvrc13():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference

-- 
To stop receiving notification emails like this one, please contact
cjolivie...@apache.org.

Reply via email to