cjolivier01 closed pull request #10472: [MXNET-106] [ONNX_MXNet] Change 
parameter names in imported model
URL: https://github.com/apache/incubator-mxnet/pull/10472
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/tutorials/onnx/fine_tuning_gluon.md 
b/docs/tutorials/onnx/fine_tuning_gluon.md
index 7961f9f6b8a..4116ff631eb 100644
--- a/docs/tutorials/onnx/fine_tuning_gluon.md
+++ b/docs/tutorials/onnx/fine_tuning_gluon.md
@@ -230,7 +230,7 @@ sym.get_internals()
 
 
 
-```<Symbol group [input_0, param_0, param_1, convolution0, relu0, lrn0, pad0, 
pooling0, param_2, param_3, convolution1, relu1, lrn1, pad1, pooling1, param_4, 
param_5, convolution2, relu2, param_6, param_7, convolution3, relu3, param_8, 
param_9, convolution4, relu4, pad2, pooling2, _mulscalar0, param_10, param_11, 
_mulscalar1, fullyconnected0, relu5, _mulscalar2, param_12, param_13, 
_mulscalar3, fullyconnected1, relu6, _mulscalar4, param_14, param_15, 
_mulscalar5, fullyconnected2, softmax0]>```<!--notebook-skip-line-->
+```<Symbol group [gpu_0/data_0, gpu_0/conv1_w_0, gpu_0/conv1_b_0, 
convolution0, relu0, lrn0, pad0, pooling0, gpu_0/conv2_w_0, gpu_0/conv2_b_0, 
convolution1, relu1, lrn1, pad1, pooling1, gpu_0/conv3_w_0, gpu_0/conv3_b_0, 
convolution2, relu2, gpu_0/conv4_w_0, gpu_0/conv4_b_0, convolution3, relu3, 
gpu_0/conv5_w_0, gpu_0/conv5_b_0, convolution4, relu4, pad2, pooling2, 
flatten0, gpu_0/fc6_w_0, linalg_gemm20, gpu_0/fc6_b_0, _mulscalar0, 
broadcast_add0, relu5, flatten1, gpu_0/fc7_w_0, linalg_gemm21, gpu_0/fc7_b_0, 
_mulscalar1, broadcast_add1, relu6, flatten2, gpu_0/fc8_w_0, linalg_gemm22, 
gpu_0/fc8_b_0, _mulscalar2, broadcast_add2, 
softmax0]>```<!--notebook-skip-line-->
 
 
 
@@ -258,7 +258,7 @@ We create a symbol block that is going to hold all our 
pre-trained layers, and a
 
 
 ```python
-pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, 
inputs=mx.sym.var('input_0'))
+pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, 
inputs=mx.sym.var('gpu_0/data_0'))
 net_params = pre_trained.collect_params()
 for param in new_arg_params:
     if param in net_params:
diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md 
b/docs/tutorials/onnx/inference_on_onnx_model.md
index 9415d0063c8..bdda820119e 100644
--- a/docs/tutorials/onnx/inference_on_onnx_model.md
+++ b/docs/tutorials/onnx/inference_on_onnx_model.md
@@ -104,11 +104,22 @@ We pick a context, GPU if available, otherwise CPU
 ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
 ```
 
-And load them into a MXNet Gluon symbol block. For ONNX models the default 
input name is `input_0`.
+We obtain the data names of the inputs to the model, by listing all the inputs 
to the symbol graph and excluding the argument and auxiliary parameters from 
that list:
 
+```python
+data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+print(data_names)
+```
+
+
+```['gpu_0/data_0']```
+
+
+And load them into a MXNet Gluon symbol block. 
 
 ```python
-net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))
+net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('gpu_0/data_0'))
 net_params = net.collect_params()
 for param in arg_params:
     if param in net_params:
diff --git a/docs/tutorials/onnx/super_resolution.md 
b/docs/tutorials/onnx/super_resolution.md
index dc75b6606f2..36c06b743c8 100644
--- a/docs/tutorials/onnx/super_resolution.md
+++ b/docs/tutorials/onnx/super_resolution.md
@@ -51,7 +51,7 @@ mx.viz.plot_network(sym, 
node_attrs={"shape":"oval","fixedsize":"false"})
 
 
 
-![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png)
+![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) 
<!--notebook-skip-line-->
 
 
 
@@ -71,10 +71,19 @@ test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
 
 We will use MXNet's Module API to run the inference. For this we will need to 
create the module, bind it to the input data and assign the loaded weights from 
the two parameter objects - argument parameters and auxilliary parameters.
 
+To obtain the input data names we run the following line, which picks all the 
inputs of the symbol graph excluding the argument and auxiliary parameters:
 
 ```python
-mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), 
label_names=None)
-mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], 
label_shapes=None)
+data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg and graph_input not in aux]
+print(data_names)
+```
+
+```['1']```
+
+```python
+mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
+mod.bind(for_training=False, data_shapes=[(data_names[0],test_image.shape)], 
label_shapes=None)
 mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, 
allow_extra=True)
 ```
 
@@ -105,10 +114,10 @@ result_img = Image.merge(
 result_img.save("super_res_output.jpg")
 ```
 
-Here's the input image and the resulting output images compared. As you can 
see, the model was able to increase the spatial resolution from ``256x256`` to 
``672x672``.
+You can now compare the input image and the resulting output image. As you 
will notice, the model was able to increase the spatial resolution from 
``256x256`` to ``672x672``.
 
-| Input Image | Output Image |
-| ----------- | ------------ |
-| 
![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true)
 | 
![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true)
 |
+| Input Image | Output Image | <!--notebook-skip-line-->
+| ----------- | ------------ | <!--notebook-skip-line-->
+| 
![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true)
 | 
![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true)
 | <!--notebook-skip-line-->
 
 <!-- INSERT SOURCE DOWNLOAD BUTTONS -->
\ No newline at end of file
diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py
index f7c7886d0df..a52f1a892a6 100644
--- a/example/onnx/super_resolution.py
+++ b/example/onnx/super_resolution.py
@@ -55,9 +55,13 @@ def get_test_image():
 
 def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
     """Perform inference on image using mxnet"""
+    # To fetch the data names of the input to the model we list the inputs of 
the symbol graph
+    # and exclude the argument and auxiliary parameters from the list
+    data_names = [graph_input for graph_input in sym.list_inputs()
+                  if graph_input not in arg_params and graph_input not in 
aux_params]
     # create module
-    mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
-    mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
+    mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
+    mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_img.shape)])
     mod.set_params(arg_params=arg_params, aux_params=aux_params)
 
     # run inference
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py 
b/python/mxnet/contrib/onnx/_import/import_onnx.py
index 92e7cb9c64e..5192c6f8a85 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -31,7 +31,6 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
     def __init__(self):
         self._nodes = {}
         self._params = {}
-        self._renames = {}
         self._num_input = 0
         self._num_param = 0
 
