This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.4.x by this push:
new 2d08816 Add resiliency to onnx export code (#13426) (#13567)
2d08816 is described below
commit 2d08816c393e4172e89aa265493fb2a40111d39f
Author: Sina Afrooze <[email protected]>
AuthorDate: Thu Dec 6 18:06:36 2018 -0800
Add resiliency to onnx export code (#13426) (#13567)
* Added resiliency to onnx export code
- With previous infer-shape implementation, if input shape was list instead
of tuple or if extra non-existent parameters were provided, the code would
still work. The fixes in this commit make sure that behavior is restored to
prevent any compatibility issues with existing export code.
* Fixed name of net in unittest
* Fix pylint
---
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 5 +++--
.../python-pytest/onnx/export/mxnet_export_test.py | 21 +++++++++++++++++++--
2 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 14c674f..84db5de 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -134,9 +134,10 @@ class MXNetGraph(object):
# remove any input listed in params from sym.list_inputs() and bind
them to the input shapes provided
# by user. Also remove in_label, which is the name of the label symbol
that may have been used
# as the label for loss during training.
- inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not
in params and n != in_label], in_shape)}
+ inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if
n not in params and n != in_label],
+ in_shape)}
# Add params and their shape to list of inputs
- inputs.update({n: v.shape for n, v in params.items()})
+ inputs.update({n: v.shape for n, v in params.items() if n in
sym.list_inputs()})
# Provide input data as well as input params to infer_shape()
_, out_shapes, _ = sym.infer_shape(**inputs)
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index f4144fd6..964d0e7 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -286,18 +286,19 @@ def _optional_group(symbols, group=False):
return symbols
-def _check_onnx_export(net, group_outputs=False):
+def _check_onnx_export(net, group_outputs=False, shape_type=tuple,
extra_params={}):
net.initialize()
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {name:param._reduce() for name, param in
net.collect_params().items()}
+ net_params.update(extra_params)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
sym=net_sym,
params=net_params,
- input_shape=[data.shape],
+ input_shape=[shape_type(data.shape)],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
@@ -340,6 +341,22 @@ def test_onnx_export_multi_output():
_check_onnx_export(net, group_outputs=True)
+@with_seed()
+def test_onnx_export_list_shape():
+ net = nn.HybridSequential(prefix='list_shape_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+ _check_onnx_export(net, shape_type=list)
+
+
+@with_seed()
+def test_onnx_export_extra_params():
+ net = nn.HybridSequential(prefix='extra_params_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+ _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})
+
+
if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))