This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ac4ef21 ONNX export - Clip operator (#12457)
ac4ef21 is described below
commit ac4ef212f6269469f3f3827da49e43fb42f1398f
Author: Vandana Kannan <[email protected]>
AuthorDate: Mon Sep 10 11:59:01 2018 -0700
ONNX export - Clip operator (#12457)
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 24 ++++++++++++++++++++++
.../python-pytest/onnx/export/onnx_backend_test.py | 3 ++-
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 0960776..3ffac96 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1057,6 +1057,30 @@ def convert_flatten(node, **kwargs):
)
return [flatten_node]
+@mx_op.register("clip")
+def convert_clip(node, **kwargs):
+ """Map MXNet's Clip operator attributes to onnx's Clip operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_idx].name
+ attrs = node["attrs"]
+ a_min = np.float(attrs.get('a_min', -np.inf))
+ a_max = np.float(attrs.get('a_max', np.inf))
+
+ clip_node = helper.make_node(
+ "Clip",
+ [input_node],
+ [name],
+ name=name,
+ min=a_min,
+ max=a_max
+ )
+ return [clip_node]
+
def scalar_op_helper(node, op_name, **kwargs):
"""Helper function for scalar arithmetic operations"""
diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py
b/tests/python-pytest/onnx/export/onnx_backend_test.py
index 19bf699..01ae094 100644
--- a/tests/python-pytest/onnx/export/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -89,7 +89,8 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_operator_exp',
'test_operator_maxpool',
'test_operator_params',
- 'test_operator_permute2'
+ 'test_operator_permute2',
+ 'test_clip'
]
BASIC_MODEL_TESTS = [