This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d91602  [MXNET-533] MXNet-ONNX export (#11213)
7d91602 is described below

commit 7d91602ba771d973360f8a0c66c976c67f700aa3
Author: Roshani Nagmote <[email protected]>
AuthorDate: Mon Jun 25 09:43:20 2018 -0700

    [MXNET-533] MXNet-ONNX export (#11213)
    
    * Resolve conflicts
    
    * Export module Test Framework
    
    * refactoring export to work with pretrained models
    
    * comments added
    
    * 1. Refactored export module.
    2. Refactored test framework to support ONNX backened tests.
    2. Added Operator support:
       - Convolution2D
       - BatchNorm
       - Add
    
    * Added Arithmetic operators:
    - Add, Sub, Mul, Div, Sum
    
    * Added operator support:
    - sigmoid, relu, pad( constant, edge, reflect), tanh
    - enabled corresponding ONNX backend tests.
    
    * Enabled ONNX tests: test_conv, test_basic_conv
    
    Added Operators :
    Ceil, Floor
    
    * Added support for:
    MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul
    
    * adding more operators
    
    * Added Operator support:
    ArgMax, ArgMin, maximum, minimum
    
    * Enabled more BASIC_MODEL tests
    
    * Added power operator tests
    
    * Added support for reshape. ONNX only supports 0, -1  special values. 
Added only for these.
    Fixed logic error with convert_string_to_list()
    
    * some tests enabled
    
    * enabling squeezenet
    
    * LRN Op support
    
    * mul_scalar modified to take scalar input
    
    * cleaning some code
    
    * Resolving conlicts on rebase
    
    * Resolving rebase conflicts
    
    * id mapping updated for all operators
    
    * save onnx models added, some code cleanup
    
    * enabled more tests
    
    * conv pad calc fixed
    
    * reshape op fix
    
    * Added support for elu, leakyRelu, prelu
    
    * Cleanup
    - Removed run_node, not needed anymore.
    - Used correct get_metadata api
    
    * valueinfoproto fix, googlenet test added
    
    * Removed redundant code.
    - run_node
    - Using correct get_metadata_api
    
    * dilation added
    
    * Lint fixes
    
    * lint fixes
    
    * some fixes to make export work with onx1.2.1
    
    * enabled more tests
    
    * mxnet_export_test file added
    
    * duplicate file deleted
    
    * reduce ops added
    
    * some small fixes
    
    * some lint fixes
    
    * Add tests for inception_v1 and inception_v2
    
    * Add CI runs for export module
    
    * docstring added
    
    * lint fixes, pooling attr fix
    
    * fix
    
    * fix global_pool
    
    * CI  run fix
    
    * code cleanup
    
    * lint fix
    
    * some code cleanup
    
    * pad in pooling added
    
    * slicechannel notimplementederror raised
    
    * Added required license comments
    
    * Lint fixes
    
    * lint fix
    
    * lint fix
    
    * lint fix
    
    * lint fix
    
    * Correct license statement
    
    * Adding onnx a runtime dependency
    
    * Fix import module error for string_types
    
    * Making ONNX runtime dependency
    
    * fixing some comments
    
    * addressing some comments
    
    * params rename
    
    * lint fixes
    
    * fixes
    
    * spatial disabled, path fixed
    
    * fixing some comments
    
    * Added support for remaining act_type(softsign, sigmoid, softrelu) in 
Activation operator
    
    * changing import
    
    * adding some comments
    
    * Add squeeze op
    
    * Refactored logic to handle extra node(output label node) for saved mxnet 
model
    Added comments
    
    * minor fix for squeeze operator.
    Also, added error handling
    
    * identity operator added
    
    * scalar ops added
    
    * Renamed onnx support folders to mark it public folders
    Changed underline files public or private as per usage
    
    Resolved conflicts with the latest
    
    * Added support L2Normalization op
    Added some error checking
    
    * added comments and warning
    
    * added comments and warning
    
    * doc API ref added
---
 LICENSE                                            |   52 +-
 ci/docker/runtime_functions.sh                     |    2 +
 docs/api/python/contrib/onnx.md                    |    2 +
 python/mxnet/contrib/onnx/__init__.py              |    5 +-
 python/mxnet/contrib/onnx/mx2onnx/LICENSE          |   44 +
 .../contrib/onnx/{_import => mx2onnx}/__init__.py  |   10 +-
 .../mxnet/contrib/onnx/mx2onnx/_export_helper.py   |   65 +
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1863 ++++++++++++++++++++
 python/mxnet/contrib/onnx/mx2onnx/export_model.py  |   95 +
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   |  347 ++++
 .../contrib/onnx/{_import => onnx2mx}/__init__.py  |    0
 .../import_helper.py => onnx2mx/_import_helper.py} |   39 +-
 .../_op_translations.py}                           |    6 +-
 .../_translation_utils.py}                         |    1 +
 .../onnx/{_import => onnx2mx}/import_model.py      |    0
 .../onnx/{_import => onnx2mx}/import_onnx.py       |    2 +-
 .../onnx/{_import => onnx2mx}/import_to_gluon.py   |    0
 .../{import/mxnet_backend.py => export/backend.py} |   49 +-
 .../mxnet_backend_rep.py => export/backend_rep.py} |    8 +-
 .../python-pytest/onnx/export/mxnet_export_test.py |  191 ++
 .../test_cases.py => export/onnx_backend_test.py}  |   61 +-
 tests/python-pytest/onnx/import/gluon_backend.py   |    6 +-
 tests/python-pytest/onnx/import/mxnet_backend.py   |    3 +-
 .../python-pytest/onnx/import/mxnet_backend_rep.py |    1 -
 tests/python-pytest/onnx/import/test_cases.py      |    1 +
 25 files changed, 2783 insertions(+), 70 deletions(-)

