This is an automated email from the ASF dual-hosted git repository.

thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a4f2ed5  ONNX export: Add Flatten before Gemm (#13356)
a4f2ed5 is described below

commit a4f2ed5675a4852ac227f477667c90c32bb293e2
Author: Vandana Kannan <[email protected]>
AuthorDate: Thu Dec 20 13:17:05 2018 -0800

    ONNX export: Add Flatten before Gemm (#13356)
    
    * Add Flatten before Gemm
    
    * ONNX export test: Allow multiple inputs in forward pass
    
    * ONNX export: Test for fully connected
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 11 +++
 .../python-pytest/onnx/export/mxnet_export_test.py | 78 ++++++++++++++++++++--
 2 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0d20c76..15624b6 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):
 
     fcnode = []
 
+    op_name = "flatten_" + str(kwargs["idx"])
+    flatten_node = onnx.helper.make_node(
+        'Flatten',
+        inputs=[input_nodes[0]],
+        outputs=[op_name],
+        name=op_name
+    )
+
+    input_nodes[0] = op_name
+    fcnode.append(flatten_node)
+
     if no_bias:
         data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
         bias_name = "bias" + str(kwargs["idx"])
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py 
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 22db0d6..b4fa4b1 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -94,15 +94,33 @@ def get_test_files(name):
 
 
 def forward_pass(sym, arg, aux, data_names, input_data):
-    """ Perform forward pass on given data"""
+    """ Perform forward pass on given data
+    :param sym: Symbol
+    :param arg: Arg params
+    :param aux: Aux params
+    :param data_names: Input names (list)
+    :param input_data: Input data (list). If there is only one input,
+                        pass it as a list. For example, if input is [1, 2],
+                        pass input_data=[[1, 2]]
+    :return: result of forward pass
+    """
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
-    mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
+
+    data_shapes = []
+    data_forward = []
+    for idx in range(len(data_names)):
+        val = input_data[idx]
+        data_shapes.append((data_names[idx], np.shape(val)))
+        data_forward.append(mx.nd.array(val))
+
+    mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
     mod.set_params(arg_params=arg, aux_params=aux,
                    allow_missing=True, allow_extra=True)
+
     # run inference
     batch = namedtuple('Batch', ['data'])
-    mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+    mod.forward(batch(data_forward), is_train=False)
 
     return mod.get_outputs()[0].asnumpy()
 
@@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
     logging.info("Running inference on onnx re-import model in mxnet")
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
 
         # verify the results
         npt.assert_equal(result.shape, output_data.shape)
@@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):
 
     expected_result= []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
         expected_result.append(result)
 
     params = {}
@@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):
 
     actual_result = []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
         actual_result.append(result)
 
     # verify the results
@@ -235,13 +253,59 @@ def test_square():
     converted_model = onnx_mxnet.export_model(square, params, 
[np.shape(input1)], np.float32, "square.onnx")
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
-    result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+    result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])
 
     numpy_op = np.square(input1)
 
     npt.assert_almost_equal(result, numpy_op)
 
 
+@with_seed()
+def test_fully_connected():
+    def random_arrays(*shapes):
+        """Generate some random numpy arrays."""
+        arrays = [np.random.randn(*s).astype("float32")
+                  for s in shapes]
+        if len(arrays) == 1:
+            return arrays[0]
+        return arrays
+
+    data_names = ['x', 'w', 'b']
+
+    dim_in, dim_out = (3, 4)
+    input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))
+
+    ipsym = []
+    data_shapes = []
+    data_forward = []
+    for idx in range(len(data_names)):
+        val = input_data[idx]
+        data_shapes.append((data_names[idx], np.shape(val)))
+        data_forward.append(mx.nd.array(val))
+        ipsym.append(mx.sym.Variable(data_names[idx]))
+
+    op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], 
num_hidden=dim_out, name='FC')
+
+    model = mx.mod.Module(op, data_names=data_names, label_names=None)
+    model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
+
+    model.init_params()
+
+    args, auxs = model.get_params()
+    params = {}
+    params.update(args)
+    params.update(auxs)
+
+    converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape 
in data_shapes], np.float32, "fc.onnx")
+
+    sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
+    result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+
+    numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]
+
+    npt.assert_almost_equal(result, numpy_op)
+
+
 def test_softmax():
     input1 = np.random.rand(1000, 1000).astype("float32")
     label1 = np.random.rand(1000)

Reply via email to