ThomasDelteil closed pull request #13356: ONNX export: Add Flatten before Gemm
URL: https://github.com/apache/incubator-mxnet/pull/13356
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0d20c76240b..15624b6c3a2 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):
 
     fcnode = []
 
+    op_name = "flatten_" + str(kwargs["idx"])
+    flatten_node = onnx.helper.make_node(
+        'Flatten',
+        inputs=[input_nodes[0]],
+        outputs=[op_name],
+        name=op_name
+    )
+
+    input_nodes[0] = op_name
+    fcnode.append(flatten_node)
+
     if no_bias:
         data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
         bias_name = "bias" + str(kwargs["idx"])
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py 
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 22db0d637a3..b4fa4b12c78 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -94,15 +94,33 @@ def get_test_files(name):
 
 
 def forward_pass(sym, arg, aux, data_names, input_data):
-    """ Perform forward pass on given data"""
+    """ Perform forward pass on given data
+    :param sym: Symbol
+    :param arg: Arg params
+    :param aux: Aux params
+    :param data_names: Input names (list)
+    :param input_data: Input data (list). If there is only one input,
+                        pass it as a list. For example, if input is [1, 2],
+                        pass input_data=[[1, 2]]
+    :return: result of forward pass
+    """
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
-    mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
+
+    data_shapes = []
+    data_forward = []
+    for idx in range(len(data_names)):
+        val = input_data[idx]
+        data_shapes.append((data_names[idx], np.shape(val)))
+        data_forward.append(mx.nd.array(val))
+
+    mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
     mod.set_params(arg_params=arg, aux_params=aux,
                    allow_missing=True, allow_extra=True)
+
     # run inference
     batch = namedtuple('Batch', ['data'])
-    mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+    mod.forward(batch(data_forward), is_train=False)
 
     return mod.get_outputs()[0].asnumpy()
 
@@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
     logging.info("Running inference on onnx re-import model in mxnet")
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
 
         # verify the results
         npt.assert_equal(result.shape, output_data.shape)
@@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):
 
     expected_result= []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
         expected_result.append(result)
 
     params = {}
@@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):
 
     actual_result = []
     for input_data, output_data in zip(inputs, outputs):
-        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
[input_data])
         actual_result.append(result)
 
     # verify the results
@@ -235,13 +253,59 @@ def test_square():
     converted_model = onnx_mxnet.export_model(square, params, 
[np.shape(input1)], np.float32, "square.onnx")
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
-    result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+    result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])
 
     numpy_op = np.square(input1)
 
     npt.assert_almost_equal(result, numpy_op)
 
 
+@with_seed()
+def test_fully_connected():
+    def random_arrays(*shapes):
+        """Generate some random numpy arrays."""
+        arrays = [np.random.randn(*s).astype("float32")
+                  for s in shapes]
+        if len(arrays) == 1:
+            return arrays[0]
+        return arrays
+
+    data_names = ['x', 'w', 'b']
+
+    dim_in, dim_out = (3, 4)
+    input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))
+
+    ipsym = []
+    data_shapes = []
+    data_forward = []
+    for idx in range(len(data_names)):
+        val = input_data[idx]
+        data_shapes.append((data_names[idx], np.shape(val)))
+        data_forward.append(mx.nd.array(val))
+        ipsym.append(mx.sym.Variable(data_names[idx]))
+
+    op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], 
num_hidden=dim_out, name='FC')
+
+    model = mx.mod.Module(op, data_names=data_names, label_names=None)
+    model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
+
+    model.init_params()
+
+    args, auxs = model.get_params()
+    params = {}
+    params.update(args)
+    params.update(auxs)
+
+    converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape 
in data_shapes], np.float32, "fc.onnx")
+
+    sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
+    result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
+
+    numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]
+
+    npt.assert_almost_equal(result, numpy_op)
+
+
 def test_softmax():
     input1 = np.random.rand(1000, 1000).astype("float32")
     label1 = np.random.rand(1000)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to