This is an automated email from the ASF dual-hosted git repository.
thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a4f2ed5 ONNX export: Add Flatten before Gemm (#13356)
a4f2ed5 is described below
commit a4f2ed5675a4852ac227f477667c90c32bb293e2
Author: Vandana Kannan <[email protected]>
AuthorDate: Thu Dec 20 13:17:05 2018 -0800
ONNX export: Add Flatten before Gemm (#13356)
* Add Flatten before Gemm
* ONNX export test: Allow multiple inputs in forward pass
* ONNX export: Test for fully connected
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 11 +++
.../python-pytest/onnx/export/mxnet_export_test.py | 78 ++++++++++++++++++++--
2 files changed, 82 insertions(+), 7 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0d20c76..15624b6 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):
fcnode = []
+ op_name = "flatten_" + str(kwargs["idx"])
+ flatten_node = onnx.helper.make_node(
+ 'Flatten',
+ inputs=[input_nodes[0]],
+ outputs=[op_name],
+ name=op_name
+ )
+
+ input_nodes[0] = op_name
+ fcnode.append(flatten_node)
+
if no_bias:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 22db0d6..b4fa4b1 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -94,15 +94,33 @@ def get_test_files(name):
def forward_pass(sym, arg, aux, data_names, input_data):
- """ Perform forward pass on given data"""
+ """ Perform forward pass on given data
+ :param sym: Symbol
+ :param arg: Arg params
+ :param aux: Aux params
+ :param data_names: Input names (list)
+ :param input_data: Input data (list). If there is only one input,
+ pass it as a list. For example, if input is [1, 2],
+ pass input_data=[[1, 2]]
+ :return: result of forward pass
+ """
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(),
label_names=None)
- mod.bind(for_training=False, data_shapes=[(data_names[0],
input_data.shape)], label_shapes=None)
+
+ data_shapes = []
+ data_forward = []
+ for idx in range(len(data_names)):
+ val = input_data[idx]
+ data_shapes.append((data_names[idx], np.shape(val)))
+ data_forward.append(mx.nd.array(val))
+
+ mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
+
# run inference
batch = namedtuple('Batch', ['data'])
- mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+ mod.forward(batch(data_forward), is_train=False)
return mod.get_outputs()[0].asnumpy()
@@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
logging.info("Running inference on onnx re-import model in mxnet")
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
- result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+ result = forward_pass(sym, arg_params, aux_params, data_names,
[input_data])
# verify the results
npt.assert_equal(result.shape, output_data.shape)
@@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):
expected_result= []
for input_data, output_data in zip(inputs, outputs):
- result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+ result = forward_pass(sym, arg_params, aux_params, data_names,
[input_data])
expected_result.append(result)
params = {}
@@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):
actual_result = []
for input_data, output_data in zip(inputs, outputs):
- result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+ result = forward_pass(sym, arg_params, aux_params, data_names,
[input_data])
actual_result.append(result)
# verify the results
@@ -235,13 +253,59 @@ def test_square():
converted_model = onnx_mxnet.export_model(square, params,
[np.shape(input1)], np.float32, "square.onnx")
sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
- result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+ result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])
numpy_op = np.square(input1)
npt.assert_almost_equal(result, numpy_op)
+@with_seed()
+def test_fully_connected():
+ def random_arrays(*shapes):
+ """Generate some random numpy arrays."""
+ arrays = [np.random.randn(*s).astype("float32")
+ for s in shapes]
+ if len(arrays) == 1:
+ return arrays[0]
+ return arrays
+
+ data_names = ['x', 'w', 'b']
+
+ dim_in, dim_out = (3, 4)
+ input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))
+
+ ipsym = []
+ data_shapes = []
+ data_forward = []
+ for idx in range(len(data_names)):
+ val = input_data[idx]
+ data_shapes.append((data_names[idx], np.shape(val)))
+ data_forward.append(mx.nd.array(val))
+ ipsym.append(mx.sym.Variable(data_names[idx]))
+
+ op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2],
num_hidden=dim_out, name='FC')
+
+ model = mx.mod.Module(op, data_names=data_names, label_names=None)
+ model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
+
+ model.init_params()
+
+ args, auxs = model.get_params()
+ params = {}
+ params.update(args)
+ params.update(auxs)
+
+ converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape
in data_shapes], np.float32, "fc.onnx")
+
+ sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
+ result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+
+ numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]
+
+ npt.assert_almost_equal(result, numpy_op)
+
+
def test_softmax():
input1 = np.random.rand(1000, 1000).astype("float32")
label1 = np.random.rand(1000)