This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new aa00f4b [v1.x] ONNX Supoort for MXNet repeat op (#19732)
aa00f4b is described below
commit aa00f4b300cdabcb6a99ee6e80727c35369c523d
Author: Zhaoqi Zhu <[email protected]>
AuthorDate: Mon Jan 11 14:37:53 2021 -0800
[v1.x] ONNX Supoort for MXNet repeat op (#19732)
* repeat op
* remove extra print
* restore sanity
* Update test_operators.py
* fix axis=1 case
* Update _op_translations.py
* Update test_operators.py
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 81 +++++++++++++++++++++-
tests/python-pytest/onnx/test_operators.py | 10 ++-
2 files changed, 89 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index bc4b414..57ef546 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2778,7 +2778,86 @@ def convert_arange(node, **kwargs):
create_const_scalar_node(name+"_start", np.array([start],
dtype=dtype), kwargs),
create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype),
kwargs),
create_const_scalar_node(name+"_step", np.array([step], dtype=dtype),
kwargs),
- make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name])
+ make_node("Range", [name+"_start", name+"_stop", name+"_step"],
[name], name=name)
]
return nodes
+
+
+@mx_op.register('repeat')
+def convert_repeat(node, **kwargs):
+ """Map MXNet's repeat operator attributes to onnx's Tile operator.
+ """
+ from onnx.helper import make_node
+ from onnx import TensorProto
+ name, input_nodes, attrs = get_inputs(node, kwargs)
+
+ opset_version = kwargs['opset_version']
+ if opset_version < 11:
+ raise AttributeError('ONNX opset 11 or greater is required to export
this operator')
+
+ repeats = int(attrs.get('repeats', 1))
+ axis = attrs.get('axis', 'None')
+
+ if repeats <= 0:
+ raise NotImplementedError('repeat operator does not support parameter
repeats==0')
+
+ nodes = []
+ if axis == 'None':
+ nodes += [
+ create_tensor([repeats], name+'_rep', kwargs['initializer']),
+ create_tensor([1, repeats], name+'_repeats',
kwargs['initializer']),
+ make_node('Shape', [input_nodes[0]], [name+'_shape']),
+ make_node('ReduceProd', [name+'_shape'], [name+'_size']),
+ make_node('Reshape', [input_nodes[0], name+'_size'],
[name+'_flat']),
+ make_node('Unsqueeze', [name+'_flat'], [name+'_unsqueeze'],
axes=[-1]),
+ make_node('Tile', [name+'_unsqueeze', name+'_repeats'],
[name+'_tile']),
+ make_node('Mul', [name+'_size', name+'_rep'], [name+'_new_size']),
+ make_node('Reshape', [name+'_tile', name+'_new_size'], [name],
name=name)
+ ]
+ else:
+ axis = int(axis)
+ repeats -= 1
+ nodes += [
+ create_tensor([repeats], name+'_repeats', kwargs['initializer']),
+ create_tensor([1], name+'_1', kwargs['initializer']),
+ create_tensor([0], name+'_0', kwargs['initializer']),
+ create_tensor([], name+'_void', kwargs['initializer']),
+ create_tensor([axis], name+'_axis', kwargs['initializer']),
+ make_node('Shape', [input_nodes[0]], [name+'_shape']),
+ make_node('Shape', [name+'_shape'], [name+'_dim']),
+ make_node('Reshape', [name+'_dim', name+'_void'], [name+'_dim_s']),
+ make_node('Range', [name+'_0', name+'_dim_s', name+'_1'],
[name+'_range'])
+ ]
+ if axis < 0:
+ nodes += [
+ make_node('Add', [name+'_axis', name+'_dim'],
[name+'_true_axis']),
+ make_node('Equal', [name+'_range', name+'_true_axis'],
[name+'_one_hot'])
+ ]
+ else:
+ nodes += [
+ make_node('Equal', [name+'_range', name+'_axis'],
[name+'_one_hot'])
+ ]
+ nodes += [
+ make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'],
to=int(TensorProto.INT64)),
+ make_node('Mul', [name+'_repeats', name+'_one_hot_int'],
[name+'_mul']),
+ make_node('Add', [name+'_mul', name+'_1'], [name+'_add']),
+ make_node('Concat', [name+'_1', name+'_add'],
[name+'_repeats_tensor'], axis=0)
+ ]
+ if axis == -1:
+ nodes += [
+ make_node('Concat', [name+'_shape', name+'_1'],
[name+'_unsqueeze_shape'], axis=0),
+ make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'],
+ [name+'_unsqueeze'])
+ ]
+ else:
+ nodes += [
+ make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'],
axes=[axis+1])
+ ]
+ nodes += [
+ make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'],
[name+'_tile']),
+ make_node('Mul', [name+'_shape', name+'_add'],
[name+'_new_shape']),
+ make_node('Reshape', [name+'_tile', name+'_new_shape'], [name],
name=name)
+ ]
+
+ return nodes
diff --git a/tests/python-pytest/onnx/test_operators.py
b/tests/python-pytest/onnx/test_operators.py
index 7800008..c17a03b 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -354,6 +354,15 @@ def test_onnx_export_softmax(tmp_path, dtype):
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32',
'int64'])
[email protected]('axis', [None, 0, 1, 2, -1, -2, -3])
[email protected]('repeats', [2, 1, 3])
+def test_onnx_export_repeat(tmp_path, dtype, axis, repeats):
+ x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3))
+ M = def_model('repeat', axis=axis, repeats=repeats)
+ op_export_test('repeat', M, [x], tmp_path)
+
+
[email protected]('dtype', ['float16', 'float32', 'float64', 'int32',
'int64'])
@pytest.mark.parametrize('params', [{'height': 7, 'width': 13},
{'height': 10, 'width': 16},
{'height': 3, 'width': 5},
@@ -369,4 +378,3 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path,
dtype, params):
x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8))
M = def_model('contrib.BilinearResize2D', **params)
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)
-