nswamy closed pull request #13112: [MXNET-897] ONNX import/export: Size
URL: https://github.com/apache/incubator-mxnet/pull/13112
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 86767a66712..e326c71020c 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1645,3 +1645,11 @@ def convert_logical_not(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Not', node, kwargs)
+
+
+@mx_op.register("size_array")
+def convert_size(node, **kwargs):
+ """Map MXNet's size_array operator attributes to onnx's Size operator
+ and return the created node.
+ """
+ return create_basic_op_node('Size', node, kwargs)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index f61910f838e..2ceabaec1dc 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -21,7 +21,7 @@
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import add, subtract, multiply, divide, absolute,
negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
-from ._op_translations import softplus, shape, gather, lp_pooling
+from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax,
fully_connected
@@ -139,6 +139,7 @@
'Softplus' : softplus,
'Tan' : _tan,
'Shape' : shape,
+ 'Size' : size,
'Gather' : gather,
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling,
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 368b98d662b..70283252931 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -642,6 +642,10 @@ def shape(attrs, inputs, proto_obj):
"""Returns shape of input array."""
return 'shape_array', attrs, inputs
+def size(attrs, inputs, proto_obj):
+ """Returns array containing size of data."""
+ return "size_array", attrs, inputs
+
def reduce_l2(attrs, inputs, proto_obj):
"""Reduce input tensor by l2 normalization."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py
b/tests/python-pytest/onnx/export/onnx_backend_test.py
index be9273eb6fa..c9926c4d5e1 100644
--- a/tests/python-pytest/onnx/export/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -97,7 +97,8 @@
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
- 'test_shape'
+ 'test_shape',
+ 'test_size'
]
BASIC_MODEL_TESTS = [
diff --git a/tests/python-pytest/onnx/import/test_cases.py
b/tests/python-pytest/onnx/import/test_cases.py
index f41fe92352d..e0b26cc4983 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -85,7 +85,8 @@
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
- 'test_depthtospace'
+ 'test_depthtospace',
+ 'test_size'
]
BASIC_MODEL_TESTS = [
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services