TriLoo opened a new issue #18048: convert mxnet op `split` to onnx `Split` output names error URL: https://github.com/apache/incubator-mxnet/issues/18048 ## Description if a network contains `mx.nd.split()`, then converting it to `onnx` may raise an error `the input of followed layer is not an output of previous layer's output`. The cause is that the output is renamed during converting `mx.nd.split(...)` to `onnx.Split()` node, *i.e.* the `_output + str(i)` is added after the layer name of `mx.nd.split()`, meanwhile, the layer after `mx.nd.split()` which use the outputs of `mx.nd.split()` still expert the layer name of `mx.nd.split()` as inputs. Then the `onnx` would complain that the input is not an output of previous layers! ## To Reproduce ``` python class TmpMulScalar(gluon.HybridBlock): def __init__(self, **kwargs): super(TmpMulScalar, self).__init__(**kwargs) with self.name_scope(): self.conv = gluon.nn.Conv2D(1, 3, 1, 1, use_bias=False) self.val = 0.1 def hybrid_forward(self, F, x): r,g,b = F.split(x, axis=1, num_outputs=3) # split can cause error!!! r = g + self.val feat = self.conv(r) output = F.concat(feat, g, b, dim=1) return output def try_mul_scalar(): net = TmpMulScalar() net.initialize() data = nd.random.uniform(0.0, 1.0, (1, 3, 10, 10)) net.hybridize() net(data) net.export('./temp') import onnx from mxnet.contrib import onnx as mx_onnx sym_file = './temp-symbol.json' param_file = './temp-0000.params' converted_file = mx_onnx.export_model(sym_file, param_file, [(1, 3, 10, 10)], onnx_file_path='./temp.onnx') print('converted_file: ', converted_file) from onnx import checker model_onnx = onnx.load_model(converted_file) checker.check_graph(model_onnx.graph) ``` ### Steps to reproduce just run above code can reproduce this error. ## Possible Solutions 1. add a check in [get_inputs() - op_translation](https://github.com/apache/incubator-mxnet/blob/e3d7866e6854a5c11ab2b2c8bfb63de66f79e132/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py#L129), `i.e.` if the input names contain `split`, then should change the input nodes name to `proc_nodes[input_node_id].name + '_output'+str(ip[1])`, where `ip[1]` means the output index of the `mx.nd.split()`. The complete function change is shown as below: ``` python def get_inputs(node, kwargs): """Helper function to get inputs""" name = node["name"] proc_nodes = kwargs["proc_nodes"] index_lookup = kwargs["index_lookup"] inputs = node["inputs"] attrs = node.get("attrs", {}) input_nodes = [] for ip in inputs: input_node_id = index_lookup[ip[0]] input_node_name = proc_nodes[input_node_id].name if 'split' in input_node_name: input_node_name = input_node_name + '_output' + str(ip[1]) # input_nodes.append(proc_nodes[input_node_id].name) input_nodes.append(input_node_name) return name, input_nodes, attrs ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
