This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7d91602 [MXNET-533] MXNet-ONNX export (#11213)
7d91602 is described below
commit 7d91602ba771d973360f8a0c66c976c67f700aa3
Author: Roshani Nagmote <[email protected]>
AuthorDate: Mon Jun 25 09:43:20 2018 -0700
[MXNET-533] MXNet-ONNX export (#11213)
* Resolve conflicts
* Export module Test Framework
* refactoring export to work with pretrained models
* comments added
* 1. Refactored export module.
2. Refactored test framework to support ONNX backened tests.
2. Added Operator support:
- Convolution2D
- BatchNorm
- Add
* Added Arithmetic operators:
- Add, Sub, Mul, Div, Sum
* Added operator support:
- sigmoid, relu, pad( constant, edge, reflect), tanh
- enabled corresponding ONNX backend tests.
* Enabled ONNX tests: test_conv, test_basic_conv
Added Operators :
Ceil, Floor
* Added support for:
MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul
* adding more operators
* Added Operator support:
ArgMax, ArgMin, maximum, minimum
* Enabled more BASIC_MODEL tests
* Added power operator tests
* Added support for reshape. ONNX only supports 0, -1 special values.
Added only for these.
Fixed logic error with convert_string_to_list()
* some tests enabled
* enabling squeezenet
* LRN Op support
* mul_scalar modified to take scalar input
* cleaning some code
* Resolving conlicts on rebase
* Resolving rebase conflicts
* id mapping updated for all operators
* save onnx models added, some code cleanup
* enabled more tests
* conv pad calc fixed
* reshape op fix
* Added support for elu, leakyRelu, prelu
* Cleanup
- Removed run_node, not needed anymore.
- Used correct get_metadata api
* valueinfoproto fix, googlenet test added
* Removed redundant code.
- run_node
- Using correct get_metadata_api
* dilation added
* Lint fixes
* lint fixes
* some fixes to make export work with onx1.2.1
* enabled more tests
* mxnet_export_test file added
* duplicate file deleted
* reduce ops added
* some small fixes
* some lint fixes
* Add tests for inception_v1 and inception_v2
* Add CI runs for export module
* docstring added
* lint fixes, pooling attr fix
* fix
* fix global_pool
* CI run fix
* code cleanup
* lint fix
* some code cleanup
* pad in pooling added
* slicechannel notimplementederror raised
* Added required license comments
* Lint fixes
* lint fix
* lint fix
* lint fix
* lint fix
* Correct license statement
* Adding onnx a runtime dependency
* Fix import module error for string_types
* Making ONNX runtime dependency
* fixing some comments
* addressing some comments
* params rename
* lint fixes
* fixes
* spatial disabled, path fixed
* fixing some comments
* Added support for remaining act_type(softsign, sigmoid, softrelu) in
Activation operator
* changing import
* adding some comments
* Add squeeze op
* Refactored logic to handle extra node(output label node) for saved mxnet
model
Added comments
* minor fix for squeeze operator.
Also, added error handling
* identity operator added
* scalar ops added
* Renamed onnx support folders to mark it public folders
Changed underline files public or private as per usage
Resolved conflicts with the latest
* Added support L2Normalization op
Added some error checking
* added comments and warning
* added comments and warning
* doc API ref added
---
LICENSE | 52 +-
ci/docker/runtime_functions.sh | 2 +
docs/api/python/contrib/onnx.md | 2 +
python/mxnet/contrib/onnx/__init__.py | 5 +-
python/mxnet/contrib/onnx/mx2onnx/LICENSE | 44 +
.../contrib/onnx/{_import => mx2onnx}/__init__.py | 10 +-
.../mxnet/contrib/onnx/mx2onnx/_export_helper.py | 65 +
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1863 ++++++++++++++++++++
python/mxnet/contrib/onnx/mx2onnx/export_model.py | 95 +
python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 347 ++++
.../contrib/onnx/{_import => onnx2mx}/__init__.py | 0
.../import_helper.py => onnx2mx/_import_helper.py} | 39 +-
.../_op_translations.py} | 6 +-
.../_translation_utils.py} | 1 +
.../onnx/{_import => onnx2mx}/import_model.py | 0
.../onnx/{_import => onnx2mx}/import_onnx.py | 2 +-
.../onnx/{_import => onnx2mx}/import_to_gluon.py | 0
.../{import/mxnet_backend.py => export/backend.py} | 49 +-
.../mxnet_backend_rep.py => export/backend_rep.py} | 8 +-
.../python-pytest/onnx/export/mxnet_export_test.py | 191 ++
.../test_cases.py => export/onnx_backend_test.py} | 61 +-
tests/python-pytest/onnx/import/gluon_backend.py | 6 +-
tests/python-pytest/onnx/import/mxnet_backend.py | 3 +-
.../python-pytest/onnx/import/mxnet_backend_rep.py | 1 -
tests/python-pytest/onnx/import/test_cases.py | 1 +
25 files changed, 2783 insertions(+), 70 deletions(-)
diff --git a/LICENSE b/LICENSE
index 158bd37..a8b57e5 100644
--- a/LICENSE
+++ b/LICENSE
@@ -298,8 +298,6 @@
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
=======================================================================================
Other Licenses
=======================================================================================
@@ -512,3 +510,53 @@
For details, see, 3rdparty/dmlc-core/include/dmlc/concurrentqueue.h
=======================================================================================
+
+ 11. ONNX Export module
+ For details, see, python/mxnet/contrib/onnx/_export/LICENSE
+
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements. See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership. The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License. You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied. See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ # Based on
+ # https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/#
+ # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+ #
+ # Redistribution and use in source and binary forms, with or without
+ # modification, are permitted provided that the following conditions
+ # are met:
+ # * Redistributions of source code must retain the above copyright
+ # notice, this list of conditions and the following disclaimer.
+ # * Redistributions in binary form must reproduce the above copyright
+ # notice, this list of conditions and the following disclaimer in the
+ # documentation and/or other materials provided with the distribution.
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
+ # contributors may be used to endorse or promote products derived
+ # from this software without specific prior written permission.
+ #
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 6e6abf0..0798047 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -700,6 +700,8 @@ integrationtest_ubuntu_cpu_onnx() {
pytest tests/python-pytest/onnx/import/mxnet_backend_test.py
pytest tests/python-pytest/onnx/import/onnx_import_test.py
pytest tests/python-pytest/onnx/import/gluon_backend_test.py
+ pytest tests/python-pytest/onnx/export/onnx_backend_test.py
+ python tests/python-pytest/onnx/export/mxnet_export_test.py
}
integrationtest_ubuntu_gpu_python() {
diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md
index 6fb546f..8cd6198 100644
--- a/docs/api/python/contrib/onnx.md
+++ b/docs/api/python/contrib/onnx.md
@@ -24,6 +24,7 @@ This document describes all the ONNX-MXNet APIs.
mxnet.contrib.onnx.import_model
mxnet.contrib.onnx.get_model_metadata
+ mxnet.contrib.onnx.export_model
```
## ONNX Tutorials
@@ -46,6 +47,7 @@ This document describes all the ONNX-MXNet APIs.
.. automodule:: mxnet.contrib.onnx
:members: import_model
:members: get_model_metadata
+ :members: export_model
```
diff --git a/python/mxnet/contrib/onnx/__init__.py
b/python/mxnet/contrib/onnx/__init__.py
index 4f9296d..9f27060 100644
--- a/python/mxnet/contrib/onnx/__init__.py
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -16,5 +16,6 @@
# under the License.
"""Module for ONNX model format support for Apache MXNet."""
-from ._import.import_model import import_model, get_model_metadata
-from ._import.import_to_gluon import import_to_gluon
+from .onnx2mx.import_model import import_model, get_model_metadata
+from .onnx2mx.import_to_gluon import import_to_gluon
+from .mx2onnx.export_model import export_model
diff --git a/python/mxnet/contrib/onnx/mx2onnx/LICENSE
b/python/mxnet/contrib/onnx/mx2onnx/LICENSE
new file mode 100644
index 0000000..3abe1ee
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/LICENSE
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/#
+# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py
b/python/mxnet/contrib/onnx/mx2onnx/__init__.py
similarity index 84%
copy from python/mxnet/contrib/onnx/_import/__init__.py
copy to python/mxnet/contrib/onnx/mx2onnx/__init__.py
index d0411df..238174e 100644
--- a/python/mxnet/contrib/onnx/_import/__init__.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/__init__.py
@@ -16,7 +16,9 @@
# under the License.
# coding: utf-8
-"""ONNX Import module"""
-from . import import_model
-from . import import_onnx
-from . import import_to_gluon
+"""ONNX Export module"""
+from __future__ import absolute_import
+
+from . import export_model
+from . import export_onnx
+from . import _op_translations
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py
b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py
new file mode 100644
index 0000000..781fb4c
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""export helper functions"""
+# coding: utf-8
+import os
+import logging
+import mxnet as mx
+
+
+def load_module(sym_filepath, params_filepath):
+ """Loads the MXNet model file and
+ returns MXNet symbol and params (weights).
+
+ Parameters
+ ----------
+ json_path : str
+ Path to the json file
+ params_path : str
+ Path to the params file
+
+ Returns
+ -------
+ sym : MXNet symbol
+ Model symbol object
+
+ params : params object
+ Model weights including both arg and aux params.
+ """
+ if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)):
+ raise ValueError("Symbol and params files provided are invalid")
+ else:
+ try:
+ # reads symbol.json file from given path and
+ # retrieves model prefix and number of epochs
+ model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0]
+ params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1)
+ # Setting num_epochs to 0 if not present in filename
+ num_epochs = 0 if len(params_file_list) == 1 else
int(params_file_list[1])
+ except IndexError:
+ logging.info("Model and params name should be in format: "
+ "prefix-symbol.json, prefix-epoch.params")
+ raise
+
+ sym, arg_params, aux_params = mx.model.load_checkpoint(model_name,
num_epochs)
+
+ # Merging arg and aux parameters
+ params = {}
+ params.update(arg_params)
+ params.update(aux_params)
+
+ return sym, params
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
new file mode 100644
index 0000000..5f5561a
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -0,0 +1,1863 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/
+# mx2onnx_converter_functions.py
+# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# coding: utf-8
+# pylint: disable=too-many-locals,no-else-return,too-many-lines
+# pylint: disable=anomalous-backslash-in-string,eval-used
+"""
+Conversion Functions for common layers.
+Add new functions here with a decorator.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import re
+import logging
+import numpy as np
+from .export_onnx import MXNetGraph as mx_op
+
+def import_onnx_modules():
+ """ To make sure ONNX is runtime dependency, it is imported used only when
needed"""
+ try:
+ from onnx import helper, numpy_helper, mapping
+ except ImportError:
+ raise ImportError("Onnx and protobuf need to be installed. "
+ + "Instructions to install -
https://github.com/onnx/onnx")
+ return helper, numpy_helper, mapping
+
+
+def parse_helper(attrs, attrs_name, alt_value=None):
+ """Helper function to parse operator attributes in required format."""
+ tuple_re = re.compile('\([0-9L|,| ]+\)')
+ if attrs is None:
+ return alt_value
+ attrs_str = None if attrs.get(attrs_name) is None else
str(attrs.get(attrs_name))
+ if attrs_str is None:
+ return alt_value
+ attrs_match = tuple_re.search(attrs_str)
+ if attrs_match is not None:
+ if attrs_match.span() == (0, len(attrs_str)):
+ dims = eval(attrs_str)
+ return dims
+ else:
+ raise AttributeError("Malformed %s dimensions: %s" % (attrs_name,
str(attrs_str)))
+ return alt_value
+
+def transform_padding(pad_width):
+ """Helper function to convert padding format for pad operator.
+ """
+ num_pad_values = len(pad_width)
+ onnx_pad_width = [0]*num_pad_values
+
+ start_index = 0
+ # num_pad_values will always be multiple of 2
+ end_index = int(num_pad_values/2)
+ for idx in range(0, num_pad_values):
+ if idx % 2 == 0:
+ onnx_pad_width[start_index] = pad_width[idx]
+ start_index += 1
+ else:
+ onnx_pad_width[end_index] = pad_width[idx]
+ end_index += 1
+
+ return onnx_pad_width
+
+
+def convert_string_to_list(string_val):
+ """Helper function to convert string to list.
+ Used to convert shape attribute string to list format.
+ """
+ result_list = []
+
+ list_string = string_val.split(',')
+ for val in list_string:
+ val = str(val.strip())
+ val = val.replace("(", "")
+ val = val.replace(")", "")
+ val = val.replace("L", "")
+ val = val.replace("[", "")
+ val = val.replace("]", "")
+ if val != "" and val != "None":
+ result_list.append(int(val))
+
+ return result_list
+
+@mx_op.register("null")
+def convert_weights_and_inputs(node, **kwargs):
+ """Helper function to convert weights and inputs.
+ """
+
+ helper, _, mapping = import_onnx_modules()
+ name = node["name"]
+
+ if kwargs["is_input"] is False:
+ weights = kwargs["weights"]
+ initializer = kwargs["initializer"]
+ np_arr = weights[name]
+ data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
+ dims = np.shape(np_arr)
+
+ tensor_node = helper.make_tensor_value_info(name, data_type, dims)
+
+ initializer.append(
+ helper.make_tensor(
+ name=name,
+ data_type=data_type,
+ dims=dims,
+ vals=np_arr.flatten().tolist(),
+ raw=False,
+ )
+ )
+
+ return [tensor_node]
+ else:
+ tval_node = helper.make_tensor_value_info(name, kwargs["in_type"],
kwargs["in_shape"])
+ return [tval_node]
+
+
+@mx_op.register("Convolution")
+def convert_convolution(node, **kwargs):
+ """Map MXNet's convolution operator attributes to onnx's Conv operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+
+ num_inputs = len(inputs)
+
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
+ weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name
+
+ if num_inputs > 2:
+ bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name
+
+ attrs = node.get("attrs")
+
+ kernel_dims = list(parse_helper(attrs, "kernel"))
+ stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
+ pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+ num_group = int(attrs.get("num_group", 1))
+ dilations = list(parse_helper(attrs, "dilate", [1, 1]))
+
+ pad_dims = pad_dims + pad_dims
+
+ input_nodes = [input_node, weights_node]
+ if num_inputs > 2:
+ input_nodes.append(bias_node)
+
+ conv_node = helper.make_node(
+ "Conv",
+ inputs=input_nodes,
+ outputs=[name],
+ kernel_shape=kernel_dims,
+ strides=stride_dims,
+ dilations=dilations,
+ pads=pad_dims,
+ group=num_group,
+ name=name
+ )
+
+ return [conv_node]
+
+
+@mx_op.register("FullyConnected")
+def convert_fully_connected(node, **kwargs):
+ """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ weight_node_id = kwargs["index_lookup"][inputs[1][0]]
+ bias_node_id = kwargs["index_lookup"][inputs[2][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_node_id]
+ weights_node = proc_nodes[weight_node_id]
+ bias_node = proc_nodes[bias_node_id]
+
+ input_name = input_node.name
+ weights_name = weights_node.name
+ bias_name = bias_node.name
+
+ node = helper.make_node(
+ "Gemm",
+ [input_name, weights_name, bias_name], # input (A, B, C) - C can be
in place
+ [name], # output
+ alpha=1.0,
+ beta=1.0,
+ transA=False,
+ transB=True,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("BatchNorm")
+def convert_batchnorm(node, **kwargs):
+ """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization
operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ attrs = node["attrs"]
+ momentum = float(node.get("attrs", {}).get("momentum", 0.9))
+ eps = float(attrs.get("eps", 0.001))
+
+ data_idx = kwargs["index_lookup"][inputs[0][0]]
+ gamma_idx = kwargs["index_lookup"][inputs[1][0]]
+ beta_idx = kwargs["index_lookup"][inputs[2][0]]
+ moving_mean_idx = kwargs["index_lookup"][inputs[3][0]]
+ moving_var_idx = kwargs["index_lookup"][inputs[4][0]]
+
+ data_node = proc_nodes[data_idx].name
+ gamma_node = proc_nodes[gamma_idx].name
+ beta_node = proc_nodes[beta_idx].name
+
+ mov_mean_node = proc_nodes[moving_mean_idx]
+ mov_mean_node = mov_mean_node.name
+ mov_var_node = proc_nodes[moving_var_idx].name
+
+ bn_node = helper.make_node(
+ "BatchNormalization",
+ [data_node,
+ gamma_node, # scale
+ beta_node, # bias
+ mov_mean_node,
+ mov_var_node
+ ],
+ [name],
+ name=name,
+ epsilon=eps,
+ momentum=momentum,
+ # MXNet computes mean and variance per feature for batchnorm
+ # Default for onnx is across all spatial features. So disabling the
parameter.
+ spatial=0
+ )
+ return [bn_node]
+
+
+@mx_op.register("tanh")
+def convert_tanh(node, **kwargs):
+ """Map MXNet's tanh operator attributes to onnx's Tanh operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_node_idx].name
+
+ node = helper.make_node(
+ 'Tanh',
+ [input_node],
+ [name],
+ name=name
+ )
+ return [node]
+
+#Basic neural network functions
+@mx_op.register("sigmoid")
+def convert_sigmoid(node, **kwargs):
+ """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_node_idx].name
+
+ node = helper.make_node(
+ 'Sigmoid',
+ [input_node],
+ [name],
+ name=name
+ )
+ return [node]
+
+@mx_op.register("relu")
+def convert_relu(node, **kwargs):
+ """Map MXNet's relu operator attributes to onnx's Relu operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_node_idx].name
+
+ node = helper.make_node(
+ 'Relu',
+ [input_node],
+ [name],
+ name=name
+ )
+
+ return [node]
+
+@mx_op.register("Activation")
+def convert_activation(node, **kwargs):
+ """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+
+ proc_nodes = kwargs["proc_nodes"]
+ attrs = node["attrs"]
+ act_type = attrs["act_type"]
+
+ inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_idx].output[0]
+
+ # Creating a dictionary here, but if this titlecase pattern
+ # mxnet_name.title()
+ act_types = {
+ "tanh": "Tanh",
+ "relu": "Relu",
+ "sigmoid": "Sigmoid",
+ "softrelu": "Softplus",
+ "softsign": "Softsign"
+ }
+
+ act_name = act_types.get(act_type)
+ if act_name:
+ node = helper.make_node(
+ act_name,
+ [input_node],
+ [name],
+ name=name
+ )
+ else:
+ raise AttributeError(
+ "Activation %s not implemented or recognized in the converter" %
act_type
+ )
+
+ return [node]
+
+
+@mx_op.register("Pad")
+def convert_pad(node, **kwargs):
+ """Map MXNet's pad operator attributes to onnx's Pad operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ attrs = node["attrs"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_idx].name
+
+ mxnet_pad_width = convert_string_to_list(attrs.get("pad_width"))
+ onnx_pad_width = transform_padding(mxnet_pad_width)
+
+ pad_mode = attrs.get("mode")
+
+ if pad_mode == "constant":
+ pad_value = float(attrs.get("constant_value")) \
+ if "constant_value" in attrs else 0.0
+ node = helper.make_node(
+ 'Pad',
+ inputs=[input_node],
+ outputs=[name],
+ mode='constant',
+ value=pad_value,
+ pads=onnx_pad_width,
+ name=name
+ )
+ else:
+ node = helper.make_node(
+ 'Pad',
+ inputs=[input_node],
+ outputs=[name],
+ mode=pad_mode,
+ pads=onnx_pad_width,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("_linalg_gemm2")
+def convert_linalg_gemm2(node, **kwargs):
+ """Map MXNet's _linalg_gemm2 operator attributes to onnx's
+ MatMul and Transpose operators based on the values set for
+ transpose_a, transpose_b attributes.
+ Return multiple nodes created.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ node_inputs = node["inputs"]
+ name = node["name"]
+
+ input_a_idx = kwargs["index_lookup"][node_inputs[0][0]]
+ input_node_a = proc_nodes[input_a_idx].name
+ input_b_idx = kwargs["index_lookup"][node_inputs[1][0]]
+ input_node_b = proc_nodes[input_b_idx].name
+
+ # Getting the attributes and assigning default values.
+ if "attrs" in node:
+ attrs = node["attrs"]
+ alpha = float(attrs["alpha"])
+ trans_a = int(attrs["transpose_a"])
+ trans_b = int(attrs["transpose_b"])
+ else:
+ alpha = 1.0
+ trans_a = 0
+ trans_b = 0
+
+ op_name = "transpose" + str(kwargs["idx"])
+
+ if alpha == 1.0 and trans_a == 0 and trans_b == 0:
+ matmul_node = helper.make_node(
+ 'MatMul',
+ inputs=[input_node_a, input_node_b],
+ outputs=[name],
+ name=name
+ )
+ return [matmul_node]
+ elif trans_a == 1 and trans_b == 0:
+ op_name = "transpose" + str(kwargs["idx"])
+ node_name = op_name+"_a"
+ trans_a_node = helper.make_node(
+ 'Transpose',
+ inputs=[input_node_a],
+ outputs=[op_name+"_a"],
+ name=node_name
+ )
+
+ matmul_node = helper.make_node(
+ 'MatMul',
+ inputs=[node_name, input_node_b],
+ outputs=[name],
+ name=name
+ )
+ return [trans_a_node, matmul_node]
+
+ elif trans_a == 0 and trans_b == 1:
+ node_name = op_name + "_b"
+ trans_b_node = helper.make_node(
+ 'Transpose',
+ inputs=[input_node_b],
+ outputs=[op_name+"_b"],
+ name=node_name
+ )
+
+ matmul_node = helper.make_node(
+ 'MatMul',
+ inputs=[input_node_a, node_name],
+ outputs=[name],
+ name=name
+ )
+
+ return [trans_b_node, matmul_node]
+ else:
+ node_name_a = op_name+"_a"
+ trans_a_node = helper.make_node(
+ 'Transpose',
+ inputs=[input_node_a],
+ outputs=[op_name+"_a"],
+ name=node_name_a
+ )
+
+ node_name_b = op_name + "_b"
+ trans_b_node = helper.make_node(
+ 'Transpose',
+ inputs=[input_node_b],
+ outputs=[op_name+"_b"],
+ name=node_name_b
+ )
+
+ matmul_node = helper.make_node(
+ 'MatMul',
+ inputs=[node_name_a, node_name_b],
+ outputs=[name],
+ name=name
+ )
+
+ return [trans_a_node, trans_b_node, matmul_node]
+
+
+@mx_op.register("Pooling")
+def convert_pooling(node, **kwargs):
+ """Map MXNet's Pooling operator attributes to onnx's
+ MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators
+ based on the input node's attributes and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ attrs = node["attrs"]
+ kernel = eval(attrs["kernel"])
+ pool_type = attrs["pool_type"]
+ stride = eval(attrs["stride"]) if attrs.get("stride") else None
+ global_pool = True if "global_pool" in attrs and\
+ attrs.get("global_pool") == "True" else False
+ node_inputs = node["inputs"]
+ input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+ input_node = proc_nodes[input_node_idx]
+ name = node["name"]
+
+ pooling_convention = attrs.get('pooling_convention', 'valid')
+
+ if pooling_convention == 'full':
+ pooling_warning = "Pooling: ONNX currently doesn't support
pooling_convention. " \
+ "This might lead to shape or accuracy issues. " \
+ "https://github.com/onnx/onnx/issues/549"
+
+ logging.warning(pooling_warning)
+
+ pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+ pad_dims = pad_dims + pad_dims
+ pool_types = {"max": "MaxPool", "avg": "AveragePool"}
+ global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
+
+ if global_pool:
+ node = helper.make_node(
+ global_pool_types[pool_type],
+ [input_node.name], # input
+ [name],
+ name=name
+ )
+ else:
+ node = helper.make_node(
+ pool_types[pool_type],
+ [input_node.name], # input
+ [name],
+ kernel_shape=kernel,
+ pads=pad_dims,
+ strides=stride,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("exp")
+def convert_exp(node, **kwargs):
+ """Map MXNet's exp operator attributes to onnx's Exp operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Exp",
+ [input_node],
+ [name],
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("_copy")
+def convert_identity(node, **kwargs):
+ """Map MXNet's _copy operator attributes to onnx's Identity operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Identity",
+ [input_node],
+ [name],
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("LeakyReLU")
+def convert_leakyrelu(node, **kwargs):
+ """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu
operators
+ based on the input node's attributes and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+ attrs = node["attrs"]
+
+ act_type = attrs.get("act_type", "LeakyRelu")
+ alpha = float(attrs.get("slope", 0.25))
+
+ act_name = {"elu": "Elu", "LeakyRelu": "LeakyRelu", "prelu": "PRelu"}
+
+ if act_type == "prelu":
+ alpha_node_index = kwargs["index_lookup"][inputs[1][0]]
+ alpha_node_name = proc_nodes[alpha_node_index].name
+
+ node = helper.make_node(
+ act_name[act_type],
+ inputs=[input_node, alpha_node_name],
+ outputs=[name],
+ name=name)
+ else:
+ node = helper.make_node(
+ act_name[act_type],
+ inputs=[input_node],
+ outputs=[name],
+ name=name,
+ alpha=alpha)
+
+ return [node]
+
+
+@mx_op.register("softmax")
+def convert_softmax(node, **kwargs):
+ """Map MXNet's softmax operator attributes to onnx's Softmax operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ inputs = node["inputs"]
+ input_idx = kwargs["index_lookup"][inputs[0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_idx]
+
+ name = node["name"]
+ axis = int(node.get("attrs", {}).get("axis", -1))
+
+ softmax_node = helper.make_node(
+ "Softmax",
+ [input_node.name],
+ [name],
+ axis=axis,
+ name=name
+ )
+
+ return [softmax_node]
+
+
+# There's also mx.sym.softmax(), which doesn't do cross-entropy loss,
+# just softmax for inference - hence the name convert_softmax_output.
+@mx_op.register("SoftmaxOutput")
+def convert_softmax_output(node, **kwargs):
+ """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ inputs = node["inputs"]
+ input1_idx = kwargs["index_lookup"][inputs[0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input1 = proc_nodes[input1_idx]
+ name = node["name"]
+
+ softmax_node = helper.make_node(
+ "Softmax",
+ [input1.output[0]],
+ [name],
+ axis=1,
+ name=name
+ )
+
+ return [softmax_node]
+
+
+@mx_op.register("Concat")
+def convert_concat(node, **kwargs):
+ """Map MXNet's Concat operator attributes to onnx's Concat operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ inputs = node["inputs"]
+ proc_nodes = kwargs["proc_nodes"]
+ input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in
inputs]
+ axis = int(node.get("attrs", {}).get("dim", 1))
+ concat_node = helper.make_node(
+ "Concat",
+ input_names,
+ [name],
+ axis=axis,
+ name=name
+ )
+ return [concat_node]
+
+
+@mx_op.register("transpose")
+def convert_transpose(node, **kwargs):
+ """Map MXNet's transpose operator attributes to onnx's Transpose operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_idx].name
+ axes = node.get("attrs", {}).get("axes", ())
+ if axes:
+ axes = tuple(map(int, re.findall(r'\d+', axes)))
+
+ transpose_node = helper.make_node(
+ "Transpose",
+ [input_node],
+ [name],
+ perm=axes,
+ name=name
+ )
+ else:
+ transpose_node = helper.make_node(
+ "Transpose",
+ [input_node],
+ [name],
+ name=name
+ )
+
+ return [transpose_node]
+
+
+@mx_op.register("LRN")
+def convert_lrn(node, **kwargs):
+ """Map MXNet's LRN operator attributes to onnx's LRN operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_idx].name
+
+ attrs = node["attrs"]
+ alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001
+ beta = float(attrs["beta"]) if "beta" in attrs else 0.75
+ bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0
+ size = int(attrs["nsize"])
+
+ lrn_node = helper.make_node(
+ "LRN",
+ inputs=[input_node],
+ outputs=[name],
+ name=name,
+ alpha=alpha,
+ beta=beta,
+ bias=bias,
+ size=size
+ )
+
+ return [lrn_node]
+
+
+@mx_op.register("L2Normalization")
+def convert_l2normalization(node, **kwargs):
+ """Map MXNet's L2Normalization operator attributes to onnx's
LpNormalization operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_id = kwargs["index_lookup"][node["inputs"][0][0]]
+ input_name = kwargs["proc_nodes"][input_id].name
+ attrs = node["attrs"]
+ mode = attrs.get("mode", "instance")
+
+ if mode != "channel":
+ raise AttributeError("ONNX currently supports channel mode only")
+
+ l2norm_node = helper.make_node(
+ "LpNormalization",
+ [input_name],
+ [name],
+ axis=1, # channel only
+ name=name
+ )
+ return [l2norm_node]
+
+
+@mx_op.register("Dropout")
+def convert_dropout(node, **kwargs):
+ """Map MXNet's Dropout operator attributes to onnx's Dropout operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_id = kwargs["index_lookup"][node["inputs"][0][0]]
+ input_name = kwargs["proc_nodes"][input_id].name
+ attrs = node["attrs"]
+ probability = float(attrs["p"])
+
+ dropout_node = helper.make_node(
+ "Dropout",
+ [input_name],
+ [name],
+ ratio=probability,
+ name=name
+ )
+ return [dropout_node]
+
+
+@mx_op.register("Flatten")
+def convert_flatten(node, **kwargs):
+ """Map MXNet's Flatten operator attributes to onnx's Flatten operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+ proc_nodes = kwargs["proc_nodes"]
+ input_node = proc_nodes[input_idx].name # .output[0]
+
+ flatten_node = helper.make_node(
+ "Flatten",
+ [input_node],
+ [name],
+ name=name
+ )
+ return [flatten_node]
+
+
+def scalar_op_helper(node, op_name, **kwargs):
+ """Helper function for scalar arithmetic operations"""
+ helper, numpy_helper, mapping = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ scalar_value = [float(node.get("attrs", {}).get("scalar", 1))]
+
+ input_name_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_name_id].name
+
+ initializer = kwargs["initializer"]
+ flag = True
+ # If the input value is in initializer, just multiply with scalar input
+ # and create a new initializer
+ for i in initializer:
+ if i.name == input_node:
+ if op_name == 'Mul':
+ new_initializer = numpy_helper.to_array(i) * scalar_value[0]
+ elif op_name == 'Sub':
+ new_initializer = numpy_helper.to_array(i) - scalar_value[0]
+ elif op_name == 'Add':
+ new_initializer = numpy_helper.to_array(i) + scalar_value[0]
+ elif op_name == 'Div':
+ new_initializer = numpy_helper.to_array(i) / scalar_value[0]
+ flag = False
+ break
+
+ # else create a new tensor of the scalar value, add it in initializer
+ if flag is True:
+ np_arr = np.array(scalar_value)
+ data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
+ dims = np.shape(np_arr)
+
+ scalar_op_name = "scalar_op" + str(kwargs["idx"])
+ tensor_node = helper.make_tensor_value_info(scalar_op_name, data_type,
dims)
+
+ initializer.append(
+ helper.make_tensor(
+ name=scalar_op_name,
+ data_type=data_type,
+ dims=dims,
+ vals=scalar_value,
+ raw=False,
+ )
+ )
+
+ mul_node = helper.make_node(
+ op_name,
+ [input_node, scalar_op_name],
+ [name],
+ name=name
+ )
+
+ return [tensor_node, mul_node]
+ else:
+ data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype]
+ dims = np.shape(new_initializer)
+
+ new_a_node = input_node + str(kwargs["idx"])
+ tensor_node = helper.make_tensor_value_info(new_a_node, data_type,
dims)
+
+ initializer.append(
+ helper.make_tensor(
+ name=new_a_node,
+ data_type=data_type,
+ dims=dims,
+ vals=new_initializer,
+ raw=False,
+ )
+ )
+ return [tensor_node]
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_mul_scalar")
+def convert_mul_scalar(node, **kwargs):
+ """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Mul', **kwargs)
+
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_minus_scalar")
+def convert_minus_scalar(node, **kwargs):
+ """Map MXNet's _minus_scalar operator attributes to onnx's Minus operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Sub', **kwargs)
+
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_plus_scalar")
+def convert_add_scalar(node, **kwargs):
+ """Map MXNet's _plus_scalar operator attributes to onnx's Add operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Add', **kwargs)
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_div_scalar")
+def convert_div_scalar(node, **kwargs):
+ """Map MXNet's _div_scalar operator attributes to onnx's Div operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Div', **kwargs)
+
+
+# Sorting and Searching
+@mx_op.register("argmax")
+def convert_argmax(node, **kwargs):
+ """Map MXNet's argmax operator attributes to onnx's ArgMax operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ node_inputs = node["inputs"]
+
+ input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+ input_node = proc_nodes[input_node_idx].name
+ name = node["name"]
+ attrs = node["attrs"]
+
+ axis = int(attrs.get("axis"))
+ keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1
+
+ node = helper.make_node(
+ 'ArgMax',
+ inputs=[input_node],
+ axis=axis,
+ keepdims=keepdims,
+ outputs=[name],
+ name=name
+ )
+ return [node]
+
+@mx_op.register("argmin")
+def convert_argmin(node, **kwargs):
+ """Map MXNet's argmin operator attributes to onnx's ArgMin operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ node_inputs = node["inputs"]
+
+ input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+ input_node = proc_nodes[input_node_idx].name
+ name = node["name"]
+ attrs = node["attrs"]
+
+ axis = int(attrs.get("axis"))
+ keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1
+
+ node = helper.make_node(
+ 'ArgMin',
+ inputs=[input_node],
+ axis=axis,
+ keepdims=keepdims,
+ outputs=[name],
+ name=name
+ )
+ return [node]
+
+@mx_op.register("_maximum")
+def convert_maximum(node, **kwargs):
+ """Map MXNet's _maximum operator attributes to onnx's Max operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ node_inputs = node["inputs"]
+
+ input_node_list = []
+ for node_input in node_inputs:
+ node_id = kwargs["index_lookup"][node_input[0]]
+ input_node_list.append(proc_nodes[node_id].name)
+
+ name = node["name"]
+
+ node = helper.make_node(
+ 'Max',
+ inputs=input_node_list,
+ outputs=[name],
+ name=name,
+ )
+
+ return [node]
+
+
+@mx_op.register("_minimum")
+def convert_minimum(node, **kwargs):
+ """Map MXNet's _minimum operator attributes to onnx's Min operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ proc_nodes = kwargs["proc_nodes"]
+ node_inputs = node["inputs"]
+
+ input_node_list = []
+ for node_input in node_inputs:
+ node_id = kwargs["index_lookup"][node_input[0]]
+ input_node_list.append(proc_nodes[node_id].name)
+
+ name = node["name"]
+
+ node = helper.make_node(
+ 'Min',
+ inputs=input_node_list,
+ outputs=[name],
+ name=name,
+ )
+
+ return [node]
+
+
+@mx_op.register("min")
+def convert_min(node, **kwargs):
+ """Map MXNet's min operator attributes to onnx's ReduceMin operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ mx_axis = node.get("attrs", {}).get("axis", None)
+ axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else
None
+
+ keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ if axes is not None:
+ node = helper.make_node(
+ 'ReduceMin',
+ inputs=[input_node],
+ outputs=[name],
+ axes=axes,
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+ else:
+ node = helper.make_node(
+ 'ReduceMin',
+ inputs=[input_node],
+ outputs=[name],
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("max")
+def convert_max(node, **kwargs):
+ """Map MXNet's max operator attributes to onnx's ReduceMax operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ mx_axis = node.get("attrs", {}).get("axis", None)
+ axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else
None
+
+ keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ if axes is not None:
+ node = helper.make_node(
+ 'ReduceMax',
+ inputs=[input_node],
+ outputs=[name],
+ axes=axes,
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+ else:
+ node = helper.make_node(
+ 'ReduceMax',
+ inputs=[input_node],
+ outputs=[name],
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("mean")
+def convert_mean(node, **kwargs):
+ """Map MXNet's mean operator attributes to onnx's ReduceMean operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ mx_axis = node.get("attrs", {}).get("axis", None)
+ axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else
None
+
+ keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ if axes is not None:
+ node = helper.make_node(
+ 'ReduceMean',
+ inputs=[input_node],
+ outputs=[name],
+ axes=axes,
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+ else:
+ node = helper.make_node(
+ 'ReduceMean',
+ inputs=[input_node],
+ outputs=[name],
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+
+
+@mx_op.register("prod")
+def convert_prod(node, **kwargs):
+ """Map MXNet's prod operator attributes to onnx's ReduceProd operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ mx_axis = node.get("attrs", {}).get("axis", None)
+ axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else
None
+
+ keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ if axes is not None:
+ node = helper.make_node(
+ 'ReduceProd',
+ inputs=[input_node],
+ outputs=[name],
+ axes=axes,
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+ else:
+ node = helper.make_node(
+ 'ReduceProd',
+ inputs=[input_node],
+ outputs=[name],
+ keepdims=keepdims,
+ name=name
+ )
+
+ return [node]
+
+
+# Arithmetic Operations
+@mx_op.register("elemwise_add")
+def convert_elementwise_add(node, **kwargs):
+ """Map MXNet's elemwise_add operator attributes to onnx's Add operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ add_node = helper.make_node(
+ "Add",
+ [input_node_a, input_node_b],
+ [name],
+ name=name,
+ )
+
+ return [add_node]
+
+
+@mx_op.register("broadcast_add")
+def covert_broadcast_add(node, **kwargs):
+ """Map MXNet's broadcast_add operator attributes to onnx's Add operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ add_node = helper.make_node(
+ "Add",
+ [input_node_a, input_node_b],
+ [name],
+ name=name,
+ )
+
+ return [add_node]
+
+
+@mx_op.register("elemwise_sub")
+def convert_elementwise_sub(node, **kwargs):
+ """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ sub_node = helper.make_node(
+ "Sub",
+ [input_node_a, input_node_b],
+ [name],
+ name=name,
+ )
+
+ return [sub_node]
+
+@mx_op.register("broadcast_sub")
+def covert_broadcast_sub(node, **kwargs):
+ """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ sub_node = helper.make_node(
+ "Sub",
+ [input_node_a, input_node_b],
+ [name],
+ name=name,
+ )
+
+ return [sub_node]
+
+
+@mx_op.register("elemwise_mul")
+def convert_elemwise_mul(node, **kwargs):
+ """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ mul_node = helper.make_node(
+ "Mul",
+ [input_node_a, input_node_b],
+ [name],
+ name=name,
+ )
+
+ return [mul_node]
+
+@mx_op.register("broadcast_mul")
+def convert_broadcast_mul(node, **kwargs):
+ """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ mul_node = helper.make_node(
+ "Mul",
+ [input_node_a, input_node_b],
+ [name],
+ name=name
+ )
+
+ return [mul_node]
+
+
+@mx_op.register("elemwise_div")
+def convert_elemwise_div(node, **kwargs):
+ """Map MXNet's elemwise_div operator attributes to onnx's Div operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ div_node = helper.make_node(
+ "Div",
+ [input_node_a, input_node_b],
+ [name],
+ name=name
+ )
+
+ return [div_node]
+
+
+@mx_op.register("broadcast_div")
+def convert_broadcast_div(node, **kwargs):
+ """Map MXNet's broadcast_div operator attributes to onnx's Div operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ div_node = helper.make_node(
+ "Div",
+ [input_node_a, input_node_b],
+ [name],
+ name=name
+ )
+
+ return [div_node]
+
+
+@mx_op.register("negative")
+def convert_negative(node, **kwargs):
+ """Map MXNet's negative operator attributes to onnx's Neg operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+
+ input_node = proc_nodes[input_node_id].name
+
+ neg_node = helper.make_node(
+ "Neg",
+ [input_node],
+ [name],
+ name=name,
+ )
+
+ return [neg_node]
+
+
+@mx_op.register("abs")
+def convert_abs(node, **kwargs):
+ """Map MXNet's abs operator attributes to onnx's Abs operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+
+ input_node = proc_nodes[input_node_id].name
+
+ abs_node = helper.make_node(
+ "Abs",
+ [input_node],
+ [name],
+ name=name
+ )
+
+ return [abs_node]
+
+
+@mx_op.register("add_n")
+def convert_addn(node, **kwargs):
+ """Map MXNet's add_n operator attributes to onnx's Sum operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_list = []
+ for input_val in inputs:
+
input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name)
+
+ sum_node = helper.make_node(
+ "Sum",
+ input_list,
+ [name],
+ name=name
+ )
+ return [sum_node]
+
+ # Rounding
+@mx_op.register("ceil")
+def convert_ceil(node, **kwargs):
+ """Map MXNet's ceil operator attributes to onnx's Ceil operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Ceil",
+ [input_node],
+ [name],
+ name=name
+ )
+ return [node]
+
+@mx_op.register("floor")
+def convert_floor(node, **kwargs):
+ """Map MXNet's floor operator attributes to onnx's Floor operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Floor",
+ [input_node],
+ [name],
+ name=name
+ )
+ return [node]
+
+# Changing shape and type.
+@mx_op.register("Reshape")
+def convert_reshape(node, **kwargs):
+ """Map MXNet's Reshape operator attributes to onnx's Reshape operator.
+ Converts output shape attribute to output shape tensor
+ and return multiple created nodes.
+ """
+ helper, _, mapping = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ attrs = node["attrs"]
+
+ output_shape_list = convert_string_to_list(attrs["shape"])
+
+ initializer = kwargs["initializer"]
+ output_shape_np = np.array(output_shape_list)
+ data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype]
+ dims = np.shape(output_shape_np)
+
+ output_shape_name = "reshape_attr_tensor" + str(kwargs["idx"])
+ tensor_node = helper.make_tensor_value_info(output_shape_name, data_type,
dims)
+
+ initializer.append(
+ helper.make_tensor(
+ name=output_shape_name,
+ data_type=data_type,
+ dims=dims,
+ vals=output_shape_list,
+ raw=False,
+ )
+ )
+
+ input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+ input_node_name = proc_nodes[input_node_idx].name
+
+ not_supported_shape = [-2, -3, -4]
+
+ for val in output_shape_list:
+ if val in not_supported_shape:
+ raise AttributeError("Shape value not supported in ONNX", val)
+
+ reshape_node = helper.make_node(
+ "Reshape",
+ [input_node_name, output_shape_name],
+ [name],
+ name=name
+ )
+
+ return [tensor_node, reshape_node]
+
+@mx_op.register("Cast")
+def convert_cast(node, **kwargs):
+ """Map MXNet's Cast operator attributes to onnx's Cast operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ dtype = node["attrs"]["dtype"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Cast",
+ [input_node],
+ [name],
+ to=dtype,
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("slice_axis")
+def convert_slice_axis(node, **kwargs):
+ """Map MXNet's slice_axis operator attributes to onnx's Slice operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ axes = int(node["attrs"]["axis"])
+ starts = int(node["attrs"]["begin"])
+ if node["attrs"]["end"] == 'None':
+ raise ValueError("Slice: ONNX doesnt't support 'None' in 'end'
attribute")
+ else:
+ ends = int(node["attrs"]["end"])
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Slice",
+ [input_node],
+ [name],
+ axes=[axes],
+ starts=[starts],
+ ends=[ends],
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("SliceChannel")
+def convert_slice_channel(node, **kwargs):
+ """Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split
+ operator based on squeeze_axis attribute
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ num_outputs = int(node.get("attrs", {})["num_outputs"])
+ axis = int(node.get("attrs", {}).get("axis", 1))
+ squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0))
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ if squeeze_axis == 1 and num_outputs == 1:
+ node = helper.make_node(
+ "Squeeze",
+ [input_node],
+ [name],
+ axes=[axis],
+ name=name,
+ )
+ return [node]
+ elif squeeze_axis == 0 and num_outputs > 1:
+ node = helper.make_node(
+ "Split",
+ [input_node],
+ [name],
+ axis=axis,
+ split=[num_outputs],
+ name=name,
+ )
+ return [node]
+ else:
+ raise NotImplementedError("SliceChannel operator with num_outputs>1
and"
+ "squeeze_axis true is not implemented.")
+
+
+@mx_op.register("expand_dims")
+def convert_expand_dims(node, **kwargs):
+ """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ axis = int(node["attrs"]["axis"])
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Unsqueeze",
+ [input_node],
+ [name],
+ axes=[axis],
+ name=name,
+ )
+ return [node]
+
+@mx_op.register("squeeze")
+def convert_squeeze(node, **kwargs):
+ """Map MXNet's squeeze operator attributes to onnx's squeeze operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+ if "axis" in node["attrs"]:
+ axis = convert_string_to_list(node["attrs"]["axis"])
+ else:
+ raise AttributeError("Missing axis attribute: ONNX currently requires
axis to "
+ "be specified for squeeze operator")
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Squeeze",
+ [input_node],
+ [name],
+ axes=axis,
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("log")
+def convert_log(node, **kwargs):
+ """Map MXNet's log operator attributes to onnx's Log operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Log",
+ [input_node],
+ [name],
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("reciprocal")
+def convert_reciprocal(node, **kwargs):
+ """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Reciprocal",
+ [input_node],
+ [name],
+ name=name,
+ )
+ return [node]
+
+
+@mx_op.register("_power")
+def convert_power(node, **kwargs):
+ """Map MXNet's _power operator attributes to onnx's Pow operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+ input_node_a = proc_nodes[input_node_a_id].name
+ input_node_b = proc_nodes[input_node_b_id].name
+
+ node = helper.make_node(
+ "Pow",
+ [input_node_a, input_node_b],
+ [name],
+ name=None
+ )
+ return [node]
+
+@mx_op.register("sqrt")
+def convert_sqrt(node, **kwargs):
+ """Map MXNet's sqrt operator attributes to onnx's Sqrt operator
+ and return the created node.
+ """
+ helper, _, _ = import_onnx_modules()
+ name = node["name"]
+ proc_nodes = kwargs["proc_nodes"]
+ inputs = node["inputs"]
+
+ input_node_id = kwargs["index_lookup"][inputs[0][0]]
+ input_node = proc_nodes[input_node_id].name
+
+ node = helper.make_node(
+ "Sqrt",
+ [input_node],
+ [name],
+ name=name,
+ )
+ return [node]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
new file mode 100644
index 0000000..0dbfdc1
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+#pylint: disable-msg=too-many-arguments
+
+"""export function"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+import numpy as np
+
+from ....base import string_types
+from .... import symbol
+from .export_onnx import MXNetGraph
+from ._export_helper import load_module
+
+
+def export_model(sym, params, input_shape, input_type=np.float32,
+ onnx_file_path='model.onnx', verbose=False):
+ """Exports the MXNet model file, passed as a parameter, into ONNX model.
+ Accepts both symbol,parameter objects as well as json and params filepaths
as input.
+ Operator support and coverage -
https://cwiki.apache.org/confluence/display/MXNET/ONNX
+
+ Parameters
+ ----------
+ sym : str or symbol object
+ Path to the json file or Symbol object
+ params : str or symbol object
+ Path to the params file or params dictionary. (Including both
arg_params and aux_params)
+ input_shape : List of tuple
+ Input shape of the model e.g [(1,3,224,224)]
+ input_type : data type
+ Input data type e.g. np.float32
+ onnx_file_path : str
+ Path where to save the generated onnx file
+ verbose : Boolean
+ If true will print logs of the model conversion
+
+ Returns
+ -------
+ onnx_file_path : str
+ Onnx file path
+ """
+
+ try:
+ from onnx import helper, mapping
+ except ImportError:
+ raise ImportError("Onnx and protobuf need to be installed. "
+ + "Instructions to install -
https://github.com/onnx/onnx")
+
+ converter = MXNetGraph()
+
+ data_format = np.dtype(input_type)
+ # if input parameters are strings(file paths), load files and create
symbol parameter objects
+ if isinstance(sym, string_types) and isinstance(params, string_types):
+ logging.info("Converting json and weight file to sym and params")
+ sym_obj, params_obj = load_module(sym, params)
+ onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj,
input_shape,
+
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+ verbose=verbose)
+ elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
+ onnx_graph = converter.create_onnx_graph_proto(sym, params,
input_shape,
+
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+ verbose=verbose)
+ else:
+ raise ValueError("Input sym and params should either be files or
objects")
+
+ # Create the model (ModelProto)
+ onnx_model = helper.make_model(onnx_graph)
+
+ # Save model on disk
+ with open(onnx_file_path, "wb") as file_handle:
+ serialized = onnx_model.SerializeToString()
+ file_handle.write(serialized)
+ logging.info("Input shape of the model %s ", input_shape)
+ logging.info("Exported ONNX file %s saved to disk", onnx_file_path)
+
+ return onnx_file_path
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
new file mode 100644
index 0000000..1184738
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -0,0 +1,347 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+#
https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/mx2onnx_converter.py#
+# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# coding: utf-8
+# pylint: disable=invalid-name,too-many-locals,no-self-use,too-many-arguments,
+# pylint: disable=maybe-no-member,too-many-nested-blocks
+"""MXNet to ONNX graph converter functions"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+import json
+import numpy as np
+
+from .... import context
+from .... import ndarray as nd
+from .... import io
+from .... import module as mod
+
+
+class MXNetGraph(object):
+ """Class to convert MXNet to ONNX graph"""
+ registry_ = {}
+ input_output_maps_ = {}
+
+ def __init__(self):
+ # topologically sorted nodes
+ self.nodes = []
+ self.input_tensors = []
+ self.output_tensors = []
+
+ @staticmethod
+ def register(op_name):
+ """Register operators"""
+ def wrapper(func):
+ """Helper function to map functions"""
+ MXNetGraph.registry_[op_name] = func
+ return func
+
+ return wrapper
+
+ @staticmethod
+ def convert_layer(node, **kwargs):
+ """Convert MXNet layer to ONNX"""
+ op = str(node["op"])
+ if op not in MXNetGraph.registry_:
+ raise AttributeError("No conversion function registered for op
type %s yet." % op)
+ convert_func = MXNetGraph.registry_[op]
+ return convert_func(node, **kwargs)
+
+ @staticmethod
+ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
+ """Do a forward pass based on the sym and params to get the shape
+ of the output using dummy data
+
+ Parameters
+ ----------
+ inputs : json string
+
+ sym : :class:`~mxnet.symbol.Symbol`
+ MXNet symbol object
+ arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+ aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+
+ Returns
+ -------
+ shape : Shape
+ Output shape
+ """
+ # if label is not provided, MXNet adds label "softmax_label" by default
+ # while running load_checkpoint which is not actually a graph input.
So ignoring it here
+ data_names = [graph_input for graph_input in sym.list_inputs()
+ if graph_input not in arg_params and graph_input not in
aux_params
+ and graph_input != output_label]
+
+ data_shapes = []
+ # Adding extra dimension of batch_size 1 if the batch_size is
different for multiple inputs.
+ for idx, input_name in enumerate(data_names):
+ data_shapes.append((input_name, inputs[idx].shape))
+
+ # create module, passing cpu context
+ ctx = context.cpu()
+ test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx,
label_names=None)
+ test_mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
+
+ # initializing parameters for calculating result of each individual
node
+ if arg_params is None and aux_params is None:
+ test_mod.init_params()
+ else:
+ test_mod.set_params(arg_params=arg_params, aux_params=aux_params,
allow_missing=True)
+
+ data_forward = []
+ for idx, input_name in enumerate(data_names):
+ val = inputs[idx]
+ data_forward.append(nd.array(val))
+
+ test_mod.forward(io.DataBatch(data_forward))
+ result = test_mod.get_outputs()[0].asnumpy()
+
+ return result.shape
+
+
+ @staticmethod
+ def split_params(sym, params):
+ """Helper function to split params dictionary into args and aux params
+
+ Parameters
+ ----------
+ sym : :class:`~mxnet.symbol.Symbol`
+ MXNet symbol object
+ params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+
+ Returns
+ -------
+ arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+ aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+ """
+ arg_params = {}
+ aux_params = {}
+ for args in sym.list_arguments():
+ if args in params:
+ arg_params.update({args: nd.array(params[args])})
+ for aux in sym.list_auxiliary_states():
+ if aux in params:
+ aux_params.update({aux: nd.array(params[aux])})
+ return arg_params, aux_params
+
+
+ @staticmethod
+ def infer_output_shape(sym, params, in_shape, output_label):
+ """Infer output shape by doing a forward pass using dummy inputs """
+ # create dummy input
+ inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
+ arg, aux = MXNetGraph.split_params(sym, params)
+ return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+
+
+ @staticmethod
+ def convert_weights_to_numpy(weights_dict):
+ """Convert weights to numpy"""
+ return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
+ for k, v in weights_dict.items()])
+
+ def create_onnx_graph_proto(self, sym, params, in_shape, in_type,
verbose=False):
+ """Convert MXNet graph to ONNX graph
+
+ Parameters
+ ----------
+ sym : :class:`~mxnet.symbol.Symbol`
+ MXNet symbol object
+ params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+ Dict of converted parameters stored in ``mxnet.ndarray.NDArray``
format
+ in_shape : List of tuple
+ Input shape of the model e.g [(1,3,224,224)]
+ in_type : data type
+ Input data type e.g. np.float32
+ verbose : Boolean
+ If true will print logs of the model conversion
+
+ Returns
+ -------
+ graph : GraphProto
+ ONNX graph
+ """
+ try:
+ from onnx import (checker, helper, NodeProto, ValueInfoProto,
TensorProto)
+ from onnx.helper import make_tensor_value_info
+ except ImportError:
+ raise ImportError("Onnx and protobuf need to be installed. "
+ + "Instructions to install -
https://github.com/onnx/onnx")
+
+ # When MXNet model is saved to json file , MXNet adds a node for label.
+ # The name of this node is, name of the last node + "_label" ( i.e if
last node
+ # name is "Softmax", this node will have a name "Softmax_label". Also,
the new node
+ # will always be second last node in the json graph.
+ # Deriving the output_label name.
+ output_label = sym.get_internals()[len(sym.get_internals()) - 1].name
+ "_label"
+
+ # Determine output shape
+ output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape,
output_label)
+
+ weights = MXNetGraph.convert_weights_to_numpy(params)
+
+ mx_graph = json.loads(sym.tojson())["nodes"]
+
+ initializer = []
+ all_processed_nodes = []
+ onnx_processed_nodes = []
+ onnx_processed_inputs = []
+ onnx_processed_outputs = []
+ index_lookup = []
+
+ graph_input_idx = 0
+ for idx, node in enumerate(mx_graph):
+ op = node["op"]
+ name = node["name"]
+ if verbose:
+ logging.info("Converting idx: %d, op: %s, name: %s", idx, op,
name)
+
+ # A node is an input node if its op_name is "null" and is not
+ # in params dict
+ if op == "null" and name not in params:
+ # Handling graph input
+
+ # Skipping output_label node, as this node is not part of graph
+ # Refer "output_label" assignment above for more details.
+ if name == output_label:
+ continue
+ converted = MXNetGraph.convert_layer(
+ node,
+ is_input=True,
+ mx_graph=mx_graph,
+ weights=weights,
+ in_shape=in_shape[graph_input_idx],
+ in_type=in_type,
+ proc_nodes=all_processed_nodes,
+ initializer=initializer,
+ index_lookup=index_lookup)
+ graph_input_idx += 1
+
+ else:
+ # Handling graph layers
+ converted = MXNetGraph.convert_layer(
+ node,
+ is_input=False,
+ mx_graph=mx_graph,
+ weights=weights,
+ in_shape=in_shape,
+ in_type=in_type,
+ proc_nodes=all_processed_nodes,
+ initializer=initializer,
+ index_lookup=index_lookup,
+ idx=idx
+ )
+
+ if isinstance(converted, list):
+ # Iterate for all converted nodes
+ for converted_node in converted:
+ # If converted node is ValueInfoProto, add it in inputs
+ if isinstance(converted_node, ValueInfoProto):
+ onnx_processed_inputs.append(converted_node)
+ # If converted node is NodeProto, add it in processed
nodes list
+ elif isinstance(converted_node, NodeProto):
+ onnx_processed_nodes.append(converted_node)
+ if idx == (len(mx_graph) - 1):
+ # If converted node doesnt have name, use it from
output field
+ if not converted_node.name:
+ onnx_processed_outputs.append(
+ make_tensor_value_info(
+ name=converted_node.output[0],
+ elem_type=in_type,
+ shape=output_shape
+ )
+ )
+ else:
+ onnx_processed_outputs.append(
+ make_tensor_value_info(
+ name=converted_node.name,
+ elem_type=in_type,
+ shape=output_shape
+ )
+ )
+ if verbose:
+ logging.info("Output node is: %s",
converted_node.name)
+ elif isinstance(converted_node, TensorProto):
+ raise ValueError("Did not expect TensorProto")
+ else:
+ raise ValueError("node is of an unrecognized type: %s"
% type(node))
+
+ all_processed_nodes.append(converted_node)
+
+ if idx > 0:
+ # Handling extra node added to the graph if the MXNet
model was
+ # saved to json file,
+ # refer "output_label" initialization above for more
details.
+ # if extra node was added then prev_index to the last node
is adjusted.
+ if idx == (len(mx_graph) - 1) and \
+ mx_graph[len(mx_graph)-2]["name"] == output_label:
+ prev_index = index_lookup[idx - 2]
+ else:
+ prev_index = index_lookup[idx - 1]
+
+ index_lookup.append(prev_index+len(converted))
+ else:
+ index_lookup.append(len(converted) - 1)
+ else:
+ logging.info("Operator converter function should always return
a list")
+
+ graph = helper.make_graph(
+ onnx_processed_nodes,
+ "mxnet_converted_model",
+ onnx_processed_inputs,
+ onnx_processed_outputs
+ )
+
+ graph.initializer.extend(initializer)
+
+ checker.check_graph(graph)
+ return graph
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py
b/python/mxnet/contrib/onnx/onnx2mx/__init__.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/__init__.py
rename to python/mxnet/contrib/onnx/onnx2mx/__init__.py
diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
similarity index 74%
rename from python/mxnet/contrib/onnx/_import/import_helper.py
rename to python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index 3dfff3e..c19f0f2 100644
--- a/python/mxnet/contrib/onnx/_import/import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -15,27 +15,27 @@
# specific language governing permissions and limitations
# under the License.
-# coding: utf-8
+# coding: utf-8_
# pylint: disable=invalid-name
"""Operator attributes conversion"""
-from .op_translations import identity, random_uniform, random_normal
-from .op_translations import add, subtract, multiply, divide, absolute,
negative, add_n
-from .op_translations import tanh
-from .op_translations import ceil, floor
-from .op_translations import concat
-from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected
-from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm
-from .op_translations import sigmoid, pad, relu, matrix_multiplication,
batch_norm
-from .op_translations import dropout, local_response_norm, conv, deconv
-from .op_translations import reshape, cast, split, _slice, transpose, squeeze,
flatten
-from .op_translations import reciprocal, squareroot, power, exponent, _log,
unsqueeze
-from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
-from .op_translations import reduce_prod, avg_pooling, max_pooling
-from .op_translations import argmax, argmin, maximum, minimum
-from .op_translations import clip, reduce_log_sum, reduce_log_sum_exp
-from .op_translations import reduce_sum_square, reduce_l2, max_roi_pooling,
instance_norm
-from .op_translations import log_softmax, softsign, lesser, greater, equal
-from .op_translations import logical_and, logical_or, logical_xor, logical_not
+from ._op_translations import identity, random_uniform, random_normal
+from ._op_translations import add, subtract, multiply, divide, absolute,
negative, add_n
+from ._op_translations import tanh
+from ._op_translations import ceil, floor
+from ._op_translations import concat
+from ._op_translations import leaky_relu, _elu, _prelu, softmax,
fully_connected
+from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
+from ._op_translations import sigmoid, pad, relu, matrix_multiplication,
batch_norm
+from ._op_translations import dropout, local_response_norm, conv, deconv
+from ._op_translations import reshape, cast, split, _slice, transpose,
squeeze, flatten
+from ._op_translations import reciprocal, squareroot, power, exponent, _log,
unsqueeze
+from ._op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
+from ._op_translations import reduce_prod, avg_pooling, max_pooling
+from ._op_translations import argmax, argmin, maximum, minimum
+from ._op_translations import clip, reduce_log_sum, reduce_log_sum_exp
+from ._op_translations import reduce_sum_square, reduce_l2, max_roi_pooling,
instance_norm
+from ._op_translations import log_softmax, softsign, lesser, greater, equal
+from ._op_translations import logical_and, logical_or, logical_xor, logical_not
# convert_map defines maps of ONNX operator names to converter
functor(callable)
# defined in the op_translations module.
@@ -89,6 +89,7 @@ _convert_map = {
'Squeeze' : squeeze,
'Unsqueeze' : unsqueeze,
'Flatten' : flatten,
+ 'Identity' : identity,
#Powers
'Reciprocal' : reciprocal,
'Sqrt' : squareroot,
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/op_translations.py
rename to python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 0fad008..2b98aa0 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -19,7 +19,7 @@
""" Module for translating ONNX operators into Mxnet operatoes"""
# pylint: disable=unused-argument,protected-access
import numpy as np
-from . import translation_utils
+from . import _translation_utils as translation_utils
from .... import symbol
# Method definitions for the callable objects mapped in the import_helper
module
@@ -130,7 +130,7 @@ def maximum(attrs, inputs, proto_obj):
for op_input in inputs[2:]:
mxnet_op = symbol.maximum(mxnet_op, op_input)
else:
- mxnet_op = inputs[0]
+ mxnet_op = symbol.maximum(inputs[0], inputs[0])
return mxnet_op, attrs, inputs
def minimum(attrs, inputs, proto_obj):
@@ -143,7 +143,7 @@ def minimum(attrs, inputs, proto_obj):
for op_input in inputs[2:]:
mxnet_op = symbol.minimum(mxnet_op, op_input)
else:
- mxnet_op = inputs[0]
+ mxnet_op = symbol.minimum(inputs[0], inputs[0])
return mxnet_op, attrs, inputs
def lesser(attrs, inputs, proto_obj):
diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py
b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/translation_utils.py
rename to python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index fe25a94..f63c1e9 100644
--- a/python/mxnet/contrib/onnx/_import/translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -168,6 +168,7 @@ def _fix_broadcast(op_name, inputs, broadcast_axis,
proto_obj):
op_sym = op_name
return op_sym
+
def _fix_channels(op_name, attrs, inputs, proto_obj):
"""A workaround for getting 'channels' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py
b/python/mxnet/contrib/onnx/onnx2mx/import_model.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/import_model.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_model.py
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py
b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/import_onnx.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index d81ec96..4e85171 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -23,7 +23,7 @@ from .... import symbol
from .... import cpu, gpu
from .... import ndarray as nd
from ....base import string_types
-from .import_helper import _convert_map as convert_map
+from ._import_helper import _convert_map as convert_map
class GraphProto(object): # pylint: disable=too-few-public-methods
"""A helper class for handling mxnet symbol copying from pb2.GraphProto.
diff --git a/python/mxnet/contrib/onnx/_import/import_to_gluon.py
b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/import_to_gluon.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py
b/tests/python-pytest/onnx/export/backend.py
similarity index 55%
copy from tests/python-pytest/onnx/import/mxnet_backend.py
copy to tests/python-pytest/onnx/export/backend.py
index bbe8899..e23cc01 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend.py
+++ b/tests/python-pytest/onnx/export/backend.py
@@ -16,25 +16,47 @@
# under the License.
# coding: utf-8
-"""MXNet backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
+"""backend wrapper for onnx test infrastructure"""
+import numpy as np
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph
try:
- from onnx import helper, TensorProto
+ from onnx import helper, TensorProto, mapping
from onnx.backend.base import Backend
except ImportError:
- raise ImportError("Onnx and protobuf need to be installed. Instructions to"
- + " install - https://github.com/onnx/onnx#installation")
-from mxnet_backend_rep import MXNetBackendRep
+ raise ImportError("Onnx and protobuf need to be installed")
+from backend_rep import MXNetBackendRep
+# Using these functions for onnx test infrastructure.
+# Implemented by following onnx docs guide:
+#
https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
# MXNetBackend class will take an ONNX model with inputs, perform a
computation,
# and then return the output.
-# Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
class MXNetBackend(Backend):
"""MXNet backend for ONNX"""
+ @staticmethod
+ def perform_import_export(graph_proto, input_shape):
+ """ Import ONNX model to mxnet model and then export to ONNX model
+ and then import it back to mxnet for verifying the result"""
+ graph = GraphProto()
+
+ sym, arg_params, aux_params = graph.from_onnx(graph_proto)
+
+ params = {}
+ params.update(arg_params)
+ params.update(aux_params)
+ # exporting to onnx graph proto format
+ converter = MXNetGraph()
+ graph_proto = converter.create_onnx_graph_proto(sym, params,
in_shape=input_shape,
in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')])
+
+ # importing back to MXNET for verifying result.
+ sym, arg_params, aux_params = graph.from_onnx(graph_proto)
+
+ return sym, arg_params, aux_params
+
+
@classmethod
def prepare(cls, model, device='CPU', **kwargs):
"""For running end to end model(used for onnx test backend)
@@ -54,8 +76,12 @@ class MXNetBackend(Backend):
Returns object of MXNetBackendRep class which will be in turn
used to run inference on the input model and return the result for
comparison.
"""
+
graph = GraphProto()
- sym, arg_params, aux_params = graph.from_onnx(model.graph)
+ metadata = graph.get_graph_metadata(model.graph)
+ input_data = metadata['input_tensor_data']
+ input_shape = [data[1] for data in input_data]
+ sym, arg_params, aux_params =
MXNetBackend.perform_import_export(model.graph, input_shape)
return MXNetBackendRep(sym, arg_params, aux_params, device)
@classmethod
@@ -63,6 +89,9 @@ class MXNetBackend(Backend):
"""Supports only CPU for testing"""
return device == 'CPU'
+
prepare = MXNetBackend.prepare
+run_node = MXNetBackend.run_node
+
supports_device = MXNetBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
b/tests/python-pytest/onnx/export/backend_rep.py
similarity index 91%
copy from tests/python-pytest/onnx/import/mxnet_backend_rep.py
copy to tests/python-pytest/onnx/export/backend_rep.py
index 5ce29f5..8729eaf 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
+++ b/tests/python-pytest/onnx/export/backend_rep.py
@@ -16,18 +16,16 @@
# under the License.
# coding: utf-8
-"""MXNet backend rep for onnx test infrastructure"""
-import numpy as np
+"""backend rep for onnx test infrastructure"""
try:
from onnx.backend.base import BackendRep
except ImportError:
- raise ImportError("Onnx and protobuf need to be installed. Instructions to"
- + " install - https://github.com/onnx/onnx#installation")
+ raise ImportError("Onnx and protobuf need to be installed")
import mxnet as mx
# Using these functions for onnx test infrastructure.
# Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+#
https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
# MXNetBackendRep object will be returned by MXNetBackend's prepare method
which is used to
# execute a model repeatedly.
# Inputs will be passed to the run method of MXNetBackendRep class, it will
perform computation and
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py
b/tests/python-pytest/onnx/export/mxnet_export_test.py
new file mode 100644
index 0000000..7e1df07
--- /dev/null
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -0,0 +1,191 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Tests for individual operators
+This module contains operator tests which currently do not exist on
+ONNX backend test framework. Once we have PRs on the ONNX repo and get
+those PRs merged, this file will get EOL'ed.
+"""
+# pylint: disable=too-many-locals,wrong-import-position,import-error
+from __future__ import absolute_import
+import sys
+import os
+import logging
+import tarfile
+from collections import namedtuple
+import numpy as np
+import numpy.testing as npt
+from onnx import numpy_helper
+from onnx import TensorProto
+from mxnet.test_utils import download
+from mxnet.contrib import onnx as onnx_mxnet
+import mxnet as mx
+CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+URLS = {
+ 'bvlc_googlenet':
+ 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_googlenet.tar.gz',
+ 'bvlc_reference_caffenet':
+
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_caffenet.tar.gz',
+ 'bvlc_reference_rcnn_ilsvrc13':
+
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_rcnn_ilsvrc13.tar.gz',
+ 'inception_v1':
+ 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v1.tar.gz',
+ 'inception_v2':
+ 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz'
+}
+
+def get_test_files(name):
+ """Extract tar file and returns model path and input, output data"""
+ tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
+ # extract tar file
+ tar_path = os.path.join(CURR_PATH, tar_name)
+ tar = tarfile.open(tar_path.__str__(), "r:*")
+ tar.extractall(path=CURR_PATH.__str__())
+ tar.close()
+ data_dir = os.path.join(CURR_PATH, name)
+ model_path = os.path.join(data_dir, 'model.onnx')
+
+ inputs = []
+ outputs = []
+ # get test files
+ for test_file in os.listdir(data_dir):
+ case_dir = os.path.join(data_dir, test_file)
+ # skip the non-dir files
+ if not os.path.isdir(case_dir):
+ continue
+ input_file = os.path.join(case_dir, 'input_0.pb')
+ input_tensor = TensorProto()
+ with open(input_file, 'rb') as proto_file:
+ input_tensor.ParseFromString(proto_file.read())
+ inputs.append(numpy_helper.to_array(input_tensor))
+
+ output_tensor = TensorProto()
+ output_file = os.path.join(case_dir, 'output_0.pb')
+ with open(output_file, 'rb') as proto_file:
+ output_tensor.ParseFromString(proto_file.read())
+ outputs.append(numpy_helper.to_array(output_tensor))
+
+ return model_path, inputs, outputs
+
+
+def forward_pass(sym, arg, aux, data_names, input_data):
+ """ Perform forward pass on given data"""
+ # create module
+ mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(),
label_names=None)
+ mod.bind(for_training=False, data_shapes=[(data_names[0],
input_data.shape)], label_shapes=None)
+ mod.set_params(arg_params=arg, aux_params=aux,
+ allow_missing=True, allow_extra=True)
+ # run inference
+ batch = namedtuple('Batch', ['data'])
+ mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+
+ return mod.get_outputs()[0].asnumpy()
+
+
+def test_models(model_name, input_shape, output_shape):
+ """ Tests Googlenet model for both onnx import and export"""
+ model_path, inputs, outputs = get_test_files(model_name)
+ logging.info("Translating model from ONNX model zoo to Mxnet")
+ sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+ params = {}
+ params.update(arg_params)
+ params.update(aux_params)
+
+ dir_path = os.path.dirname(model_path)
+ new_model_name = "exported_" + model_name + ".onnx"
+ onnx_file = os.path.join(dir_path, new_model_name)
+
+ logging.info("Translating converted model from mxnet to ONNX")
+ converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape],
np.float32, onnx_file)
+
+ sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
+
+ metadata = onnx_mxnet.get_model_metadata(converted_model_path)
+ assert len(metadata) == 2
+ assert metadata.get('input_tensor_data')
+ assert metadata.get('input_tensor_data')[0][1] == input_shape
+ assert metadata.get('output_tensor_data')
+ assert metadata.get('output_tensor_data')[0][1] == output_shape
+ data_names = [input_name[0] for input_name in
metadata.get('input_tensor_data')]
+
+ logging.info("Running inference on onnx re-import model in mxnet")
+ # run test for each test file
+ for input_data, output_data in zip(inputs, outputs):
+ result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+
+ # verify the results
+ npt.assert_equal(result.shape, output_data.shape)
+ npt.assert_almost_equal(output_data, result, decimal=3)
+ logging.info(model_name + " conversion successful")
+
+
+def test_model_accuracy(model_name, input_shape):
+ """ Imports ONNX model, runs inference, exports and imports back
+ run inference, compare result with the previous inference result"""
+ model_path, inputs, outputs = get_test_files(model_name)
+ logging.info("Translating model from ONNX model zoo to Mxnet")
+ sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+
+ metadata = onnx_mxnet.get_model_metadata(model_path)
+ data_names = [input_name[0] for input_name in
metadata.get('input_tensor_data')]
+
+ expected_result= []
+ for input_data, output_data in zip(inputs, outputs):
+ result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+ expected_result.append(result)
+
+ params = {}
+ params.update(arg_params)
+ params.update(aux_params)
+
+ dir_path = os.path.dirname(model_path)
+ new_model_name = "exported_" + model_name + ".onnx"
+ onnx_file = os.path.join(dir_path, new_model_name)
+
+ logging.info("Translating converted model from mxnet to ONNX")
+ converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape],
np.float32,
+ onnx_file)
+
+ sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
+
+ metadata = onnx_mxnet.get_model_metadata(converted_model_path)
+ data_names = [input_name[0] for input_name in
metadata.get('input_tensor_data')]
+
+ actual_result = []
+ for input_data, output_data in zip(inputs, outputs):
+ result = forward_pass(sym, arg_params, aux_params, data_names,
input_data)
+ actual_result.append(result)
+
+ # verify the results
+ for expected, actual in zip(expected_result, actual_result):
+ npt.assert_equal(expected.shape, actual.shape)
+ npt.assert_almost_equal(expected, actual, decimal=3)
+
+
+if __name__ == '__main__':
+ test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
+ test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
+ test_models("bvlc_reference_rcnn_ilsvrc13", (1, 3, 224, 224), (1, 200))
+
+ # Comparing MXNet inference result, since MXNet results don't match
+ # ONNX expected results due to AveragePool issue github issue(#10194)
+ test_model_accuracy("inception_v1", (1, 3, 224, 224))
+ test_model_accuracy("inception_v2", (1, 3, 224, 224))
diff --git a/tests/python-pytest/onnx/import/test_cases.py
b/tests/python-pytest/onnx/export/onnx_backend_test.py
similarity index 64%
copy from tests/python-pytest/onnx/import/test_cases.py
copy to tests/python-pytest/onnx/export/onnx_backend_test.py
index 1a4d8c4..803d290 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -15,7 +15,24 @@
# specific language governing permissions and limitations
# under the License.
-"""Test Cases to be run for the import module"""
+"""ONNX test backend wrapper"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+try:
+ import onnx.backend.test
+except ImportError:
+ raise ImportError("Onnx and protobuf need to be installed")
+
+import backend as mxnet_backend
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = "onnx.backend.test.report",
+
+BACKEND_TESTS = onnx.backend.test.BackendTest(mxnet_backend, __name__)
IMPLEMENTED_OPERATORS_TEST = [
'test_random_uniform',
@@ -31,6 +48,7 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_ceil',
'test_floor',
'test_concat',
+ 'test_identity',
'test_sigmoid',
'test_relu',
'test_constant_pad',
@@ -41,7 +59,6 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_reduce_mean',
'test_reduce_prod',
'test_squeeze',
- 'test_unsqueeze',
'test_softmax_example',
'test_softmax_large_number',
'test_softmax_axis_2',
@@ -58,16 +75,7 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_argmax',
'test_argmin',
'test_min',
- 'test_logical_and',
- 'test_logical_xor',
- 'test_logical_not',
- 'test_logical_or',
- 'test_clip',
- 'test_softsign',
- 'test_reduce_l2',
- 'test_reduce_log_sum',
- 'test_reduce_log_sum_exp',
- 'test_reduce_sum_square'
+ 'test_max'
#pytorch operator tests
'test_operator_exp',
'test_operator_maxpool',
@@ -78,7 +86,7 @@ IMPLEMENTED_OPERATORS_TEST = [
BASIC_MODEL_TESTS = [
'test_AvgPool2D',
'test_BatchNorm',
- 'test_ConstantPad2d'
+ 'test_ConstantPad2d',
'test_Conv2d',
'test_ELU',
'test_LeakyReLU',
@@ -95,11 +103,30 @@ BASIC_MODEL_TESTS = [
STANDARD_MODEL = [
'test_bvlc_alexnet',
'test_densenet121',
- #'test_inception_v1',
- #'test_inception_v2',
+ # 'test_inception_v1',
+ # 'test_inception_v2',
'test_resnet50',
- #'test_shufflenet',
+ # 'test_shufflenet',
'test_squeezenet',
- 'test_zfnet512',
+ 'test_vgg16',
'test_vgg19'
]
+
+for op_test in IMPLEMENTED_OPERATORS_TEST:
+ BACKEND_TESTS.include(op_test)
+
+for basic_model_test in BASIC_MODEL_TESTS:
+ BACKEND_TESTS.include(basic_model_test)
+
+for std_model_test in STANDARD_MODEL:
+ BACKEND_TESTS.include(std_model_test)
+
+BACKEND_TESTS.exclude('.*broadcast.*')
+BACKEND_TESTS.exclude('.*bcast.*')
+
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(BACKEND_TESTS.enable_report().test_cases)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py
b/tests/python-pytest/onnx/import/gluon_backend.py
index d2946f7..302fd4d 100644
--- a/tests/python-pytest/onnx/import/gluon_backend.py
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -17,10 +17,8 @@
# coding: utf-8
"""Gluon backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet import nd
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
-import numpy as np
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+
try:
from onnx import helper, TensorProto
from onnx.backend.base import Backend
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py
b/tests/python-pytest/onnx/import/mxnet_backend.py
index bbe8899..10f89ec 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend.py
@@ -17,8 +17,7 @@
# coding: utf-8
"""MXNet backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
try:
from onnx import helper, TensorProto
from onnx.backend.base import Backend
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
index 5ce29f5..067ef15 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
@@ -17,7 +17,6 @@
# coding: utf-8
"""MXNet backend rep for onnx test infrastructure"""
-import numpy as np
try:
from onnx.backend.base import BackendRep
except ImportError:
diff --git a/tests/python-pytest/onnx/import/test_cases.py
b/tests/python-pytest/onnx/import/test_cases.py
index 1a4d8c4..f7addbb 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -31,6 +31,7 @@ IMPLEMENTED_OPERATORS_TEST = [
'test_ceil',
'test_floor',
'test_concat',
+ 'test_identity',
'test_sigmoid',
'test_relu',
'test_constant_pad',