This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 636933d ONNX import/export: Size (#13112)
636933d is described below
commit 636933d424d789661d9e954ebfb569e1a2945a78
Author: Vandana Kannan <[email protected]>
AuthorDate: Fri Dec 7 19:39:47 2018 -0800
ONNX import/export: Size (#13112)
---
python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 8 ++++++++
python/mxnet/contrib/onnx/onnx2mx/_import_helper.py | 3 ++-
python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 4 ++++
tests/python-pytest/onnx/export/onnx_backend_test.py | 3 ++-
tests/python-pytest/onnx/import/test_cases.py | 3 ++-
5 files changed, 18 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0f4b448..0d20c76 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1647,3 +1647,11 @@ def convert_logical_not(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Not', node, kwargs)
+
+
+@mx_op.register("size_array")
+def convert_size(node, **kwargs):
+ """Map MXNet's size_array operator attributes to onnx's Size operator
+ and return the created node.
+ """
+ return create_basic_op_node('Size', node, kwargs)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index f61910f..2ceabae 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -21,7 +21,7 @@
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import add, subtract, multiply, divide, absolute,
negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
-from ._op_translations import softplus, shape, gather, lp_pooling
+from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax,
fully_connected
@@ -139,6 +139,7 @@ _convert_map = {
'Softplus' : softplus,
'Tan' : _tan,
'Shape' : shape,
+ 'Size' : size,
'Gather' : gather,
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling,
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 368b98d..7028325 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -642,6 +642,10 @@ def shape(attrs, inputs, proto_obj):
"""Returns shape of input array."""
return 'shape_array', attrs, inputs
+def size(attrs, inputs, proto_obj):
+ """Returns array containing size of data."""
+ return "size_array", attrs, inputs
+
def reduce_l2(attrs, inputs, proto_obj):
"""Reduce input tensor by l2 normalization."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py
b/tests/python-pytest/onnx/export/onnx_backend_test.py
index be9273e..c9926c4 100644
--- a/tests/python-pytest/onnx/export/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -97,7 +97,8 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
- 'test_shape'
+ 'test_shape',
+ 'test_size'
]
BASIC_MODEL_TESTS = [
diff --git a/tests/python-pytest/onnx/import/test_cases.py
b/tests/python-pytest/onnx/import/test_cases.py
index f41fe92..e0b26cc 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -85,7 +85,8 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
- 'test_depthtospace'
+ 'test_depthtospace',
+ 'test_size'
]
BASIC_MODEL_TESTS = [