diff --git a/LICENSE b/LICENSE
index 158bd37..a8b57e5 100644
--- a/LICENSE
+++ b/LICENSE
@@ -298,8 +298,6 @@
     (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 
THIS
     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-
-
     
=======================================================================================
     Other Licenses
     
=======================================================================================
@@ -512,3 +510,53 @@
     For details, see, 3rdparty/dmlc-core/include/dmlc/concurrentqueue.h
 
     
=======================================================================================
+
+    11. ONNX Export module
+    For details, see, python/mxnet/contrib/onnx/_export/LICENSE
+
+    # Licensed to the Apache Software Foundation (ASF) under one
+    # or more contributor license agreements.  See the NOTICE file
+    # distributed with this work for additional information
+    # regarding copyright ownership.  The ASF licenses this file
+    # to you under the Apache License, Version 2.0 (the
+    # "License"); you may not use this file except in compliance
+    # with the License.  You may obtain a copy of the License at
+    #
+    #   http://www.apache.org/licenses/LICENSE-2.0
+    #
+    # Unless required by applicable law or agreed to in writing,
+    # software distributed under the License is distributed on an
+    # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    # KIND, either express or implied.  See the License for the
+    # specific language governing permissions and limitations
+    # under the License.
+    #
+    # Based on
+    # https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/#
+    #  Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+    #
+    #  Redistribution and use in source and binary forms, with or without
+    #  modification, are permitted provided that the following conditions
+    #  are met:
+    #  * Redistributions of source code must retain the above copyright
+    #    notice, this list of conditions and the following disclaimer.
+    #  * Redistributions in binary form must reproduce the above copyright
+    #    notice, this list of conditions and the following disclaimer in the
+    #    documentation and/or other materials provided with the distribution.
+    #  * Neither the name of NVIDIA CORPORATION nor the names of its
+    #    contributors may be used to endorse or promote products derived
+    #    from this software without specific prior written permission.
+    #
+    #  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+    #  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+    #  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+    #  PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+    #  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+    #  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+    #  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+    #  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+    #  OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+    #  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+    #  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 6e6abf0..0798047 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -700,6 +700,8 @@ integrationtest_ubuntu_cpu_onnx() {
        pytest tests/python-pytest/onnx/import/mxnet_backend_test.py
        pytest tests/python-pytest/onnx/import/onnx_import_test.py
        pytest tests/python-pytest/onnx/import/gluon_backend_test.py
+       pytest tests/python-pytest/onnx/export/onnx_backend_test.py
+       python tests/python-pytest/onnx/export/mxnet_export_test.py
 }
 
 integrationtest_ubuntu_gpu_python() {
diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md
index 6fb546f..8cd6198 100644
--- a/docs/api/python/contrib/onnx.md
+++ b/docs/api/python/contrib/onnx.md
@@ -24,6 +24,7 @@ This document describes all the ONNX-MXNet APIs.
 
     mxnet.contrib.onnx.import_model
     mxnet.contrib.onnx.get_model_metadata
+    mxnet.contrib.onnx.export_model
 ```
 
 ## ONNX Tutorials
@@ -46,6 +47,7 @@ This document describes all the ONNX-MXNet APIs.
 .. automodule:: mxnet.contrib.onnx
     :members: import_model
     :members: get_model_metadata
+    :members: export_model
 
 ```
 
diff --git a/python/mxnet/contrib/onnx/__init__.py 
b/python/mxnet/contrib/onnx/__init__.py
index 4f9296d..9f27060 100644
--- a/python/mxnet/contrib/onnx/__init__.py
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -16,5 +16,6 @@
 # under the License.
 """Module for ONNX model format support for Apache MXNet."""
 
-from ._import.import_model import import_model, get_model_metadata
-from ._import.import_to_gluon import import_to_gluon
+from .onnx2mx.import_model import import_model, get_model_metadata
+from .onnx2mx.import_to_gluon import import_to_gluon
+from .mx2onnx.export_model import export_model
diff --git a/python/mxnet/contrib/onnx/mx2onnx/LICENSE 
b/python/mxnet/contrib/onnx/mx2onnx/LICENSE
new file mode 100644
index 0000000..3abe1ee
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/LICENSE
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/#
+#  Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions
+#  are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+#  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+#  PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+#  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+#  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+#  OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py 
b/python/mxnet/contrib/onnx/mx2onnx/__init__.py
similarity index 84%
copy from python/mxnet/contrib/onnx/_import/__init__.py
copy to python/mxnet/contrib/onnx/mx2onnx/__init__.py
index d0411df..238174e 100644
--- a/python/mxnet/contrib/onnx/_import/__init__.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/__init__.py
@@ -16,7 +16,9 @@
 # under the License.
 
 # coding: utf-8
-"""ONNX Import module"""
-from . import import_model
-from . import import_onnx
-from . import import_to_gluon
+"""ONNX Export module"""
+from __future__ import absolute_import
+
+from . import export_model
+from . import export_onnx
+from . import _op_translations
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py 
b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py
new file mode 100644
index 0000000..781fb4c
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""export helper functions"""
+# coding: utf-8
+import os
+import logging
+import mxnet as mx
+
+
+def load_module(sym_filepath, params_filepath):
+    """Loads the MXNet model file and
+    returns MXNet symbol and params (weights).
+
+    Parameters
+    ----------
+    json_path : str
+        Path to the json file
+    params_path : str
+        Path to the params file
+
+    Returns
+    -------
+    sym : MXNet symbol
+        Model symbol object
+
+    params : params object
+        Model weights including both arg and aux params.
+    """
+    if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)):
+        raise ValueError("Symbol and params files provided are invalid")
+    else:
+        try:
+            # reads symbol.json file from given path and
+            # retrieves model prefix and number of epochs
+            model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0]
+            params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1)
+            # Setting num_epochs to 0 if not present in filename
+            num_epochs = 0 if len(params_file_list) == 1 else 
int(params_file_list[1])
+        except IndexError:
+            logging.info("Model and params name should be in format: "
+                         "prefix-symbol.json, prefix-epoch.params")
+            raise
+
+        sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, 
num_epochs)
+
+        # Merging arg and aux parameters
+        params = {}
+        params.update(arg_params)
+        params.update(aux_params)
+
+        return sym, params
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
new file mode 100644
index 0000000..5f5561a
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -0,0 +1,1863 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+#  https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/
+# mx2onnx_converter_functions.py
+#  Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions
+#  are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+#  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+#  PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+#  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+#  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+#  OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# coding: utf-8
+# pylint: disable=too-many-locals,no-else-return,too-many-lines
+# pylint: disable=anomalous-backslash-in-string,eval-used
+"""
+Conversion Functions for common layers.
+Add new functions here with a decorator.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import re
+import logging
+import numpy as np
+from .export_onnx import MXNetGraph as mx_op
+
+def import_onnx_modules():
+    """ To make sure ONNX is runtime dependency, it is imported used only when 
needed"""
+    try:
+        from onnx import helper, numpy_helper, mapping
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. "
+                          + "Instructions to install - 
https://github.com/onnx/onnx";)
+    return helper, numpy_helper, mapping
+
+
+def parse_helper(attrs, attrs_name, alt_value=None):
+    """Helper function to parse operator attributes in required format."""
+    tuple_re = re.compile('\([0-9L|,| ]+\)')
+    if attrs is None:
+        return alt_value
+    attrs_str = None if attrs.get(attrs_name) is None else 
str(attrs.get(attrs_name))
+    if attrs_str is None:
+        return alt_value
+    attrs_match = tuple_re.search(attrs_str)
+    if attrs_match is not None:
+        if attrs_match.span() == (0, len(attrs_str)):
+            dims = eval(attrs_str)
+            return dims
+        else:
+            raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, 
str(attrs_str)))
+    return alt_value
+
+def transform_padding(pad_width):
+    """Helper function to convert padding format for pad operator.
+    """
+    num_pad_values = len(pad_width)
+    onnx_pad_width = [0]*num_pad_values
+
+    start_index = 0
+    # num_pad_values will always be multiple of 2
+    end_index = int(num_pad_values/2)
+    for idx in range(0, num_pad_values):
+        if idx % 2 == 0:
+            onnx_pad_width[start_index] = pad_width[idx]
+            start_index += 1
+        else:
+            onnx_pad_width[end_index] = pad_width[idx]
+            end_index += 1
+
+    return onnx_pad_width
+
+
+def convert_string_to_list(string_val):
+    """Helper function to convert string to list.
+     Used to convert shape attribute string to list format.
+    """
+    result_list = []
+
+    list_string = string_val.split(',')
+    for val in list_string:
+        val = str(val.strip())
+        val = val.replace("(", "")
+        val = val.replace(")", "")
+        val = val.replace("L", "")
+        val = val.replace("[", "")
+        val = val.replace("]", "")
+        if val != "" and val != "None":
+            result_list.append(int(val))
+
+    return result_list
+
+@mx_op.register("null")
+def convert_weights_and_inputs(node, **kwargs):
+    """Helper function to convert weights and inputs.
+    """
+
+    helper, _, mapping = import_onnx_modules()
+    name = node["name"]
+
+    if kwargs["is_input"] is False:
+        weights = kwargs["weights"]
+        initializer = kwargs["initializer"]
+        np_arr = weights[name]
+        data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
+        dims = np.shape(np_arr)
+
+        tensor_node = helper.make_tensor_value_info(name, data_type, dims)
+
+        initializer.append(
+            helper.make_tensor(
+                name=name,
+                data_type=data_type,
+                dims=dims,
+                vals=np_arr.flatten().tolist(),
+                raw=False,
+            )
+        )
+
+        return [tensor_node]
+    else:
+        tval_node = helper.make_tensor_value_info(name, kwargs["in_type"], 
kwargs["in_shape"])
+        return [tval_node]
+
+
+@mx_op.register("Convolution")
+def convert_convolution(node, **kwargs):
+    """Map MXNet's convolution operator attributes to onnx's Conv operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+
+    num_inputs = len(inputs)
+
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
+    weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name
+
+    if num_inputs > 2:
+        bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name
+
+    attrs = node.get("attrs")
+
+    kernel_dims = list(parse_helper(attrs, "kernel"))
+    stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
+    pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+    num_group = int(attrs.get("num_group", 1))
+    dilations = list(parse_helper(attrs, "dilate", [1, 1]))
+
+    pad_dims = pad_dims + pad_dims
+
+    input_nodes = [input_node, weights_node]
+    if num_inputs > 2:
+        input_nodes.append(bias_node)
+
+    conv_node = helper.make_node(
+        "Conv",
+        inputs=input_nodes,
+        outputs=[name],
+        kernel_shape=kernel_dims,
+        strides=stride_dims,
+        dilations=dilations,
+        pads=pad_dims,
+        group=num_group,
+        name=name
+    )
+
+    return [conv_node]
+
+
+@mx_op.register("FullyConnected")
+def convert_fully_connected(node, **kwargs):
+    """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    weight_node_id = kwargs["index_lookup"][inputs[1][0]]
+    bias_node_id = kwargs["index_lookup"][inputs[2][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_node_id]
+    weights_node = proc_nodes[weight_node_id]
+    bias_node = proc_nodes[bias_node_id]
+
+    input_name = input_node.name
+    weights_name = weights_node.name
+    bias_name = bias_node.name
+
+    node = helper.make_node(
+        "Gemm",
+        [input_name, weights_name, bias_name],  # input (A, B, C) - C can be 
in place
+        [name],  # output
+        alpha=1.0,
+        beta=1.0,
+        transA=False,
+        transB=True,
+        name=name
+    )
+
+    return [node]
+
+
+@mx_op.register("BatchNorm")
+def convert_batchnorm(node, **kwargs):
+    """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization 
operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    attrs = node["attrs"]
+    momentum = float(node.get("attrs", {}).get("momentum", 0.9))
+    eps = float(attrs.get("eps", 0.001))
+
+    data_idx = kwargs["index_lookup"][inputs[0][0]]
+    gamma_idx = kwargs["index_lookup"][inputs[1][0]]
+    beta_idx = kwargs["index_lookup"][inputs[2][0]]
+    moving_mean_idx = kwargs["index_lookup"][inputs[3][0]]
+    moving_var_idx = kwargs["index_lookup"][inputs[4][0]]
+
+    data_node = proc_nodes[data_idx].name
+    gamma_node = proc_nodes[gamma_idx].name
+    beta_node = proc_nodes[beta_idx].name
+
+    mov_mean_node = proc_nodes[moving_mean_idx]
+    mov_mean_node = mov_mean_node.name
+    mov_var_node = proc_nodes[moving_var_idx].name
+
+    bn_node = helper.make_node(
+        "BatchNormalization",
+        [data_node,
+         gamma_node,  # scale
+         beta_node,  # bias
+         mov_mean_node,
+         mov_var_node
+        ],
+        [name],
+        name=name,
+        epsilon=eps,
+        momentum=momentum,
+        # MXNet computes mean and variance per feature for batchnorm
+        # Default for onnx is across all spatial features. So disabling the 
parameter.
+        spatial=0
+    )
+    return [bn_node]
+
+
+@mx_op.register("tanh")
+def convert_tanh(node, **kwargs):
+    """Map MXNet's tanh operator attributes to onnx's Tanh operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_node_idx].name
+
+    node = helper.make_node(
+        'Tanh',
+        [input_node],
+        [name],
+        name=name
+    )
+    return [node]
+
+#Basic neural network functions
+@mx_op.register("sigmoid")
+def convert_sigmoid(node, **kwargs):
+    """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_node_idx].name
+
+    node = helper.make_node(
+        'Sigmoid',
+        [input_node],
+        [name],
+        name=name
+    )
+    return [node]
+
+@mx_op.register("relu")
+def convert_relu(node, **kwargs):
+    """Map MXNet's relu operator attributes to onnx's Relu operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_node_idx].name
+
+    node = helper.make_node(
+        'Relu',
+        [input_node],
+        [name],
+        name=name
+    )
+
+    return [node]
+
+@mx_op.register("Activation")
+def convert_activation(node, **kwargs):
+    """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+
+    proc_nodes = kwargs["proc_nodes"]
+    attrs = node["attrs"]
+    act_type = attrs["act_type"]
+
+    inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_idx].output[0]
+
+    # Creating a dictionary here, but if this titlecase pattern
+    # mxnet_name.title()
+    act_types = {
+        "tanh": "Tanh",
+        "relu": "Relu",
+        "sigmoid": "Sigmoid",
+        "softrelu": "Softplus",
+        "softsign": "Softsign"
+    }
+
+    act_name = act_types.get(act_type)
+    if act_name:
+        node = helper.make_node(
+            act_name,
+            [input_node],
+            [name],
+            name=name
+        )
+    else:
+        raise AttributeError(
+            "Activation %s not implemented or recognized in the converter" % 
act_type
+        )
+
+    return [node]
+
+
+@mx_op.register("Pad")
+def convert_pad(node, **kwargs):
+    """Map MXNet's pad operator attributes to onnx's Pad operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    attrs = node["attrs"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_idx].name
+
+    mxnet_pad_width = convert_string_to_list(attrs.get("pad_width"))
+    onnx_pad_width = transform_padding(mxnet_pad_width)
+
+    pad_mode = attrs.get("mode")
+
+    if pad_mode == "constant":
+        pad_value = float(attrs.get("constant_value")) \
+            if "constant_value" in attrs else 0.0
+        node = helper.make_node(
+            'Pad',
+            inputs=[input_node],
+            outputs=[name],
+            mode='constant',
+            value=pad_value,
+            pads=onnx_pad_width,
+            name=name
+        )
+    else:
+        node = helper.make_node(
+            'Pad',
+            inputs=[input_node],
+            outputs=[name],
+            mode=pad_mode,
+            pads=onnx_pad_width,
+            name=name
+        )
+
+    return [node]
+
+
+@mx_op.register("_linalg_gemm2")
+def convert_linalg_gemm2(node, **kwargs):
+    """Map MXNet's _linalg_gemm2 operator attributes to onnx's
+    MatMul and Transpose operators based on the values set for
+    transpose_a, transpose_b attributes.
+    Return multiple nodes created.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    node_inputs = node["inputs"]
+    name = node["name"]
+
+    input_a_idx = kwargs["index_lookup"][node_inputs[0][0]]
+    input_node_a = proc_nodes[input_a_idx].name
+    input_b_idx = kwargs["index_lookup"][node_inputs[1][0]]
+    input_node_b = proc_nodes[input_b_idx].name
+
+    # Getting the attributes and assigning default values.
+    if "attrs" in node:
+        attrs = node["attrs"]
+        alpha = float(attrs["alpha"])
+        trans_a = int(attrs["transpose_a"])
+        trans_b = int(attrs["transpose_b"])
+    else:
+        alpha = 1.0
+        trans_a = 0
+        trans_b = 0
+
+    op_name = "transpose" + str(kwargs["idx"])
+
+    if alpha == 1.0 and trans_a == 0 and trans_b == 0:
+        matmul_node = helper.make_node(
+            'MatMul',
+            inputs=[input_node_a, input_node_b],
+            outputs=[name],
+            name=name
+        )
+        return [matmul_node]
+    elif trans_a == 1 and trans_b == 0:
+        op_name = "transpose" + str(kwargs["idx"])
+        node_name = op_name+"_a"
+        trans_a_node = helper.make_node(
+            'Transpose',
+            inputs=[input_node_a],
+            outputs=[op_name+"_a"],
+            name=node_name
+        )
+
+        matmul_node = helper.make_node(
+            'MatMul',
+            inputs=[node_name, input_node_b],
+            outputs=[name],
+            name=name
+        )
+        return [trans_a_node, matmul_node]
+
+    elif trans_a == 0 and trans_b == 1:
+        node_name = op_name + "_b"
+        trans_b_node = helper.make_node(
+            'Transpose',
+            inputs=[input_node_b],
+            outputs=[op_name+"_b"],
+            name=node_name
+        )
+
+        matmul_node = helper.make_node(
+            'MatMul',
+            inputs=[input_node_a, node_name],
+            outputs=[name],
+            name=name
+        )
+
+        return [trans_b_node, matmul_node]
+    else:
+        node_name_a = op_name+"_a"
+        trans_a_node = helper.make_node(
+            'Transpose',
+            inputs=[input_node_a],
+            outputs=[op_name+"_a"],
+            name=node_name_a
+        )
+
+        node_name_b = op_name + "_b"
+        trans_b_node = helper.make_node(
+            'Transpose',
+            inputs=[input_node_b],
+            outputs=[op_name+"_b"],
+            name=node_name_b
+        )
+
+        matmul_node = helper.make_node(
+            'MatMul',
+            inputs=[node_name_a, node_name_b],
+            outputs=[name],
+            name=name
+        )
+
+        return [trans_a_node, trans_b_node, matmul_node]
+
+
+@mx_op.register("Pooling")
+def convert_pooling(node, **kwargs):
+    """Map MXNet's Pooling operator attributes to onnx's
+    MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators
+    based on the input node's attributes and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    attrs = node["attrs"]
+    kernel = eval(attrs["kernel"])
+    pool_type = attrs["pool_type"]
+    stride = eval(attrs["stride"]) if attrs.get("stride") else None
+    global_pool = True if "global_pool" in attrs and\
+                          attrs.get("global_pool") == "True" else False
+    node_inputs = node["inputs"]
+    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+    input_node = proc_nodes[input_node_idx]
+    name = node["name"]
+
+    pooling_convention = attrs.get('pooling_convention', 'valid')
+
+    if pooling_convention == 'full':
+        pooling_warning = "Pooling: ONNX currently doesn't support 
pooling_convention. " \
+                          "This might lead to shape or accuracy issues. " \
+                          "https://github.com/onnx/onnx/issues/549";
+
+        logging.warning(pooling_warning)
+
+    pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+    pad_dims = pad_dims + pad_dims
+    pool_types = {"max": "MaxPool", "avg": "AveragePool"}
+    global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
+
+    if global_pool:
+        node = helper.make_node(
+            global_pool_types[pool_type],
+            [input_node.name],  # input
+            [name],
+            name=name
+        )
+    else:
+        node = helper.make_node(
+            pool_types[pool_type],
+            [input_node.name],  # input
+            [name],
+            kernel_shape=kernel,
+            pads=pad_dims,
+            strides=stride,
+            name=name
+        )
+
+    return [node]
+
+
+@mx_op.register("exp")
+def convert_exp(node, **kwargs):
+    """Map MXNet's exp operator attributes to onnx's Exp operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Exp",
+        [input_node],
+        [name],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("_copy")
+def convert_identity(node, **kwargs):
+    """Map MXNet's _copy operator attributes to onnx's Identity operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Identity",
+        [input_node],
+        [name],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("LeakyReLU")
+def convert_leakyrelu(node, **kwargs):
+    """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu 
operators
+    based on the input node's attributes and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+    attrs = node["attrs"]
+
+    act_type = attrs.get("act_type", "LeakyRelu")
+    alpha = float(attrs.get("slope", 0.25))
+
+    act_name = {"elu": "Elu", "LeakyRelu": "LeakyRelu", "prelu": "PRelu"}
+
+    if act_type == "prelu":
+        alpha_node_index = kwargs["index_lookup"][inputs[1][0]]
+        alpha_node_name = proc_nodes[alpha_node_index].name
+
+        node = helper.make_node(
+            act_name[act_type],
+            inputs=[input_node, alpha_node_name],
+            outputs=[name],
+            name=name)
+    else:
+        node = helper.make_node(
+            act_name[act_type],
+            inputs=[input_node],
+            outputs=[name],
+            name=name,
+            alpha=alpha)
+
+    return [node]
+
+
+@mx_op.register("softmax")
+def convert_softmax(node, **kwargs):
+    """Map MXNet's softmax operator attributes to onnx's Softmax operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    inputs = node["inputs"]
+    input_idx = kwargs["index_lookup"][inputs[0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_idx]
+
+    name = node["name"]
+    axis = int(node.get("attrs", {}).get("axis", -1))
+
+    softmax_node = helper.make_node(
+        "Softmax",
+        [input_node.name],
+        [name],
+        axis=axis,
+        name=name
+    )
+
+    return [softmax_node]
+
+
+# There's also mx.sym.softmax(), which doesn't do cross-entropy loss,
+# just softmax for inference - hence the name convert_softmax_output.
+@mx_op.register("SoftmaxOutput")
+def convert_softmax_output(node, **kwargs):
+    """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    inputs = node["inputs"]
+    input1_idx = kwargs["index_lookup"][inputs[0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input1 = proc_nodes[input1_idx]
+    name = node["name"]
+
+    softmax_node = helper.make_node(
+        "Softmax",
+        [input1.output[0]],
+        [name],
+        axis=1,
+        name=name
+    )
+
+    return [softmax_node]
+
+
+@mx_op.register("Concat")
+def convert_concat(node, **kwargs):
+    """Map MXNet's Concat operator attributes to onnx's Concat operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+    proc_nodes = kwargs["proc_nodes"]
+    input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in 
inputs]
+    axis = int(node.get("attrs", {}).get("dim", 1))
+    concat_node = helper.make_node(
+        "Concat",
+        input_names,
+        [name],
+        axis=axis,
+        name=name
+    )
+    return [concat_node]
+
+
+@mx_op.register("transpose")
+def convert_transpose(node, **kwargs):
+    """Map MXNet's transpose operator attributes to onnx's Transpose operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_idx].name
+    axes = node.get("attrs", {}).get("axes", ())
+    if axes:
+        axes = tuple(map(int, re.findall(r'\d+', axes)))
+
+        transpose_node = helper.make_node(
+            "Transpose",
+            [input_node],
+            [name],
+            perm=axes,
+            name=name
+        )
+    else:
+        transpose_node = helper.make_node(
+            "Transpose",
+            [input_node],
+            [name],
+            name=name
+        )
+
+    return [transpose_node]
+
+
+@mx_op.register("LRN")
+def convert_lrn(node, **kwargs):
+    """Map MXNet's LRN operator attributes to onnx's LRN operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_idx].name
+
+    attrs = node["attrs"]
+    alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001
+    beta = float(attrs["beta"]) if "beta" in attrs else 0.75
+    bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0
+    size = int(attrs["nsize"])
+
+    lrn_node = helper.make_node(
+        "LRN",
+        inputs=[input_node],
+        outputs=[name],
+        name=name,
+        alpha=alpha,
+        beta=beta,
+        bias=bias,
+        size=size
+    )
+
+    return [lrn_node]
+
+
+@mx_op.register("L2Normalization")
+def convert_l2normalization(node, **kwargs):
+    """Map MXNet's L2Normalization operator attributes to onnx's 
LpNormalization operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    input_id = kwargs["index_lookup"][node["inputs"][0][0]]
+    input_name = kwargs["proc_nodes"][input_id].name
+    attrs = node["attrs"]
+    mode = attrs.get("mode", "instance")
+
+    if mode != "channel":
+        raise AttributeError("ONNX currently supports channel mode only")
+
+    l2norm_node = helper.make_node(
+        "LpNormalization",
+        [input_name],
+        [name],
+        axis=1,  # channel only
+        name=name
+    )
+    return [l2norm_node]
+
+
+@mx_op.register("Dropout")
+def convert_dropout(node, **kwargs):
+    """Map MXNet's Dropout operator attributes to onnx's Dropout operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    input_id = kwargs["index_lookup"][node["inputs"][0][0]]
+    input_name = kwargs["proc_nodes"][input_id].name
+    attrs = node["attrs"]
+    probability = float(attrs["p"])
+
+    dropout_node = helper.make_node(
+        "Dropout",
+        [input_name],
+        [name],
+        ratio=probability,
+        name=name
+    )
+    return [dropout_node]
+
+
+@mx_op.register("Flatten")
+def convert_flatten(node, **kwargs):
+    """Map MXNet's Flatten operator attributes to onnx's Flatten operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[input_idx].name  # .output[0]
+
+    flatten_node = helper.make_node(
+        "Flatten",
+        [input_node],
+        [name],
+        name=name
+    )
+    return [flatten_node]
+
+
+def scalar_op_helper(node, op_name, **kwargs):
+    """Helper function for scalar arithmetic operations"""
+    helper, numpy_helper, mapping = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    scalar_value = [float(node.get("attrs", {}).get("scalar", 1))]
+
+    input_name_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_name_id].name
+
+    initializer = kwargs["initializer"]
+    flag = True
+    # If the input value is in initializer, just multiply with scalar input
+    # and create a new initializer
+    for i in initializer:
+        if i.name == input_node:
+            if op_name == 'Mul':
+                new_initializer = numpy_helper.to_array(i) * scalar_value[0]
+            elif op_name == 'Sub':
+                new_initializer = numpy_helper.to_array(i) - scalar_value[0]
+            elif op_name == 'Add':
+                new_initializer = numpy_helper.to_array(i) + scalar_value[0]
+            elif op_name == 'Div':
+                new_initializer = numpy_helper.to_array(i) / scalar_value[0]
+            flag = False
+            break
+
+    # else create a new tensor of the scalar value, add it in initializer
+    if flag is True:
+        np_arr = np.array(scalar_value)
+        data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
+        dims = np.shape(np_arr)
+
+        scalar_op_name = "scalar_op" + str(kwargs["idx"])
+        tensor_node = helper.make_tensor_value_info(scalar_op_name, data_type, 
dims)
+
+        initializer.append(
+            helper.make_tensor(
+                name=scalar_op_name,
+                data_type=data_type,
+                dims=dims,
+                vals=scalar_value,
+                raw=False,
+            )
+        )
+
+        mul_node = helper.make_node(
+            op_name,
+            [input_node, scalar_op_name],
+            [name],
+            name=name
+        )
+
+        return [tensor_node, mul_node]
+    else:
+        data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype]
+        dims = np.shape(new_initializer)
+
+        new_a_node = input_node + str(kwargs["idx"])
+        tensor_node = helper.make_tensor_value_info(new_a_node, data_type, 
dims)
+
+        initializer.append(
+            helper.make_tensor(
+                name=new_a_node,
+                data_type=data_type,
+                dims=dims,
+                vals=new_initializer,
+                raw=False,
+            )
+        )
+        return [tensor_node]
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_mul_scalar")
+def convert_mul_scalar(node, **kwargs):
+    """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Mul', **kwargs)
+
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_minus_scalar")
+def convert_minus_scalar(node, **kwargs):
+    """Map MXNet's _minus_scalar operator attributes to onnx's Minus operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Sub', **kwargs)
+
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_plus_scalar")
+def convert_add_scalar(node, **kwargs):
+    """Map MXNet's _plus_scalar operator attributes to onnx's Add operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Add', **kwargs)
+
+# Convert scalar value into node and pass it as input to mul_node
+@mx_op.register("_div_scalar")
+def convert_div_scalar(node, **kwargs):
+    """Map MXNet's _div_scalar operator attributes to onnx's Div operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Div', **kwargs)
+
+
+# Sorting and Searching
+@mx_op.register("argmax")
+def convert_argmax(node, **kwargs):
+    """Map MXNet's argmax operator attributes to onnx's ArgMax operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    node_inputs = node["inputs"]
+
+    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+    input_node = proc_nodes[input_node_idx].name
+    name = node["name"]
+    attrs = node["attrs"]
+
+    axis = int(attrs.get("axis"))
+    keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
+
+    node = helper.make_node(
+        'ArgMax',
+        inputs=[input_node],
+        axis=axis,
+        keepdims=keepdims,
+        outputs=[name],
+        name=name
+    )
+    return [node]
+
+@mx_op.register("argmin")
+def convert_argmin(node, **kwargs):
+    """Map MXNet's argmin operator attributes to onnx's ArgMin operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    node_inputs = node["inputs"]
+
+    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
+    input_node = proc_nodes[input_node_idx].name
+    name = node["name"]
+    attrs = node["attrs"]
+
+    axis = int(attrs.get("axis"))
+    keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
+
+    node = helper.make_node(
+        'ArgMin',
+        inputs=[input_node],
+        axis=axis,
+        keepdims=keepdims,
+        outputs=[name],
+        name=name
+    )
+    return [node]
+
+@mx_op.register("_maximum")
+def convert_maximum(node, **kwargs):
+    """Map MXNet's _maximum operator attributes to onnx's Max operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    node_inputs = node["inputs"]
+
+    input_node_list = []
+    for node_input in node_inputs:
+        node_id = kwargs["index_lookup"][node_input[0]]
+        input_node_list.append(proc_nodes[node_id].name)
+
+    name = node["name"]
+
+    node = helper.make_node(
+        'Max',
+        inputs=input_node_list,
+        outputs=[name],
+        name=name,
+    )
+
+    return [node]
+
+
+@mx_op.register("_minimum")
+def convert_minimum(node, **kwargs):
+    """Map MXNet's _minimum operator attributes to onnx's Min operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    proc_nodes = kwargs["proc_nodes"]
+    node_inputs = node["inputs"]
+
+    input_node_list = []
+    for node_input in node_inputs:
+        node_id = kwargs["index_lookup"][node_input[0]]
+        input_node_list.append(proc_nodes[node_id].name)
+
+    name = node["name"]
+
+    node = helper.make_node(
+        'Min',
+        inputs=input_node_list,
+        outputs=[name],
+        name=name,
+    )
+
+    return [node]
+
+
+@mx_op.register("min")
+def convert_min(node, **kwargs):
+    """Map MXNet's min operator attributes to onnx's ReduceMin operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    mx_axis = node.get("attrs", {}).get("axis", None)
+    axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
+
+    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    if axes is not None:
+        node = helper.make_node(
+            'ReduceMin',
+            inputs=[input_node],
+            outputs=[name],
+            axes=axes,
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+    else:
+        node = helper.make_node(
+            'ReduceMin',
+            inputs=[input_node],
+            outputs=[name],
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+
+
+@mx_op.register("max")
+def convert_max(node, **kwargs):
+    """Map MXNet's max operator attributes to onnx's ReduceMax operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    mx_axis = node.get("attrs", {}).get("axis", None)
+    axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
+
+    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    if axes is not None:
+        node = helper.make_node(
+            'ReduceMax',
+            inputs=[input_node],
+            outputs=[name],
+            axes=axes,
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+    else:
+        node = helper.make_node(
+            'ReduceMax',
+            inputs=[input_node],
+            outputs=[name],
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+
+
+@mx_op.register("mean")
+def convert_mean(node, **kwargs):
+    """Map MXNet's mean operator attributes to onnx's ReduceMean operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    mx_axis = node.get("attrs", {}).get("axis", None)
+    axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
+
+    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    if axes is not None:
+        node = helper.make_node(
+            'ReduceMean',
+            inputs=[input_node],
+            outputs=[name],
+            axes=axes,
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+    else:
+        node = helper.make_node(
+            'ReduceMean',
+            inputs=[input_node],
+            outputs=[name],
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+
+
+@mx_op.register("prod")
+def convert_prod(node, **kwargs):
+    """Map MXNet's prod operator attributes to onnx's ReduceProd operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    mx_axis = node.get("attrs", {}).get("axis", None)
+    axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
+
+    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    if axes is not None:
+        node = helper.make_node(
+            'ReduceProd',
+            inputs=[input_node],
+            outputs=[name],
+            axes=axes,
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+    else:
+        node = helper.make_node(
+            'ReduceProd',
+            inputs=[input_node],
+            outputs=[name],
+            keepdims=keepdims,
+            name=name
+        )
+
+        return [node]
+
+
+# Arithmetic Operations
+@mx_op.register("elemwise_add")
+def convert_elementwise_add(node, **kwargs):
+    """Map MXNet's elemwise_add operator attributes to onnx's Add operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    add_node = helper.make_node(
+        "Add",
+        [input_node_a, input_node_b],
+        [name],
+        name=name,
+    )
+
+    return [add_node]
+
+
+@mx_op.register("broadcast_add")
+def covert_broadcast_add(node, **kwargs):
+    """Map MXNet's broadcast_add operator attributes to onnx's Add operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    add_node = helper.make_node(
+        "Add",
+        [input_node_a, input_node_b],
+        [name],
+        name=name,
+    )
+
+    return [add_node]
+
+
+@mx_op.register("elemwise_sub")
+def convert_elementwise_sub(node, **kwargs):
+    """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    sub_node = helper.make_node(
+        "Sub",
+        [input_node_a, input_node_b],
+        [name],
+        name=name,
+    )
+
+    return [sub_node]
+
+@mx_op.register("broadcast_sub")
+def covert_broadcast_sub(node, **kwargs):
+    """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    sub_node = helper.make_node(
+        "Sub",
+        [input_node_a, input_node_b],
+        [name],
+        name=name,
+    )
+
+    return [sub_node]
+
+
+@mx_op.register("elemwise_mul")
+def convert_elemwise_mul(node, **kwargs):
+    """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    mul_node = helper.make_node(
+        "Mul",
+        [input_node_a, input_node_b],
+        [name],
+        name=name,
+    )
+
+    return [mul_node]
+
+@mx_op.register("broadcast_mul")
+def convert_broadcast_mul(node, **kwargs):
+    """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    mul_node = helper.make_node(
+        "Mul",
+        [input_node_a, input_node_b],
+        [name],
+        name=name
+    )
+
+    return [mul_node]
+
+
+@mx_op.register("elemwise_div")
+def convert_elemwise_div(node, **kwargs):
+    """Map MXNet's elemwise_div operator attributes to onnx's Div operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    div_node = helper.make_node(
+        "Div",
+        [input_node_a, input_node_b],
+        [name],
+        name=name
+    )
+
+    return [div_node]
+
+
+@mx_op.register("broadcast_div")
+def convert_broadcast_div(node, **kwargs):
+    """Map MXNet's broadcast_div operator attributes to onnx's Div operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    div_node = helper.make_node(
+        "Div",
+        [input_node_a, input_node_b],
+        [name],
+        name=name
+    )
+
+    return [div_node]
+
+
+@mx_op.register("negative")
+def convert_negative(node, **kwargs):
+    """Map MXNet's negative operator attributes to onnx's Neg operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+
+    input_node = proc_nodes[input_node_id].name
+
+    neg_node = helper.make_node(
+        "Neg",
+        [input_node],
+        [name],
+        name=name,
+    )
+
+    return [neg_node]
+
+
+@mx_op.register("abs")
+def convert_abs(node, **kwargs):
+    """Map MXNet's abs operator attributes to onnx's Abs operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+
+    input_node = proc_nodes[input_node_id].name
+
+    abs_node = helper.make_node(
+        "Abs",
+        [input_node],
+        [name],
+        name=name
+    )
+
+    return [abs_node]
+
+
+@mx_op.register("add_n")
+def convert_addn(node, **kwargs):
+    """Map MXNet's add_n operator attributes to onnx's Sum operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_list = []
+    for input_val in inputs:
+        
input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name)
+
+    sum_node = helper.make_node(
+        "Sum",
+        input_list,
+        [name],
+        name=name
+    )
+    return [sum_node]
+
+ # Rounding
+@mx_op.register("ceil")
+def convert_ceil(node, **kwargs):
+    """Map MXNet's ceil operator attributes to onnx's Ceil operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Ceil",
+        [input_node],
+        [name],
+        name=name
+    )
+    return [node]
+
+@mx_op.register("floor")
+def convert_floor(node, **kwargs):
+    """Map MXNet's floor operator attributes to onnx's Floor operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Floor",
+        [input_node],
+        [name],
+        name=name
+    )
+    return [node]
+
+# Changing shape and type.
+@mx_op.register("Reshape")
+def convert_reshape(node, **kwargs):
+    """Map MXNet's Reshape operator attributes to onnx's Reshape operator.
+    Converts output shape attribute to output shape tensor
+    and return multiple created nodes.
+    """
+    helper, _, mapping = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    attrs = node["attrs"]
+
+    output_shape_list = convert_string_to_list(attrs["shape"])
+
+    initializer = kwargs["initializer"]
+    output_shape_np = np.array(output_shape_list)
+    data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype]
+    dims = np.shape(output_shape_np)
+
+    output_shape_name = "reshape_attr_tensor" + str(kwargs["idx"])
+    tensor_node = helper.make_tensor_value_info(output_shape_name, data_type, 
dims)
+
+    initializer.append(
+        helper.make_tensor(
+            name=output_shape_name,
+            data_type=data_type,
+            dims=dims,
+            vals=output_shape_list,
+            raw=False,
+        )
+    )
+
+    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
+    input_node_name = proc_nodes[input_node_idx].name
+
+    not_supported_shape = [-2, -3, -4]
+
+    for val in output_shape_list:
+        if val in not_supported_shape:
+            raise AttributeError("Shape value not supported in ONNX", val)
+
+    reshape_node = helper.make_node(
+        "Reshape",
+        [input_node_name, output_shape_name],
+        [name],
+        name=name
+    )
+
+    return [tensor_node, reshape_node]
+
+@mx_op.register("Cast")
+def convert_cast(node, **kwargs):
+    """Map MXNet's Cast operator attributes to onnx's Cast operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    dtype = node["attrs"]["dtype"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Cast",
+        [input_node],
+        [name],
+        to=dtype,
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("slice_axis")
+def convert_slice_axis(node, **kwargs):
+    """Map MXNet's slice_axis operator attributes to onnx's Slice operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    axes = int(node["attrs"]["axis"])
+    starts = int(node["attrs"]["begin"])
+    if node["attrs"]["end"] == 'None':
+        raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' 
attribute")
+    else:
+        ends = int(node["attrs"]["end"])
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Slice",
+        [input_node],
+        [name],
+        axes=[axes],
+        starts=[starts],
+        ends=[ends],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("SliceChannel")
+def convert_slice_channel(node, **kwargs):
+    """Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split
+    operator based on squeeze_axis attribute
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    num_outputs = int(node.get("attrs", {})["num_outputs"])
+    axis = int(node.get("attrs", {}).get("axis", 1))
+    squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0))
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    if squeeze_axis == 1 and num_outputs == 1:
+        node = helper.make_node(
+            "Squeeze",
+            [input_node],
+            [name],
+            axes=[axis],
+            name=name,
+        )
+        return [node]
+    elif squeeze_axis == 0 and num_outputs > 1:
+        node = helper.make_node(
+            "Split",
+            [input_node],
+            [name],
+            axis=axis,
+            split=[num_outputs],
+            name=name,
+        )
+        return [node]
+    else:
+        raise NotImplementedError("SliceChannel operator with num_outputs>1 
and"
+                                  "squeeze_axis true is not implemented.")
+
+
+@mx_op.register("expand_dims")
+def convert_expand_dims(node, **kwargs):
+    """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    axis = int(node["attrs"]["axis"])
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Unsqueeze",
+        [input_node],
+        [name],
+        axes=[axis],
+        name=name,
+    )
+    return [node]
+
+@mx_op.register("squeeze")
+def convert_squeeze(node, **kwargs):
+    """Map MXNet's squeeze operator attributes to onnx's squeeze operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+    if "axis" in node["attrs"]:
+        axis = convert_string_to_list(node["attrs"]["axis"])
+    else:
+        raise AttributeError("Missing axis attribute: ONNX currently requires 
axis to "
+                             "be specified for squeeze operator")
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Squeeze",
+        [input_node],
+        [name],
+        axes=axis,
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("log")
+def convert_log(node, **kwargs):
+    """Map MXNet's log operator attributes to onnx's Log operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Log",
+        [input_node],
+        [name],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("reciprocal")
+def convert_reciprocal(node, **kwargs):
+    """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Reciprocal",
+        [input_node],
+        [name],
+        name=name,
+    )
+    return [node]
+
+
+@mx_op.register("_power")
+def convert_power(node, **kwargs):
+    """Map MXNet's _power operator attributes to onnx's Pow operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
+
+    input_node_a = proc_nodes[input_node_a_id].name
+    input_node_b = proc_nodes[input_node_b_id].name
+
+    node = helper.make_node(
+        "Pow",
+        [input_node_a, input_node_b],
+        [name],
+        name=None
+    )
+    return [node]
+
+@mx_op.register("sqrt")
+def convert_sqrt(node, **kwargs):
+    """Map MXNet's sqrt operator attributes to onnx's Sqrt operator
+    and return the created node.
+    """
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    inputs = node["inputs"]
+
+    input_node_id = kwargs["index_lookup"][inputs[0][0]]
+    input_node = proc_nodes[input_node_id].name
+
+    node = helper.make_node(
+        "Sqrt",
+        [input_node],
+        [name],
+        name=name,
+    )
+    return [node]
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
new file mode 100644
index 0000000..0dbfdc1
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+#pylint: disable-msg=too-many-arguments
+
+"""export function"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+import numpy as np
+
+from ....base import string_types
+from .... import symbol
+from .export_onnx import MXNetGraph
+from ._export_helper import load_module
+
+
+def export_model(sym, params, input_shape, input_type=np.float32,
+                 onnx_file_path='model.onnx', verbose=False):
+    """Exports the MXNet model file, passed as a parameter, into ONNX model.
+    Accepts both symbol,parameter objects as well as json and params filepaths 
as input.
+    Operator support and coverage - 
https://cwiki.apache.org/confluence/display/MXNET/ONNX
+
+    Parameters
+    ----------
+    sym : str or symbol object
+        Path to the json file or Symbol object
+    params : str or symbol object
+        Path to the params file or params dictionary. (Including both 
arg_params and aux_params)
+    input_shape : List of tuple
+        Input shape of the model e.g [(1,3,224,224)]
+    input_type : data type
+        Input data type e.g. np.float32
+    onnx_file_path : str
+        Path where to save the generated onnx file
+    verbose : Boolean
+        If true will print logs of the model conversion
+
+    Returns
+    -------
+    onnx_file_path : str
+        Onnx file path
+    """
+
+    try:
+        from onnx import helper, mapping
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. "
+                          + "Instructions to install - 
https://github.com/onnx/onnx";)
+
+    converter = MXNetGraph()
+
+    data_format = np.dtype(input_type)
+    # if input parameters are strings(file paths), load files and create 
symbol parameter objects
+    if isinstance(sym, string_types) and isinstance(params, string_types):
+        logging.info("Converting json and weight file to sym and params")
+        sym_obj, params_obj = load_module(sym, params)
+        onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, 
input_shape,
+                                                       
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       verbose=verbose)
+    elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
+        onnx_graph = converter.create_onnx_graph_proto(sym, params, 
input_shape,
+                                                       
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       verbose=verbose)
+    else:
+        raise ValueError("Input sym and params should either be files or 
objects")
+
+    # Create the model (ModelProto)
+    onnx_model = helper.make_model(onnx_graph)
+
+    # Save model on disk
+    with open(onnx_file_path, "wb") as file_handle:
+        serialized = onnx_model.SerializeToString()
+        file_handle.write(serialized)
+        logging.info("Input shape of the model %s ", input_shape)
+        logging.info("Exported ONNX file %s saved to disk", onnx_file_path)
+
+    return onnx_file_path
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
new file mode 100644
index 0000000..1184738
--- /dev/null
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -0,0 +1,347 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Based on
+# 
https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/mx2onnx_converter.py#
+#  Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions
+#  are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+#  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+#  PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+#  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+#  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+#  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+#  OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# coding: utf-8
+# pylint: disable=invalid-name,too-many-locals,no-self-use,too-many-arguments,
+# pylint: disable=maybe-no-member,too-many-nested-blocks
+"""MXNet to ONNX graph converter functions"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+import json
+import numpy as np
+
+from .... import context
+from .... import ndarray as nd
+from .... import io
+from .... import module as mod
+
+
+class MXNetGraph(object):
+    """Class to convert MXNet to ONNX graph"""
+    registry_ = {}
+    input_output_maps_ = {}
+
+    def __init__(self):
+        # topologically sorted nodes
+        self.nodes = []
+        self.input_tensors = []
+        self.output_tensors = []
+
+    @staticmethod
+    def register(op_name):
+        """Register operators"""
+        def wrapper(func):
+            """Helper function to map functions"""
+            MXNetGraph.registry_[op_name] = func
+            return func
+
+        return wrapper
+
+    @staticmethod
+    def convert_layer(node, **kwargs):
+        """Convert MXNet layer to ONNX"""
+        op = str(node["op"])
+        if op not in MXNetGraph.registry_:
+            raise AttributeError("No conversion function registered for op 
type %s yet." % op)
+        convert_func = MXNetGraph.registry_[op]
+        return convert_func(node, **kwargs)
+
+    @staticmethod
+    def forward_pass(inputs, sym, arg_params, aux_params, output_label):
+        """Do a forward pass based on the sym and params to get the shape
+        of the output using dummy data
+
+        Parameters
+        ----------
+        inputs   : json string
+
+        sym : :class:`~mxnet.symbol.Symbol`
+            MXNet symbol object
+        arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+        aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+
+        Returns
+        -------
+        shape : Shape
+            Output shape
+        """
+        # if label is not provided, MXNet adds label "softmax_label" by default
+        # while running load_checkpoint which is not actually a graph input. 
So ignoring it here
+        data_names = [graph_input for graph_input in sym.list_inputs()
+                      if graph_input not in arg_params and graph_input not in 
aux_params
+                      and graph_input != output_label]
+
+        data_shapes = []
+        # Adding extra dimension of batch_size 1 if the batch_size is 
different for multiple inputs.
+        for idx, input_name in enumerate(data_names):
+            data_shapes.append((input_name, inputs[idx].shape))
+
+        # create module, passing cpu context
+        ctx = context.cpu()
+        test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, 
label_names=None)
+        test_mod.bind(for_training=False, data_shapes=data_shapes, 
label_shapes=None)
+
+        # initializing parameters for calculating result of each individual 
node
+        if arg_params is None and aux_params is None:
+            test_mod.init_params()
+        else:
+            test_mod.set_params(arg_params=arg_params, aux_params=aux_params, 
allow_missing=True)
+
+        data_forward = []
+        for idx, input_name in enumerate(data_names):
+            val = inputs[idx]
+            data_forward.append(nd.array(val))
+
+        test_mod.forward(io.DataBatch(data_forward))
+        result = test_mod.get_outputs()[0].asnumpy()
+
+        return result.shape
+
+
+    @staticmethod
+    def split_params(sym, params):
+        """Helper function to split params dictionary into args and aux params
+
+        Parameters
+        ----------
+        sym : :class:`~mxnet.symbol.Symbol`
+            MXNet symbol object
+        params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+
+        Returns
+        -------
+        arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+        aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+        """
+        arg_params = {}
+        aux_params = {}
+        for args in sym.list_arguments():
+            if args in params:
+                arg_params.update({args: nd.array(params[args])})
+        for aux in sym.list_auxiliary_states():
+            if aux in params:
+                aux_params.update({aux: nd.array(params[aux])})
+        return arg_params, aux_params
+
+
+    @staticmethod
+    def infer_output_shape(sym, params, in_shape, output_label):
+        """Infer output shape by doing a forward pass using dummy inputs """
+        # create dummy input
+        inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
+        arg, aux = MXNetGraph.split_params(sym, params)
+        return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+
+
+    @staticmethod
+    def convert_weights_to_numpy(weights_dict):
+        """Convert weights to numpy"""
+        return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
+                     for k, v in weights_dict.items()])
+
+    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, 
verbose=False):
+        """Convert MXNet graph to ONNX graph
+
+        Parameters
+        ----------
+        sym : :class:`~mxnet.symbol.Symbol`
+            MXNet symbol object
+        params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
+            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
+        in_shape : List of tuple
+            Input shape of the model e.g [(1,3,224,224)]
+        in_type : data type
+            Input data type e.g. np.float32
+        verbose : Boolean
+            If true will print logs of the model conversion
+
+        Returns
+        -------
+        graph : GraphProto
+            ONNX graph
+        """
+        try:
+            from onnx import (checker, helper, NodeProto, ValueInfoProto, 
TensorProto)
+            from onnx.helper import make_tensor_value_info
+        except ImportError:
+            raise ImportError("Onnx and protobuf need to be installed. "
+                              + "Instructions to install - 
https://github.com/onnx/onnx";)
+
+        # When MXNet model is saved to json file , MXNet adds a node for label.
+        # The name of this node is, name of the last node + "_label" ( i.e if 
last node
+        # name is "Softmax", this node will have a name "Softmax_label". Also, 
the new node
+        # will always be second last node in the json graph.
+        # Deriving the output_label name.
+        output_label = sym.get_internals()[len(sym.get_internals()) - 1].name 
+ "_label"
+
+        # Determine output shape
+        output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, 
output_label)
+
+        weights = MXNetGraph.convert_weights_to_numpy(params)
+
+        mx_graph = json.loads(sym.tojson())["nodes"]
+
+        initializer = []
+        all_processed_nodes = []
+        onnx_processed_nodes = []
+        onnx_processed_inputs = []
+        onnx_processed_outputs = []
+        index_lookup = []
+
+        graph_input_idx = 0
+        for idx, node in enumerate(mx_graph):
+            op = node["op"]
+            name = node["name"]
+            if verbose:
+                logging.info("Converting idx: %d, op: %s, name: %s", idx, op, 
name)
+
+            # A node is an input node if its op_name is "null" and is not
+            # in params dict
+            if op == "null" and name not in params:
+                # Handling graph input
+
+                # Skipping output_label node, as this node is not part of graph
+                # Refer "output_label" assignment above for more details.
+                if name == output_label:
+                    continue
+                converted = MXNetGraph.convert_layer(
+                    node,
+                    is_input=True,
+                    mx_graph=mx_graph,
+                    weights=weights,
+                    in_shape=in_shape[graph_input_idx],
+                    in_type=in_type,
+                    proc_nodes=all_processed_nodes,
+                    initializer=initializer,
+                    index_lookup=index_lookup)
+                graph_input_idx += 1
+
+            else:
+                # Handling graph layers
+                converted = MXNetGraph.convert_layer(
+                    node,
+                    is_input=False,
+                    mx_graph=mx_graph,
+                    weights=weights,
+                    in_shape=in_shape,
+                    in_type=in_type,
+                    proc_nodes=all_processed_nodes,
+                    initializer=initializer,
+                    index_lookup=index_lookup,
+                    idx=idx
+                )
+
+            if isinstance(converted, list):
+                # Iterate for all converted nodes
+                for converted_node in converted:
+                    # If converted node is ValueInfoProto, add it in inputs
+                    if isinstance(converted_node, ValueInfoProto):
+                        onnx_processed_inputs.append(converted_node)
+                    # If converted node is NodeProto, add it in processed 
nodes list
+                    elif isinstance(converted_node, NodeProto):
+                        onnx_processed_nodes.append(converted_node)
+                        if idx == (len(mx_graph) - 1):
+                            # If converted node doesnt have name, use it from 
output field
+                            if not converted_node.name:
+                                onnx_processed_outputs.append(
+                                    make_tensor_value_info(
+                                        name=converted_node.output[0],
+                                        elem_type=in_type,
+                                        shape=output_shape
+                                    )
+                                )
+                            else:
+                                onnx_processed_outputs.append(
+                                    make_tensor_value_info(
+                                        name=converted_node.name,
+                                        elem_type=in_type,
+                                        shape=output_shape
+                                    )
+                                )
+                            if verbose:
+                                logging.info("Output node is: %s", 
converted_node.name)
+                    elif isinstance(converted_node, TensorProto):
+                        raise ValueError("Did not expect TensorProto")
+                    else:
+                        raise ValueError("node is of an unrecognized type: %s" 
% type(node))
+
+                    all_processed_nodes.append(converted_node)
+
+                if idx > 0:
+                    # Handling extra node added to the graph if the MXNet 
model was
+                    # saved to json file,
+                    # refer "output_label" initialization above for more 
details.
+                    # if extra node was added then prev_index to the last node 
is adjusted.
+                    if idx == (len(mx_graph) - 1) and \
+                            mx_graph[len(mx_graph)-2]["name"] == output_label:
+                        prev_index = index_lookup[idx - 2]
+                    else:
+                        prev_index = index_lookup[idx - 1]
+
+                    index_lookup.append(prev_index+len(converted))
+                else:
+                    index_lookup.append(len(converted) - 1)
+            else:
+                logging.info("Operator converter function should always return 
a list")
+
+        graph = helper.make_graph(
+            onnx_processed_nodes,
+            "mxnet_converted_model",
+            onnx_processed_inputs,
+            onnx_processed_outputs
+        )
+
+        graph.initializer.extend(initializer)
+
+        checker.check_graph(graph)
+        return graph
diff --git a/python/mxnet/contrib/onnx/_import/__init__.py 
b/python/mxnet/contrib/onnx/onnx2mx/__init__.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/__init__.py
rename to python/mxnet/contrib/onnx/onnx2mx/__init__.py
diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py 
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
similarity index 74%
rename from python/mxnet/contrib/onnx/_import/import_helper.py
rename to python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index 3dfff3e..c19f0f2 100644
--- a/python/mxnet/contrib/onnx/_import/import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -15,27 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# coding: utf-8
+# coding: utf-8_
 # pylint: disable=invalid-name
 """Operator attributes conversion"""
-from .op_translations import identity, random_uniform, random_normal
-from .op_translations import add, subtract, multiply, divide, absolute, 
negative, add_n
-from .op_translations import tanh
-from .op_translations import ceil, floor
-from .op_translations import concat
-from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected
-from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm
-from .op_translations import sigmoid, pad, relu, matrix_multiplication, 
batch_norm
-from .op_translations import dropout, local_response_norm, conv, deconv
-from .op_translations import reshape, cast, split, _slice, transpose, squeeze, 
flatten
-from .op_translations import reciprocal, squareroot, power, exponent, _log, 
unsqueeze
-from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
-from .op_translations import reduce_prod, avg_pooling, max_pooling
-from .op_translations import argmax, argmin, maximum, minimum
-from .op_translations import clip, reduce_log_sum, reduce_log_sum_exp
-from .op_translations import reduce_sum_square, reduce_l2, max_roi_pooling, 
instance_norm
-from .op_translations import log_softmax, softsign, lesser, greater, equal
-from .op_translations import logical_and, logical_or, logical_xor, logical_not
+from ._op_translations import identity, random_uniform, random_normal
+from ._op_translations import add, subtract, multiply, divide, absolute, 
negative, add_n
+from ._op_translations import tanh
+from ._op_translations import ceil, floor
+from ._op_translations import concat
+from ._op_translations import leaky_relu, _elu, _prelu, softmax, 
fully_connected
+from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
+from ._op_translations import sigmoid, pad, relu, matrix_multiplication, 
batch_norm
+from ._op_translations import dropout, local_response_norm, conv, deconv
+from ._op_translations import reshape, cast, split, _slice, transpose, 
squeeze, flatten
+from ._op_translations import reciprocal, squareroot, power, exponent, _log, 
unsqueeze
+from ._op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
+from ._op_translations import reduce_prod, avg_pooling, max_pooling
+from ._op_translations import argmax, argmin, maximum, minimum
+from ._op_translations import clip, reduce_log_sum, reduce_log_sum_exp
+from ._op_translations import reduce_sum_square, reduce_l2, max_roi_pooling, 
instance_norm
+from ._op_translations import log_softmax, softsign, lesser, greater, equal
+from ._op_translations import logical_and, logical_or, logical_xor, logical_not
 
 # convert_map defines maps of ONNX operator names to converter 
functor(callable)
 # defined in the op_translations module.
@@ -89,6 +89,7 @@ _convert_map = {
     'Squeeze'           : squeeze,
     'Unsqueeze'         : unsqueeze,
     'Flatten'           : flatten,
+    'Identity'          : identity,
     #Powers
     'Reciprocal'        : reciprocal,
     'Sqrt'              : squareroot,
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py 
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/op_translations.py
rename to python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 0fad008..2b98aa0 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -19,7 +19,7 @@
 """ Module for translating ONNX operators into Mxnet operatoes"""
 # pylint: disable=unused-argument,protected-access
 import numpy as np
-from . import translation_utils
+from . import _translation_utils as translation_utils
 from .... import symbol
 
 # Method definitions for the callable objects mapped in the import_helper 
module
@@ -130,7 +130,7 @@ def maximum(attrs, inputs, proto_obj):
         for op_input in inputs[2:]:
             mxnet_op = symbol.maximum(mxnet_op, op_input)
     else:
-        mxnet_op = inputs[0]
+        mxnet_op = symbol.maximum(inputs[0], inputs[0])
     return mxnet_op, attrs, inputs
 
 def minimum(attrs, inputs, proto_obj):
@@ -143,7 +143,7 @@ def minimum(attrs, inputs, proto_obj):
         for op_input in inputs[2:]:
             mxnet_op = symbol.minimum(mxnet_op, op_input)
     else:
-        mxnet_op = inputs[0]
+        mxnet_op = symbol.minimum(inputs[0], inputs[0])
     return mxnet_op, attrs, inputs
 
 def lesser(attrs, inputs, proto_obj):
diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py 
b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/translation_utils.py
rename to python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index fe25a94..f63c1e9 100644
--- a/python/mxnet/contrib/onnx/_import/translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -168,6 +168,7 @@ def _fix_broadcast(op_name, inputs, broadcast_axis, 
proto_obj):
         op_sym = op_name
     return op_sym
 
+
 def _fix_channels(op_name, attrs, inputs, proto_obj):
     """A workaround for getting 'channels' or 'units' since onnx don't provide
     these attributes. We check the shape of weights provided to get the number.
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py 
b/python/mxnet/contrib/onnx/onnx2mx/import_model.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/import_model.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_model.py
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py 
b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
similarity index 99%
rename from python/mxnet/contrib/onnx/_import/import_onnx.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index d81ec96..4e85171 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -23,7 +23,7 @@ from .... import symbol
 from .... import cpu, gpu
 from .... import ndarray as nd
 from ....base import string_types
-from .import_helper import _convert_map as convert_map
+from ._import_helper import _convert_map as convert_map
 
 class GraphProto(object): # pylint: disable=too-few-public-methods
     """A helper class for handling mxnet symbol copying from pb2.GraphProto.
diff --git a/python/mxnet/contrib/onnx/_import/import_to_gluon.py 
b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
similarity index 100%
rename from python/mxnet/contrib/onnx/_import/import_to_gluon.py
rename to python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py 
b/tests/python-pytest/onnx/export/backend.py
similarity index 55%
copy from tests/python-pytest/onnx/import/mxnet_backend.py
copy to tests/python-pytest/onnx/export/backend.py
index bbe8899..e23cc01 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend.py
+++ b/tests/python-pytest/onnx/export/backend.py
@@ -16,25 +16,47 @@
 # under the License.
 
 # coding: utf-8
-"""MXNet backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
+"""backend wrapper for onnx test infrastructure"""
+import numpy as np
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph
 try:
-    from onnx import helper, TensorProto
+    from onnx import helper, TensorProto, mapping
     from onnx.backend.base import Backend
 except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
-                      + " install - https://github.com/onnx/onnx#installation";)
-from mxnet_backend_rep import MXNetBackendRep
+    raise ImportError("Onnx and protobuf need to be installed")
+from backend_rep import MXNetBackendRep
 
+# Using these functions for onnx test infrastructure.
+# Implemented by following onnx docs guide:
+# 
https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
 # MXNetBackend class will take an ONNX model with inputs, perform a 
computation,
 # and then return the output.
-# Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
 
 class MXNetBackend(Backend):
     """MXNet backend for ONNX"""
 
+    @staticmethod
+    def perform_import_export(graph_proto, input_shape):
+        """ Import ONNX model to mxnet model and then export to ONNX model
+            and then import it back to mxnet for verifying the result"""
+        graph = GraphProto()
+
+        sym, arg_params, aux_params = graph.from_onnx(graph_proto)
+
+        params = {}
+        params.update(arg_params)
+        params.update(aux_params)
+        # exporting to onnx graph proto format
+        converter = MXNetGraph()
+        graph_proto = converter.create_onnx_graph_proto(sym, params, 
in_shape=input_shape, 
in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')])
+
+        # importing back to MXNET for verifying result.
+        sym, arg_params, aux_params = graph.from_onnx(graph_proto)
+
+        return sym, arg_params, aux_params
+
+
     @classmethod
     def prepare(cls, model, device='CPU', **kwargs):
         """For running end to end model(used for onnx test backend)
@@ -54,8 +76,12 @@ class MXNetBackend(Backend):
             Returns object of MXNetBackendRep class which will be in turn
             used to run inference on the input model and return the result for 
comparison.
         """
+
         graph = GraphProto()
-        sym, arg_params, aux_params = graph.from_onnx(model.graph)
+        metadata = graph.get_graph_metadata(model.graph)
+        input_data = metadata['input_tensor_data']
+        input_shape = [data[1] for data in input_data]
+        sym, arg_params, aux_params = 
MXNetBackend.perform_import_export(model.graph, input_shape)
         return MXNetBackendRep(sym, arg_params, aux_params, device)
 
     @classmethod
@@ -63,6 +89,9 @@ class MXNetBackend(Backend):
         """Supports only CPU for testing"""
         return device == 'CPU'
 
+
 prepare = MXNetBackend.prepare
 
+run_node = MXNetBackend.run_node
+
 supports_device = MXNetBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py 
b/tests/python-pytest/onnx/export/backend_rep.py
similarity index 91%
copy from tests/python-pytest/onnx/import/mxnet_backend_rep.py
copy to tests/python-pytest/onnx/export/backend_rep.py
index 5ce29f5..8729eaf 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
+++ b/tests/python-pytest/onnx/export/backend_rep.py
@@ -16,18 +16,16 @@
 # under the License.
 
 # coding: utf-8
-"""MXNet backend rep for onnx test infrastructure"""
-import numpy as np
+"""backend rep for onnx test infrastructure"""
 try:
     from onnx.backend.base import BackendRep
 except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
-                      + " install - https://github.com/onnx/onnx#installation";)
+    raise ImportError("Onnx and protobuf need to be installed")
 import mxnet as mx
 
 # Using these functions for onnx test infrastructure.
 # Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+# 
https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
 # MXNetBackendRep object will be returned by MXNetBackend's prepare method 
which is used to
 # execute a model repeatedly.
 # Inputs will be passed to the run method of MXNetBackendRep class, it will 
perform computation and
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py 
b/tests/python-pytest/onnx/export/mxnet_export_test.py
new file mode 100644
index 0000000..7e1df07
--- /dev/null
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -0,0 +1,191 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Tests for individual operators
+This module contains operator tests which currently do not exist on
+ONNX backend test framework. Once we have PRs on the ONNX repo and get
+those PRs merged, this file will get EOL'ed.
+"""
+# pylint: disable=too-many-locals,wrong-import-position,import-error
+from __future__ import absolute_import
+import sys
+import os
+import logging
+import tarfile
+from collections import namedtuple
+import numpy as np
+import numpy.testing as npt
+from onnx import numpy_helper
+from onnx import TensorProto
+from mxnet.test_utils import download
+from mxnet.contrib import onnx as onnx_mxnet
+import mxnet as mx
+CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+URLS = {
+    'bvlc_googlenet':
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_googlenet.tar.gz',
+    'bvlc_reference_caffenet':
+        
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_caffenet.tar.gz',
+    'bvlc_reference_rcnn_ilsvrc13':
+        
'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_rcnn_ilsvrc13.tar.gz',
+    'inception_v1':
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v1.tar.gz',
+    'inception_v2':
+        'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz'
+}
+
+def get_test_files(name):
+    """Extract tar file and returns model path and input, output data"""
+    tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
+    # extract tar file
+    tar_path = os.path.join(CURR_PATH, tar_name)
+    tar = tarfile.open(tar_path.__str__(), "r:*")
+    tar.extractall(path=CURR_PATH.__str__())
+    tar.close()
+    data_dir = os.path.join(CURR_PATH, name)
+    model_path = os.path.join(data_dir, 'model.onnx')
+
+    inputs = []
+    outputs = []
+    # get test files
+    for test_file in os.listdir(data_dir):
+        case_dir = os.path.join(data_dir, test_file)
+        # skip the non-dir files
+        if not os.path.isdir(case_dir):
+            continue
+        input_file = os.path.join(case_dir, 'input_0.pb')
+        input_tensor = TensorProto()
+        with open(input_file, 'rb') as proto_file:
+            input_tensor.ParseFromString(proto_file.read())
+        inputs.append(numpy_helper.to_array(input_tensor))
+
+        output_tensor = TensorProto()
+        output_file = os.path.join(case_dir, 'output_0.pb')
+        with open(output_file, 'rb') as proto_file:
+            output_tensor.ParseFromString(proto_file.read())
+        outputs.append(numpy_helper.to_array(output_tensor))
+
+    return model_path, inputs, outputs
+
+
+def forward_pass(sym, arg, aux, data_names, input_data):
+    """ Perform forward pass on given data"""
+    # create module
+    mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), 
label_names=None)
+    mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
+    mod.set_params(arg_params=arg, aux_params=aux,
+                   allow_missing=True, allow_extra=True)
+    # run inference
+    batch = namedtuple('Batch', ['data'])
+    mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
+
+    return mod.get_outputs()[0].asnumpy()
+
+
+def test_models(model_name, input_shape, output_shape):
+    """ Tests Googlenet model for both onnx import and export"""
+    model_path, inputs, outputs = get_test_files(model_name)
+    logging.info("Translating model from ONNX model zoo to Mxnet")
+    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    params = {}
+    params.update(arg_params)
+    params.update(aux_params)
+
+    dir_path = os.path.dirname(model_path)
+    new_model_name = "exported_" + model_name + ".onnx"
+    onnx_file = os.path.join(dir_path, new_model_name)
+
+    logging.info("Translating converted model from mxnet to ONNX")
+    converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], 
np.float32, onnx_file)
+
+    sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
+
+    metadata = onnx_mxnet.get_model_metadata(converted_model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data')[0][1] == input_shape
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data')[0][1] == output_shape
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
+
+    logging.info("Running inference on onnx re-import model in mxnet")
+    # run test for each test file
+    for input_data, output_data in zip(inputs, outputs):
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+
+        # verify the results
+        npt.assert_equal(result.shape, output_data.shape)
+        npt.assert_almost_equal(output_data, result, decimal=3)
+    logging.info(model_name + " conversion successful")
+
+
+def test_model_accuracy(model_name, input_shape):
+    """ Imports ONNX model, runs inference, exports and imports back
+        run inference, compare result with the previous inference result"""
+    model_path, inputs, outputs = get_test_files(model_name)
+    logging.info("Translating model from ONNX model zoo to Mxnet")
+    sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
+
+    expected_result= []
+    for input_data, output_data in zip(inputs, outputs):
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        expected_result.append(result)
+
+    params = {}
+    params.update(arg_params)
+    params.update(aux_params)
+
+    dir_path = os.path.dirname(model_path)
+    new_model_name = "exported_" + model_name + ".onnx"
+    onnx_file = os.path.join(dir_path, new_model_name)
+
+    logging.info("Translating converted model from mxnet to ONNX")
+    converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], 
np.float32,
+                                                   onnx_file)
+
+    sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
+
+    metadata = onnx_mxnet.get_model_metadata(converted_model_path)
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
+
+    actual_result = []
+    for input_data, output_data in zip(inputs, outputs):
+        result = forward_pass(sym, arg_params, aux_params, data_names, 
input_data)
+        actual_result.append(result)
+
+    # verify the results
+    for expected, actual in zip(expected_result, actual_result):
+        npt.assert_equal(expected.shape, actual.shape)
+        npt.assert_almost_equal(expected, actual, decimal=3)
+
+
+if __name__ == '__main__':
+    test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
+    test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
+    test_models("bvlc_reference_rcnn_ilsvrc13", (1, 3, 224, 224), (1, 200))
+
+    # Comparing MXNet inference result, since MXNet results don't match
+    # ONNX expected results due to AveragePool issue github issue(#10194)
+    test_model_accuracy("inception_v1", (1, 3, 224, 224))
+    test_model_accuracy("inception_v2", (1, 3, 224, 224))
diff --git a/tests/python-pytest/onnx/import/test_cases.py 
b/tests/python-pytest/onnx/export/onnx_backend_test.py
similarity index 64%
copy from tests/python-pytest/onnx/import/test_cases.py
copy to tests/python-pytest/onnx/export/onnx_backend_test.py
index 1a4d8c4..803d290 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -15,7 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Test Cases to be run for the import module"""
+"""ONNX test backend wrapper"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+try:
+    import onnx.backend.test
+except ImportError:
+    raise ImportError("Onnx and protobuf need to be installed")
+
+import backend as mxnet_backend
+
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = "onnx.backend.test.report",
+
+BACKEND_TESTS = onnx.backend.test.BackendTest(mxnet_backend, __name__)
 
 IMPLEMENTED_OPERATORS_TEST = [
     'test_random_uniform',
@@ -31,6 +48,7 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_ceil',
     'test_floor',
     'test_concat',
+    'test_identity',
     'test_sigmoid',
     'test_relu',
     'test_constant_pad',
@@ -41,7 +59,6 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_reduce_mean',
     'test_reduce_prod',
     'test_squeeze',
-    'test_unsqueeze',
     'test_softmax_example',
     'test_softmax_large_number',
     'test_softmax_axis_2',
@@ -58,16 +75,7 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_argmax',
     'test_argmin',
     'test_min',
-    'test_logical_and',
-    'test_logical_xor',
-    'test_logical_not',
-    'test_logical_or',
-    'test_clip',
-    'test_softsign',
-    'test_reduce_l2',
-    'test_reduce_log_sum',
-    'test_reduce_log_sum_exp',
-    'test_reduce_sum_square'
+    'test_max'
     #pytorch operator tests
     'test_operator_exp',
     'test_operator_maxpool',
@@ -78,7 +86,7 @@ IMPLEMENTED_OPERATORS_TEST = [
 BASIC_MODEL_TESTS = [
     'test_AvgPool2D',
     'test_BatchNorm',
-    'test_ConstantPad2d'
+    'test_ConstantPad2d',
     'test_Conv2d',
     'test_ELU',
     'test_LeakyReLU',
@@ -95,11 +103,30 @@ BASIC_MODEL_TESTS = [
 STANDARD_MODEL = [
     'test_bvlc_alexnet',
     'test_densenet121',
-    #'test_inception_v1',
-    #'test_inception_v2',
+    # 'test_inception_v1',
+    # 'test_inception_v2',
     'test_resnet50',
-    #'test_shufflenet',
+    # 'test_shufflenet',
     'test_squeezenet',
-    'test_zfnet512',
+    'test_vgg16',
     'test_vgg19'
     ]
+
+for op_test in IMPLEMENTED_OPERATORS_TEST:
+    BACKEND_TESTS.include(op_test)
+
+for basic_model_test in BASIC_MODEL_TESTS:
+    BACKEND_TESTS.include(basic_model_test)
+
+for std_model_test in STANDARD_MODEL:
+    BACKEND_TESTS.include(std_model_test)
+
+BACKEND_TESTS.exclude('.*broadcast.*')
+BACKEND_TESTS.exclude('.*bcast.*')
+
+
+# import all test cases at global scope to make them visible to python.unittest
+globals().update(BACKEND_TESTS.enable_report().test_cases)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py 
b/tests/python-pytest/onnx/import/gluon_backend.py
index d2946f7..302fd4d 100644
--- a/tests/python-pytest/onnx/import/gluon_backend.py
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -17,10 +17,8 @@
 
 # coding: utf-8
 """Gluon backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet import nd
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
-import numpy as np
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+
 try:
     from onnx import helper, TensorProto
     from onnx.backend.base import Backend
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py 
b/tests/python-pytest/onnx/import/mxnet_backend.py
index bbe8899..10f89ec 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend.py
@@ -17,8 +17,7 @@
 
 # coding: utf-8
 """MXNet backend wrapper for onnx test infrastructure"""
-import mxnet as mx
-from mxnet.contrib.onnx._import.import_onnx import GraphProto
+from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
 try:
     from onnx import helper, TensorProto
     from onnx.backend.base import Backend
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py 
b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
index 5ce29f5..067ef15 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
@@ -17,7 +17,6 @@
 
 # coding: utf-8
 """MXNet backend rep for onnx test infrastructure"""
-import numpy as np
 try:
     from onnx.backend.base import BackendRep
 except ImportError:
diff --git a/tests/python-pytest/onnx/import/test_cases.py 
b/tests/python-pytest/onnx/import/test_cases.py
index 1a4d8c4..f7addbb 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -31,6 +31,7 @@ IMPLEMENTED_OPERATORS_TEST = [
     'test_ceil',
     'test_floor',
     'test_concat',
+    'test_identity',
     'test_sigmoid',
     'test_relu',
     'test_constant_pad',

Reply via email to