vandanavk closed pull request #13627: [WIP] [MXNET-895] ONNX import/export: TopK
URL: https://github.com/apache/incubator-mxnet/pull/13627
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0d20c76240b..fd3a4338b0d 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1655,3 +1655,40 @@ def convert_size(node, **kwargs):
     and return the created node.
     """
     return create_basic_op_node('Size', node, kwargs)
+
+
+@mx_op.register("topk")
+def convert_topk(node, **kwargs):
+    """Map MXNet's size_array operator attributes to onnx's Size operator
+    and return the created node.
+    """
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = int(attrs.get('axis', '-1'))
+    k = int(attrs.get('k', '1'))
+    ret_type = attrs.get('ret_typ')
+    outputs = [name+'_output0']
+
+    if ret_type and ret_type == 'both':
+        outputs.append(name + '_output1')
+    else:
+        raise NotImplementedError("ONNX expects both value and indices as 
output")
+
+    topk_node = onnx.helper.make_node(
+        "TopK",
+        input_nodes,
+        [outputs[0], 'cast_'+outputs[1]],
+        axis=axis,
+        k=k,
+        name=name
+    )
+
+    cast_node = onnx.helper.make_node(
+        "Cast",
+        ['cast_'+outputs[1]],
+        [outputs[1]],
+        to=getattr(onnx.TensorProto, 'INT64'),
+        name=outputs[1]
+    )
+
+    return [topk_node, cast_node]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 84db5decd50..e494c137ad2 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -262,7 +262,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, 
in_type, verbose=False)
                     # If converted node is NodeProto, add it in processed 
nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        node_name = converted_node.name if converted_node.name 
else converted_node.output[0]
+                        node_name = converted_node.output[0]
                         if node_name in graph_outputs:
                             onnx_processed_outputs.append(
                                 make_tensor_value_info(
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py 
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index 2ceabaec1dc..10e0c02e0ae 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -37,7 +37,7 @@
 from ._op_translations import reduce_sum_square, reduce_l1, reduce_l2, 
max_roi_pooling
 from ._op_translations import log_softmax, softsign, lesser, greater, equal
 from ._op_translations import logical_and, logical_or, logical_xor, logical_not
-from ._op_translations import mean, depthtospace, spacetodepth
+from ._op_translations import mean, depthtospace, spacetodepth, topk
 
 # convert_map defines maps of ONNX operator names to converter 
functor(callable)
 # defined in the op_translations module.
@@ -144,5 +144,6 @@
     'HardSigmoid'       : hardsigmoid,
     'LpPool'            : lp_pooling,
     'DepthToSpace'      : depthtospace,
-    'SpaceToDepth'      : spacetodepth
+    'SpaceToDepth'      : spacetodepth,
+    'TopK'              : topk,
 }
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py 
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 70283252931..50ec51c8167 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -714,3 +714,10 @@ def spacetodepth(attrs, inputs, proto_obj):
     new_attrs = translation_utils._fix_attribute_names(attrs, 
{'blocksize':'block_size'})
 
     return "space_to_depth", new_attrs, inputs
+
+
+def topk(attrs, inputs, proto_obj):
+    """Returns the top k elements in an input array along the given axis."""
+    new_attrs = translation_utils._add_extra_attributes(attrs,
+                                                        {'ret_typ': 'both'})
+    return 'topk', new_attrs, inputs
diff --git a/tests/python-pytest/onnx/backend_rep.py 
b/tests/python-pytest/onnx/backend_rep.py
index 63836ac848d..9b0660d4529 100644
--- a/tests/python-pytest/onnx/backend_rep.py
+++ b/tests/python-pytest/onnx/backend_rep.py
@@ -80,5 +80,7 @@ def run(self, inputs, **kwargs):
         args = dict(zip(data_names, data_forward))
         exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
         exe.forward(is_train=False)
-        result = exe.outputs[0].asnumpy()
-        return [result]
+        result = []
+        for output in exe.outputs:
+            result.append(output.asnumpy())
+        return result
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py 
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 22db0d637a3..ec12db78510 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -98,13 +98,19 @@ def forward_pass(sym, arg, aux, data_names, input_data):
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
     mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
-    mod.set_params(arg_params=arg, aux_params=aux,
-                   allow_missing=True, allow_extra=True)
+    if not arg and not aux:
+        mod.init_params()
+    else:
+        mod.set_params(arg_params=arg, aux_params=aux,
+                       allow_missing=True, allow_extra=True)
     # run inference
     batch = namedtuple('Batch', ['data'])
     mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
 
-    return mod.get_outputs()[0].asnumpy()
+    result = []
+    for output in mod.get_outputs():
+        result.append(output.asnumpy())
+    return result
 
 
 def test_models(model_name, input_shape, output_shape):
@@ -139,8 +145,8 @@ def test_models(model_name, input_shape, output_shape):
         result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
 
         # verify the results
-        npt.assert_equal(result.shape, output_data.shape)
-        npt.assert_almost_equal(output_data, result, decimal=3)
+        npt.assert_equal(result[0].shape, output_data.shape)
+        npt.assert_almost_equal(output_data, result[0], decimal=3)
     logging.info(model_name + " conversion successful")
 
 
@@ -157,7 +163,7 @@ def test_model_accuracy(model_name, input_shape):
     expected_result= []
     for input_data, output_data in zip(inputs, outputs):
         result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
-        expected_result.append(result)
+        expected_result.append(result[0])
 
     params = {}
     params.update(arg_params)
@@ -179,7 +185,7 @@ def test_model_accuracy(model_name, input_shape):
     actual_result = []
     for input_data, output_data in zip(inputs, outputs):
         result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
-        actual_result.append(result)
+        actual_result.append(result[0])
 
     # verify the results
     for expected, actual in zip(expected_result, actual_result):
@@ -239,7 +245,7 @@ def test_square():
 
     numpy_op = np.square(input1)
 
-    npt.assert_almost_equal(result, numpy_op)
+    npt.assert_almost_equal(result[0], numpy_op)
 
 
 def test_softmax():
@@ -261,7 +267,35 @@ def test_softmax():
     result = forward_pass(sym, arg_params, aux_params, ['ipsym'], input1)
 
     # Comparing result of forward pass before using onnx export, import
-    npt.assert_almost_equal(result, softmax_out)
+    npt.assert_almost_equal(result[0], softmax_out)
+
+
+@with_seed()
+def test_topk():
+    input1 = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], 
dtype=np.float32)
+    k = 3
+    dtype = 'int32'
+    inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, 
shape=np.shape(input1))]
+    sym = mx.sym.topk(mx.sym.Variable('input1'), k=k, ret_typ='both', 
dtype=dtype)
+    sym_output = forward_pass(sym, None, None, ['input1'], input1)
+
+    outputs = [helper.make_tensor_value_info("output1", TensorProto.FLOAT, 
shape=np.shape(sym_output[0])),
+               helper.make_tensor_value_info("output2", TensorProto.FLOAT, 
shape=np.shape(sym_output[1]))]
+
+    nodes = [helper.make_node("TopK", ["input1"], ["output1", "output2"], k=k)]
+
+    graph = helper.make_graph(nodes,
+                              "topk_test",
+                              inputs,
+                              outputs)
+
+    spacetodepth_model = helper.make_model(graph)
+
+    bkd_rep = backend.prepare(spacetodepth_model)
+    output = bkd_rep.run([input1])
+
+    npt.assert_almost_equal(output, sym_output)
+
 
 @with_seed()
 def test_comparison_ops():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to