This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new aa00f4b  [v1.x] ONNX Supoort for MXNet repeat op (#19732)
aa00f4b is described below

commit aa00f4b300cdabcb6a99ee6e80727c35369c523d
Author: Zhaoqi Zhu <[email protected]>
AuthorDate: Mon Jan 11 14:37:53 2021 -0800

    [v1.x] ONNX Supoort for MXNet repeat op (#19732)
    
    * repeat op
    
    * remove extra print
    
    * restore sanity
    
    * Update test_operators.py
    
    * fix axis=1 case
    
    * Update _op_translations.py
    
    * Update test_operators.py
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 81 +++++++++++++++++++++-
 tests/python-pytest/onnx/test_operators.py         | 10 ++-
 2 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index bc4b414..57ef546 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2778,7 +2778,86 @@ def convert_arange(node, **kwargs):
         create_const_scalar_node(name+"_start", np.array([start], 
dtype=dtype), kwargs),
         create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), 
kwargs),
         create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), 
kwargs),
-        make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name])
+        make_node("Range", [name+"_start", name+"_stop", name+"_step"], 
[name], name=name)
     ]
 
     return nodes
+
+
+@mx_op.register('repeat')
+def convert_repeat(node, **kwargs):
+    """Map MXNet's repeat operator attributes to onnx's Tile operator.
+    """
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    opset_version = kwargs['opset_version']
+    if opset_version < 11:
+        raise AttributeError('ONNX opset 11 or greater is required to export 
this operator')
+
+    repeats = int(attrs.get('repeats', 1))
+    axis = attrs.get('axis', 'None')
+
+    if repeats <= 0:
+        raise NotImplementedError('repeat operator does not support parameter 
repeats==0')
+
+    nodes = []
+    if axis == 'None':
+        nodes += [
+            create_tensor([repeats], name+'_rep', kwargs['initializer']),
+            create_tensor([1, repeats], name+'_repeats', 
kwargs['initializer']),
+            make_node('Shape', [input_nodes[0]], [name+'_shape']),
+            make_node('ReduceProd', [name+'_shape'], [name+'_size']),
+            make_node('Reshape', [input_nodes[0], name+'_size'], 
[name+'_flat']),
+            make_node('Unsqueeze', [name+'_flat'], [name+'_unsqueeze'], 
axes=[-1]),
+            make_node('Tile', [name+'_unsqueeze', name+'_repeats'], 
[name+'_tile']),
+            make_node('Mul', [name+'_size', name+'_rep'], [name+'_new_size']),
+            make_node('Reshape', [name+'_tile', name+'_new_size'], [name], 
name=name)
+            ]
+    else:
+        axis = int(axis)
+        repeats -= 1
+        nodes += [
+            create_tensor([repeats], name+'_repeats', kwargs['initializer']),
+            create_tensor([1], name+'_1', kwargs['initializer']),
+            create_tensor([0], name+'_0', kwargs['initializer']),
+            create_tensor([], name+'_void', kwargs['initializer']),
+            create_tensor([axis], name+'_axis', kwargs['initializer']),
+            make_node('Shape', [input_nodes[0]], [name+'_shape']),
+            make_node('Shape', [name+'_shape'], [name+'_dim']),
+            make_node('Reshape', [name+'_dim', name+'_void'], [name+'_dim_s']),
+            make_node('Range', [name+'_0', name+'_dim_s', name+'_1'], 
[name+'_range'])
+            ]
+        if axis < 0:
+            nodes += [
+                make_node('Add', [name+'_axis', name+'_dim'], 
[name+'_true_axis']),
+                make_node('Equal', [name+'_range', name+'_true_axis'], 
[name+'_one_hot'])
+                ]
+        else:
+            nodes += [
+                make_node('Equal', [name+'_range', name+'_axis'], 
[name+'_one_hot'])
+                ]
+        nodes += [
+            make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'], 
to=int(TensorProto.INT64)),
+            make_node('Mul', [name+'_repeats', name+'_one_hot_int'], 
[name+'_mul']),
+            make_node('Add', [name+'_mul', name+'_1'], [name+'_add']),
+            make_node('Concat', [name+'_1', name+'_add'], 
[name+'_repeats_tensor'], axis=0)
+            ]
+        if axis == -1:
+            nodes += [
+                make_node('Concat', [name+'_shape', name+'_1'], 
[name+'_unsqueeze_shape'], axis=0),
+                make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'],
+                          [name+'_unsqueeze'])
+                ]
+        else:
+            nodes += [
+                make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], 
axes=[axis+1])
+                ]
+        nodes += [
+            make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'], 
[name+'_tile']),
+            make_node('Mul', [name+'_shape', name+'_add'], 
[name+'_new_shape']),
+            make_node('Reshape', [name+'_tile', name+'_new_shape'], [name], 
name=name)
+            ]
+
+    return nodes
diff --git a/tests/python-pytest/onnx/test_operators.py 
b/tests/python-pytest/onnx/test_operators.py
index 7800008..c17a03b 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -354,6 +354,15 @@ def test_onnx_export_softmax(tmp_path, dtype):
 
 
 @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 
'int64'])
[email protected]('axis', [None, 0, 1, 2, -1, -2, -3])
[email protected]('repeats', [2, 1, 3])
+def test_onnx_export_repeat(tmp_path, dtype, axis, repeats):
+    x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3))
+    M = def_model('repeat', axis=axis, repeats=repeats)
+    op_export_test('repeat', M, [x], tmp_path)
+
+
[email protected]('dtype', ['float16', 'float32', 'float64', 'int32', 
'int64'])
 @pytest.mark.parametrize('params', [{'height': 7, 'width': 13},
                                     {'height': 10, 'width': 16},
                                     {'height': 3, 'width': 5},
@@ -369,4 +378,3 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, 
dtype, params):
     x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8))
     M = def_model('contrib.BilinearResize2D', **params)
     op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)
-

Reply via email to