@@ -72,9 +71,6 @@ def _convert_operator(self, node_name, op_name, attrs, 
inputs):
 
     def from_onnx(self, graph):
         """Construct symbol from onnx graph.
-        The inputs from onnx graph is vague, only providing "1", "2"...
-        For convenience, we rename the `real` input names to "input_0",
-        "input_1"... And renaming parameters to "param_0", "param_1"...
 
         Parameters
         ----------
@@ -98,17 +94,10 @@ def from_onnx(self, graph):
         for i in graph.input:
             if i.name in self._params:
                 # i is a param instead of input
-                name_param = 'param_{}'.format(self._num_param)
-                self._num_param += 1
-                self._params[name_param] = self._params.pop(i.name)
-                self._nodes[name_param] = symbol.Variable(name=name_param,
-                                                          
shape=self._params[name_param].shape)
-                self._renames[i.name] = name_param
+                self._nodes[i.name] = symbol.Variable(name=i.name,
+                                                      
shape=self._params[i.name].shape)
             else:
-                name_input = 'input_{}'.format(self._num_input)
-                self._num_input += 1
-                self._nodes[name_input] = symbol.Variable(name=name_input)
-                self._renames[i.name] = name_input
+                self._nodes[i.name] = symbol.Variable(name=i.name)
 
         # For storing arg  and aux params for the graph.
         auxDict = {}
@@ -121,7 +110,7 @@ def from_onnx(self, graph):
             node_name = node.name.strip()
             node_name = node_name if node_name else None
             onnx_attr = self._parse_attr(node.attribute)
-            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
+            inputs = [self._nodes[i] for i in node.input]
             mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, 
inputs)
 
             for k, i in zip(list(node.output), 
range(len(mxnet_sym.list_outputs()))):
diff --git a/tests/python-pytest/onnx/backend_rep.py 
b/tests/python-pytest/onnx/backend_rep.py
index 47ea6c1585a..114a2eb7990 100644
--- a/tests/python-pytest/onnx/backend_rep.py
+++ b/tests/python-pytest/onnx/backend_rep.py
@@ -64,9 +64,18 @@ def run(self, inputs, **kwargs):
         else:
             raise NotImplementedError("Only CPU context is supported for now")
 
-        mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], 
context=ctx,
+        # To fetch the data names of the input to the model we list the inputs 
of the symbol graph
+        # and exclude the argument and auxiliary parameters from the list
+        data_names = [graph_input for graph_input in self.symbol.list_inputs()
+                      if graph_input not in self.arg_params and graph_input 
not in self.aux_params]
+
+        data_shapes = []
+        for idx, input_name in enumerate(data_names):
+            data_shapes.append((input_name, inputs[idx].shape))
+
+        mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, 
context=ctx,
                             label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)],
+        mod.bind(for_training=False, data_shapes=data_shapes,
                  label_shapes=None)
         mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
 
diff --git a/tests/python-pytest/onnx/onnx_test.py 
b/tests/python-pytest/onnx/onnx_test.py
index ddc633e28f6..36cb9abacdd 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -117,8 +117,7 @@ def test_super_resolution_example():
 
     inputs = sym.list_inputs()
     assert len(inputs) == 9
-    for i, input_param in enumerate(['param_7', 'param_5', 'param_3', 
'param_1',
-                                     'input_0', 'param_0', 'param_2', 
'param_4', 'param_6']):
+    for i, input_param in enumerate(['9', '7', '5', '3', '1', '2', '4', '6', 
'8']):
         assert inputs[i] == input_param
 
     assert len(sym.list_outputs()) == 1
@@ -126,18 +125,16 @@ def test_super_resolution_example():
 
     attrs_keys = sym.attr_dict().keys()
     assert len(attrs_keys) == 19
-    for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7',
-                                  'param_6', 'param_1', 'param_0', 'param_3',
-                                  'param_2', 'reshape2', 'reshape3', 
'reshape0',
-                                  'reshape1', 'convolution2', 'convolution3',
-                                  'convolution0', 'convolution1', 'reshape5',
-                                  'transpose0']):
+    for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0',
+                                  'transpose0', '6', 'reshape0', 'reshape2',
+                                  'reshape3', '3', 'reshape1', '5', '4', '7',
+                                  'convolution1', '9', '2', 'convolution3',
+                                  'reshape5', '8']):
         assert key_item in attrs_keys
 
     param_keys = arg_params.keys()
     assert len(param_keys) == 8
-    for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6',
-                                    'param_1', 'param_0', 'param_3', 
'param_2']):
+    for i, param_item in enumerate(['3', '2', '5', '4', '7', '6', '9', '8']):
         assert param_item in param_keys
 
     logging.info("Asserted the result of the onnx model conversion")
@@ -192,8 +189,10 @@ def test_bvlc_googlenet():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference
@@ -214,8 +213,10 @@ def test_bvlc_reference_caffenet():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference
@@ -236,8 +237,10 @@ def test_bvlc_rcnn_ilsvrc13():
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        mod = mx.mod.Module(symbol=sym, data_names=['input_0'], 
context=mx.cpu(), label_names=None)
-        mod.bind(for_training=False, data_shapes=[('input_0', 
input_data.shape)], label_shapes=None)
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params]
+        mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
+        mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
                        allow_missing=True, allow_extra=True)
         # run inference


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to