This is an automated email from the ASF dual-hosted git repository.
thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cb627bc Onnx multi output (#13390)
cb627bc is described below
commit cb627bcaccc127a00ab035a2a3006e5cbb6d501d
Author: Sina Afrooze <[email protected]>
AuthorDate: Mon Nov 26 00:00:08 2018 -0800
Onnx multi output (#13390)
* Fix ONNX export to support multi-output graphs
* Add ONNX unit-test
* Added multi-output shape inference.
- Removed unnecessary forward_pass() call
- Modified infer_output_shape to return multiple shapes for multiple
outputs as well as output names.
* Fixed pylint
---
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 128 +++++++--------------
.../python-pytest/onnx/export/mxnet_export_test.py | 76 ++++++++++++
2 files changed, 119 insertions(+), 85 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index b02d970..14c674f 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -53,12 +53,8 @@ from __future__ import print_function
from __future__ import unicode_literals
import logging
import json
-import numpy as np
-from .... import context
from .... import ndarray as nd
-from .... import io
-from .... import module as mod
class MXNetGraph(object):
@@ -96,60 +92,6 @@ class MXNetGraph(object):
return convert_func(node, **kwargs)
@staticmethod
- def forward_pass(inputs, sym, arg_params, aux_params, output_label):
- """Do a forward pass based on the sym and params to get the shape
- of the output using dummy data
-
- Parameters
- ----------
- inputs : json string
-
- sym : :class:`~mxnet.symbol.Symbol`
- MXNet symbol object
- arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
- Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
- aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
- Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
-
- Returns
- -------
- shape : Shape
- Output shape
- """
- # if label is not provided, MXNet adds label "softmax_label" by default
- # while running load_checkpoint which is not actually a graph input.
So ignoring it here
- data_names = [graph_input for graph_input in sym.list_inputs()
- if graph_input not in arg_params and graph_input not in
aux_params
- and graph_input != output_label]
-
- data_shapes = []
- # Adding extra dimension of batch_size 1 if the batch_size is
different for multiple inputs.
- for idx, input_name in enumerate(data_names):
- data_shapes.append((input_name, inputs[idx].shape))
-
- # create module, passing cpu context
- ctx = context.cpu()
- test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx,
label_names=None)
- test_mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
-
- # initializing parameters for calculating result of each individual
node
- if arg_params is None and aux_params is None:
- test_mod.init_params()
- else:
- test_mod.set_params(arg_params=arg_params, aux_params=aux_params,
allow_missing=True)
-
- data_forward = []
- for idx, input_name in enumerate(data_names):
- val = inputs[idx]
- data_forward.append(nd.array(val))
-
- test_mod.forward(io.DataBatch(data_forward))
- result = test_mod.get_outputs()[0].asnumpy()
-
- return result.shape
-
-
- @staticmethod
def split_params(sym, params):
"""Helper function to split params dictionary into args and aux params
@@ -177,15 +119,40 @@ class MXNetGraph(object):
aux_params.update({aux: nd.array(params[aux])})
return arg_params, aux_params
-
@staticmethod
- def infer_output_shape(sym, params, in_shape, output_label):
- """Infer output shape by doing a forward pass using dummy inputs """
- # create dummy input
- inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
- arg, aux = MXNetGraph.split_params(sym, params)
- return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+ def get_outputs(sym, params, in_shape, in_label):
+ """ Infer output shapes and return dictionary of output name to shape
+
+ :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer
shape on
+ :param dic of (str, nd.NDArray) params:
+ :param list of tuple(int, ...) in_shape: list of all input shapes
+ :param in_label: name of label typically used in loss that may be
left in graph. This name is
+ removed from list of inputs required by symbol
+ :return: dictionary of output name to shape
+ :rtype: dict of (str, tuple(int, ...))
+ """
+ # remove any input listed in params from sym.list_inputs() and bind
them to the input shapes provided
+ # by user. Also remove in_label, which is the name of the label symbol
that may have been used
+ # as the label for loss during training.
+ inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not
in params and n != in_label], in_shape)}
+ # Add params and their shape to list of inputs
+ inputs.update({n: v.shape for n, v in params.items()})
+ # Provide input data as well as input params to infer_shape()
+ _, out_shapes, _ = sym.infer_shape(**inputs)
+
+ out_names = list()
+ for name in sym.list_outputs():
+ if name.endswith('_output'):
+ out_names.append(name[:-len('_output')])
+ else:
+ logging.warning("output '%s' does not end with '_output'",
name)
+ out_names.append(name)
+ assert len(out_shapes) == len(out_names)
+ # bind output shapes with output names
+ graph_outputs = {n: s for n, s in zip(out_names, out_shapes)}
+
+ return graph_outputs
@staticmethod
def convert_weights_to_numpy(weights_dict):
@@ -228,9 +195,6 @@ class MXNetGraph(object):
# Deriving the output_label name.
output_label = sym.get_internals()[len(sym.get_internals()) - 1].name
+ "_label"
- # Determine output shape
- output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape,
output_label)
-
weights = MXNetGraph.convert_weights_to_numpy(params)
mx_graph = json.loads(sym.tojson())["nodes"]
@@ -242,6 +206,9 @@ class MXNetGraph(object):
onnx_processed_outputs = []
index_lookup = []
+ # Determine output shape
+ graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape,
output_label)
+
graph_input_idx = 0
for idx, node in enumerate(mx_graph):
op = node["op"]
@@ -294,24 +261,15 @@ class MXNetGraph(object):
# If converted node is NodeProto, add it in processed
nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
- if idx == (len(mx_graph) - 1):
- # If converted node doesnt have name, use it from
output field
- if not converted_node.name:
- onnx_processed_outputs.append(
- make_tensor_value_info(
- name=converted_node.output[0],
- elem_type=in_type,
- shape=output_shape
- )
- )
- else:
- onnx_processed_outputs.append(
- make_tensor_value_info(
- name=converted_node.name,
- elem_type=in_type,
- shape=output_shape
- )
+ node_name = converted_node.name if converted_node.name
else converted_node.output[0]
+ if node_name in graph_outputs:
+ onnx_processed_outputs.append(
+ make_tensor_value_info(
+ name=node_name,
+ elem_type=in_type,
+ shape=graph_outputs[node_name]
)
+ )
if verbose:
logging.info("Output node is: %s",
converted_node.name)
elif isinstance(converted_node, TensorProto):
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 9f91369..bbff783 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -28,11 +28,14 @@ import os
import unittest
import logging
import tarfile
+import tempfile
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper, helper
from onnx import TensorProto
+from mxnet import nd, sym
+from mxnet.gluon import nn
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
@@ -238,6 +241,79 @@ def test_square():
npt.assert_almost_equal(result, numpy_op)
+
+def _assert_sym_equal(lhs, rhs):
+ assert lhs.list_inputs() == rhs.list_inputs() # input names must be
identical
+ assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of
outputs must be identical
+
+
+def _force_list(output):
+ if isinstance(output, nd.NDArray):
+ return [output]
+ return list(output)
+
+
+def _optional_group(symbols, group=False):
+ if group:
+ return sym.Group(symbols)
+ else:
+ return symbols
+
+
+def _check_onnx_export(net, group_outputs=False):
+ net.initialize()
+ data = nd.random.uniform(0, 1, (1, 1024))
+ output = _force_list(net(data)) # initialize weights
+ net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
+ net_params = {name:param._reduce() for name, param in
net.collect_params().items()}
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
+ export_path = onnx_mxnet.export_model(
+ sym=net_sym,
+ params=net_params,
+ input_shape=[data.shape],
+ onnx_file_path=onnx_file_path)
+ assert export_path == onnx_file_path
+ # Try importing the model to symbol
+ _assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])
+
+ # Try importing the model to gluon
+ imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
+ _assert_sym_equal(net_sym,
_optional_group(imported_net(sym.Variable('data')), group_outputs))
+
+ # Confirm network outputs are the same
+ imported_net_output = _force_list(imported_net(data))
+ for out, imp_out in zip(output, imported_net_output):
+ mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy())
+
+
+@with_seed()
+def test_onnx_export_single_output():
+ net = nn.HybridSequential(prefix='single_output_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+ _check_onnx_export(net)
+
+
+@with_seed()
+def test_onnx_export_multi_output():
+ class MultiOutputBlock(nn.HybridBlock):
+ def __init__(self):
+ super(MultiOutputBlock, self).__init__()
+ with self.name_scope():
+ self.net = nn.HybridSequential()
+ for i in range(10):
+ self.net.add(nn.Dense(100 + i * 10, activation='relu'))
+
+ def hybrid_forward(self, F, x):
+ out = tuple(block(x) for block in self.net._children.values())
+ return out
+
+ net = MultiOutputBlock()
+ assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
+ _check_onnx_export(net, group_outputs=True)
+
+
if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))