TriLoo opened a new issue #18048: convert mxnet op `split` to onnx `Split` 
output names error
URL: https://github.com/apache/incubator-mxnet/issues/18048
 
 
   ## Description
   if a network contains `mx.nd.split()`, then converting it to `onnx` may 
raise an error `the input of followed layer is not an output of previous 
layer's output`. The cause is that the output is renamed during converting 
`mx.nd.split(...)` to `onnx.Split()` node, *i.e.* the `_output + str(i)` is 
added after the layer name of `mx.nd.split()`, meanwhile, the layer after 
`mx.nd.split()` which use the outputs of `mx.nd.split()` still expert the layer 
name of `mx.nd.split()` as inputs.  Then the `onnx` would complain that the 
input is not an output of previous layers!
   
   ## To Reproduce
   ``` python
   class TmpMulScalar(gluon.HybridBlock):
       def __init__(self, **kwargs):
           super(TmpMulScalar, self).__init__(**kwargs)
   
           with self.name_scope():
               self.conv = gluon.nn.Conv2D(1, 3, 1, 1, use_bias=False)
   
           self.val = 0.1
   
       def hybrid_forward(self, F, x):
           r,g,b = F.split(x, axis=1, num_outputs=3)           # split can 
cause error!!!
           r = g + self.val
           feat = self.conv(r)
           output = F.concat(feat, g, b, dim=1)
   
           return output
   
   def try_mul_scalar():
       net = TmpMulScalar()
       net.initialize()
       data = nd.random.uniform(0.0, 1.0, (1, 3, 10, 10))
   
       net.hybridize()
       net(data)
   
       net.export('./temp')
   
       import onnx
       from mxnet.contrib import onnx as mx_onnx
   
       sym_file = './temp-symbol.json'
       param_file = './temp-0000.params'
   
       converted_file = mx_onnx.export_model(sym_file, param_file, [(1, 3, 10, 
10)], onnx_file_path='./temp.onnx')
       print('converted_file: ', converted_file)
   
       from onnx import checker
       model_onnx = onnx.load_model(converted_file)
       checker.check_graph(model_onnx.graph)
   ```
   
   ### Steps to reproduce
   just run above code can reproduce this error.
   
   ## Possible Solutions
   1. add a check in [get_inputs() - 
op_translation](https://github.com/apache/incubator-mxnet/blob/e3d7866e6854a5c11ab2b2c8bfb63de66f79e132/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py#L129),
 `i.e.` if the input names contain `split`, then should change the input nodes 
name to `proc_nodes[input_node_id].name + '_output'+str(ip[1])`, where `ip[1]` 
means the output index of the `mx.nd.split()`.
   
   The complete function change is shown as below:
   
   ``` python
   def get_inputs(node, kwargs):
       """Helper function to get inputs"""
       name = node["name"]
       proc_nodes = kwargs["proc_nodes"]
       index_lookup = kwargs["index_lookup"]
       inputs = node["inputs"]
       attrs = node.get("attrs", {})
   
       input_nodes = []
       for ip in inputs:
           input_node_id = index_lookup[ip[0]]
           input_node_name = proc_nodes[input_node_id].name
           if 'split' in input_node_name:
               input_node_name = input_node_name + '_output' + str(ip[1])
           # input_nodes.append(proc_nodes[input_node_id].name)
           input_nodes.append(input_node_name)
   
       return name, input_nodes, attrs
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to