This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 636933d  ONNX import/export: Size (#13112)
636933d is described below

commit 636933d424d789661d9e954ebfb569e1a2945a78
Author: Vandana Kannan <[email protected]>
AuthorDate: Fri Dec 7 19:39:47 2018 -0800

    ONNX import/export: Size (#13112)
---
 python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 8 ++++++++
 python/mxnet/contrib/onnx/onnx2mx/_import_helper.py   | 3 ++-
 python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 4 ++++
 tests/python-pytest/onnx/export/onnx_backend_test.py  | 3 ++-
 tests/python-pytest/onnx/import/test_cases.py         | 3 ++-
 5 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0f4b448..0d20c76 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1647,3 +1647,11 @@ def convert_logical_not(node, **kwargs):
     and return the created node.
     """
     return create_basic_op_node('Not', node, kwargs)
+
+
+@mx_op.register("size_array")
+def convert_size(node, **kwargs):
+    """Map MXNet's size_array operator attributes to onnx's Size operator
+    and return the created node.
+    """
+    return create_basic_op_node('Size', node, kwargs)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py 
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index f61910f..2ceabae 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -21,7 +21,7 @@
 from ._op_translations import identity, random_uniform, random_normal
 from ._op_translations import add, subtract, multiply, divide, absolute, 
negative, add_n
 from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
-from ._op_translations import softplus, shape, gather, lp_pooling
+from ._op_translations import softplus, shape, gather, lp_pooling, size
 from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
 from ._op_translations import concat
 from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, 
fully_connected
@@ -139,6 +139,7 @@ _convert_map = {
     'Softplus'          : softplus,
     'Tan'               : _tan,
     'Shape'             : shape,
+    'Size'              : size,
     'Gather'            : gather,
     'HardSigmoid'       : hardsigmoid,
     'LpPool'            : lp_pooling,
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py 
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 368b98d..7028325 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -642,6 +642,10 @@ def shape(attrs, inputs, proto_obj):
     """Returns shape of input array."""
     return 'shape_array', attrs, inputs
 
+def size(attrs, inputs, proto_obj):
+    """Returns array containing size of data."""
+    return "size_array", attrs, inputs
+
 def reduce_l2(attrs, inputs, proto_obj):
     """Reduce input tensor by l2 normalization."""
     new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py 
b/tests/python-pytest/onnx/export/onnx_backend_test.py
index be9273e..c9926c4 100644
--- a/tests/python-pytest/onnx/export/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -97,7 +97,8 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_depthtospace',
     'test_hardsigmoid',
     'test_instancenorm',
-    'test_shape'
+    'test_shape',
+    'test_size'
     ]
 
 BASIC_MODEL_TESTS = [
diff --git a/tests/python-pytest/onnx/import/test_cases.py 
b/tests/python-pytest/onnx/import/test_cases.py
index f41fe92..e0b26cc 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -85,7 +85,8 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_operator_maxpool',
     'test_operator_params',
     'test_operator_permute2',
-    'test_depthtospace'
+    'test_depthtospace',
+    'test_size'
     ]
 
 BASIC_MODEL_TESTS = [

Reply via email to