zhreshold closed pull request #12878: ONNX export: Cleanup
URL: https://github.com/apache/incubator-mxnet/pull/12878
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 7cf856c767f..11e75d9a600 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -60,21 +60,16 @@
 import logging
 import numpy as np
 from .export_onnx import MXNetGraph as mx_op
-
-def import_onnx_modules():
-    """ To make sure ONNX is runtime dependency, it is imported used only when 
needed"""
-    try:
-        import onnx
-    except ImportError:
-        raise ImportError("Onnx and protobuf need to be installed. "
-                          + "Instructions to install - 
https://github.com/onnx/onnx";)
-    return onnx
+try:
+    import onnx
+except ImportError:
+    onnx = None
 
 
 def parse_helper(attrs, attrs_name, alt_value=None):
     """Helper function to parse operator attributes in required format."""
     tuple_re = re.compile('\([0-9L|,| ]+\)')
-    if attrs is None:
+    if not attrs:
         return alt_value
     attrs_str = None if attrs.get(attrs_name) is None else 
str(attrs.get(attrs_name))
     if attrs_str is None:
@@ -135,12 +130,39 @@ def get_boolean_attribute_value(attrs, attr_name):
     """
     return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0
 
+def get_inputs(node, kwargs):
+    """Helper function to get inputs"""
+    name = node["name"]
+    proc_nodes = kwargs["proc_nodes"]
+    index_lookup = kwargs["index_lookup"]
+    inputs = node["inputs"]
+    attrs = node.get("attrs", {})
+
+    input_nodes = []
+    for ip in inputs:
+        input_node_id = index_lookup[ip[0]]
+        input_nodes.append(proc_nodes[input_node_id].name)
+
+    return name, input_nodes, attrs
+
+def create_basic_op_node(op_name, node, kwargs):
+    """Helper function to create a basic operator
+    node that doesn't contain op specific attrs"""
+    name, input_nodes, _ = get_inputs(node, kwargs)
+
+    node = onnx.helper.make_node(
+        op_name,
+        input_nodes,
+        [name],
+        name=name
+    )
+    return [node]
+
 @mx_op.register("null")
 def convert_weights_and_inputs(node, **kwargs):
     """Helper function to convert weights and inputs.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
+    name, _, _ = get_inputs(node, kwargs)
 
     if kwargs["is_input"] is False:
         weights = kwargs["weights"]
@@ -172,20 +194,7 @@ def convert_convolution(node, **kwargs):
     """Map MXNet's convolution operator attributes to onnx's Conv operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-
-    num_inputs = len(inputs)
-
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
-    weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name
-
-    if num_inputs > 2:
-        bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name
-
-    attrs = node.get("attrs")
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     kernel_dims = list(parse_helper(attrs, "kernel"))
     stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
@@ -195,10 +204,6 @@ def convert_convolution(node, **kwargs):
 
     pad_dims = pad_dims + pad_dims
 
-    input_nodes = [input_node, weights_node]
-    if num_inputs > 2:
-        input_nodes.append(bias_node)
-
     conv_node = onnx.helper.make_node(
         "Conv",
         inputs=input_nodes,
@@ -219,32 +224,15 @@ def convert_fully_connected(node, **kwargs):
     """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
     initializer = kwargs["initializer"]
 
     no_bias = get_boolean_attribute_value(attrs, "no_bias")
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    weight_node_id = kwargs["index_lookup"][inputs[1][0]]
-
-    proc_nodes = kwargs["proc_nodes"]
-
-    input_node = proc_nodes[input_node_id]
-    input_name = input_node.name
-
-    weights_node = proc_nodes[weight_node_id]
-    weights_name = weights_node.name
-
     fcnode = []
 
-    if no_bias == 0:
-        bias_node_id = kwargs["index_lookup"][inputs[2][0]]
-        bias_node = proc_nodes[bias_node_id]
-        bias_name = bias_node.name
-    else:
+    if no_bias:
         data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
         bias_name = "bias" + str(kwargs["idx"])
         tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, 
(1,))
@@ -257,11 +245,12 @@ def convert_fully_connected(node, **kwargs):
                 raw=False,
             )
         )
+        input_nodes.append(bias_name)
         fcnode.append(tensor_node)
 
     node = onnx.helper.make_node(
         "Gemm",
-        [input_name, weights_name, bias_name],  # input (A, B, C) - C can be 
in place
+        input_nodes,  # input (A, B, C) - C can be in place
         [name],  # output
         alpha=1.0,
         beta=1.0,
@@ -280,37 +269,14 @@ def convert_batchnorm(node, **kwargs):
     """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization 
operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    attrs = node["attrs"]
-    momentum = float(node.get("attrs", {}).get("momentum", 0.9))
+    momentum = float(attrs.get("momentum", 0.9))
     eps = float(attrs.get("eps", 0.001))
 
-    data_idx = kwargs["index_lookup"][inputs[0][0]]
-    gamma_idx = kwargs["index_lookup"][inputs[1][0]]
-    beta_idx = kwargs["index_lookup"][inputs[2][0]]
-    moving_mean_idx = kwargs["index_lookup"][inputs[3][0]]
-    moving_var_idx = kwargs["index_lookup"][inputs[4][0]]
-
-    data_node = proc_nodes[data_idx].name
-    gamma_node = proc_nodes[gamma_idx].name
-    beta_node = proc_nodes[beta_idx].name
-
-    mov_mean_node = proc_nodes[moving_mean_idx]
-    mov_mean_node = mov_mean_node.name
-    mov_var_node = proc_nodes[moving_var_idx].name
-
     bn_node = onnx.helper.make_node(
         "BatchNormalization",
-        [data_node,
-         gamma_node,  # scale
-         beta_node,  # bias
-         mov_mean_node,
-         mov_var_node
-        ],
+        input_nodes,
         [name],
         name=name,
         epsilon=eps,
@@ -327,140 +293,49 @@ def convert_tanh(node, **kwargs):
     """Map MXNet's tanh operator attributes to onnx's Tanh operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Tanh',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Tanh', node, kwargs)
 
 @mx_op.register("cos")
 def convert_cos(node, **kwargs):
     """Map MXNet's cos operator attributes to onnx's Cos operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Cos',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Cos', node, kwargs)
 
 @mx_op.register("sin")
 def convert_sin(node, **kwargs):
     """Map MXNet's sin operator attributes to onnx's Sin operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Sin',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Sin', node, kwargs)
 
 @mx_op.register("tan")
 def convert_tan(node, **kwargs):
     """Map MXNet's tan operator attributes to onnx's tan operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Tan',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Tan', node, kwargs)
 
 @mx_op.register("arccos")
 def convert_acos(node, **kwargs):
     """Map MXNet's acos operator attributes to onnx's acos operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Acos',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Acos', node, kwargs)
 
 @mx_op.register("arcsin")
 def convert_asin(node, **kwargs):
     """Map MXNet's asin operator attributes to onnx's asin operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Asin',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Asin', node, kwargs)
 
 @mx_op.register("arctan")
 def convert_atan(node, **kwargs):
     """Map MXNet's atan operator attributes to onnx's atan operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Atan',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Atan', node, kwargs)
 
 #Basic neural network functions
 @mx_op.register("sigmoid")
@@ -468,58 +343,24 @@ def convert_sigmoid(node, **kwargs):
     """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Sigmoid',
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Sigmoid', node, kwargs)
 
 @mx_op.register("relu")
 def convert_relu(node, **kwargs):
     """Map MXNet's relu operator attributes to onnx's Relu operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_node_idx].name
-
-    node = onnx.helper.make_node(
-        'Relu',
-        [input_node],
-        [name],
-        name=name
-    )
-
-    return [node]
+    return create_basic_op_node('Relu', node, kwargs)
 
 @mx_op.register("Activation")
 def convert_activation(node, **kwargs):
     """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    proc_nodes = kwargs["proc_nodes"]
-    attrs = node["attrs"]
     act_type = attrs["act_type"]
 
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_idx].output[0]
-
     # Creating a dictionary here, but if this titlecase pattern
     # mxnet_name.title()
     act_types = {
@@ -534,7 +375,7 @@ def convert_activation(node, **kwargs):
     if act_name:
         node = onnx.helper.make_node(
             act_name,
-            [input_node],
+            input_nodes,
             [name],
             name=name
         )
@@ -551,13 +392,7 @@ def convert_pad(node, **kwargs):
     """Map MXNet's pad operator attributes to onnx's Pad operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    attrs = node["attrs"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_idx].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     mxnet_pad_width = convert_string_to_list(attrs.get("pad_width"))
     onnx_pad_width = transform_padding(mxnet_pad_width)
@@ -569,7 +404,7 @@ def convert_pad(node, **kwargs):
             if "constant_value" in attrs else 0.0
         node = onnx.helper.make_node(
             'Pad',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             mode='constant',
             value=pad_value,
@@ -579,7 +414,7 @@ def convert_pad(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'Pad',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             mode=pad_mode,
             pads=onnx_pad_width,
@@ -591,8 +426,6 @@ def convert_pad(node, **kwargs):
 
 def create_helper_trans_node(op_name, input_node, node_name):
     """create extra transpose node for dot operator"""
-    onnx = import_onnx_modules()
-
     node_name = op_name + "_" + node_name
     trans_node = onnx.helper.make_node(
         'Transpose',
@@ -608,17 +441,8 @@ def convert_dot(node, **kwargs):
     """Map MXNet's dot operator attributes to onnx's
     MatMul and Transpose operators based on the values set for
     transpose_a, transpose_b attributes."""
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-    name = node["name"]
-
-    input_a_idx = kwargs["index_lookup"][node_inputs[0][0]]
-    input_node_a = proc_nodes[input_a_idx].name
-    input_b_idx = kwargs["index_lookup"][node_inputs[1][0]]
-    input_node_b = proc_nodes[input_b_idx].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    attrs = node.get('attrs', {})
     trans_a_node = None
     trans_b_node = None
 
@@ -626,14 +450,12 @@ def convert_dot(node, **kwargs):
     trans_b = get_boolean_attribute_value(attrs, "transpose_b")
 
     op_name = "transpose" + str(kwargs["idx"])
-    create_helper_trans_node(op_name, input_node_a, 'a')
-    create_helper_trans_node(op_name, input_node_b, 'b')
 
     if trans_a:
-        trans_a_node = create_helper_trans_node(op_name, input_node_a, 'a')
+        trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a')
         input_node_a = op_name+"_a"
     if trans_b:
-        trans_b_node = create_helper_trans_node(op_name, input_node_b, 'b')
+        trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b')
         input_node_b = op_name+"_b"
 
     matmul_node = onnx.helper.make_node(
@@ -660,33 +482,19 @@ def convert_linalg_gemm2(node, **kwargs):
     transpose_a, transpose_b attributes.
     Return multiple nodes created.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-    name = node["name"]
-
-    input_a_idx = kwargs["index_lookup"][node_inputs[0][0]]
-    input_node_a = proc_nodes[input_a_idx].name
-    input_b_idx = kwargs["index_lookup"][node_inputs[1][0]]
-    input_node_b = proc_nodes[input_b_idx].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     # Getting the attributes and assigning default values.
-    if "attrs" in node:
-        attrs = node["attrs"]
-        alpha = float(attrs["alpha"])
-        trans_a = int(attrs["transpose_a"])
-        trans_b = int(attrs["transpose_b"])
-    else:
-        alpha = 1.0
-        trans_a = 0
-        trans_b = 0
+    alpha = float(attrs.get("alpha", 1.0))
+    trans_a = get_boolean_attribute_value(attrs, "transpose_a")
+    trans_b = get_boolean_attribute_value(attrs, "transpose_b")
 
     op_name = "transpose" + str(kwargs["idx"])
 
     if alpha == 1.0 and trans_a == 0 and trans_b == 0:
         matmul_node = onnx.helper.make_node(
             'MatMul',
-            inputs=[input_node_a, input_node_b],
+            inputs=input_nodes,
             outputs=[name],
             name=name
         )
@@ -696,14 +504,14 @@ def convert_linalg_gemm2(node, **kwargs):
         node_name = op_name+"_a"
         trans_a_node = onnx.helper.make_node(
             'Transpose',
-            inputs=[input_node_a],
+            inputs=[input_nodes[0]],
             outputs=[op_name+"_a"],
             name=node_name
         )
 
         matmul_node = onnx.helper.make_node(
             'MatMul',
-            inputs=[node_name, input_node_b],
+            inputs=[node_name, input_nodes[1]],
             outputs=[name],
             name=name
         )
@@ -713,14 +521,14 @@ def convert_linalg_gemm2(node, **kwargs):
         node_name = op_name + "_b"
         trans_b_node = onnx.helper.make_node(
             'Transpose',
-            inputs=[input_node_b],
+            inputs=[input_nodes[1]],
             outputs=[op_name+"_b"],
             name=node_name
         )
 
         matmul_node = onnx.helper.make_node(
             'MatMul',
-            inputs=[input_node_a, node_name],
+            inputs=[input_nodes[0], node_name],
             outputs=[name],
             name=name
         )
@@ -730,7 +538,7 @@ def convert_linalg_gemm2(node, **kwargs):
         node_name_a = op_name+"_a"
         trans_a_node = onnx.helper.make_node(
             'Transpose',
-            inputs=[input_node_a],
+            inputs=[input_nodes[0]],
             outputs=[op_name+"_a"],
             name=node_name_a
         )
@@ -738,14 +546,14 @@ def convert_linalg_gemm2(node, **kwargs):
         node_name_b = op_name + "_b"
         trans_b_node = onnx.helper.make_node(
             'Transpose',
-            inputs=[input_node_b],
+            inputs=[input_nodes[1]],
             outputs=[op_name+"_b"],
             name=node_name_b
         )
 
         matmul_node = onnx.helper.make_node(
             'MatMul',
-            inputs=[node_name_a, node_name_b],
+            inputs=input_nodes,
             outputs=[name],
             name=name
         )
@@ -759,19 +567,13 @@ def convert_pooling(node, **kwargs):
     MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators
     based on the input node's attributes and return the created node.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
     kernel = eval(attrs["kernel"])
     pool_type = attrs["pool_type"]
     stride = eval(attrs["stride"]) if attrs.get("stride") else None
     global_pool = get_boolean_attribute_value(attrs, "global_pool")
 
-    node_inputs = node["inputs"]
-    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
-    input_node = proc_nodes[input_node_idx]
-    name = node["name"]
-
     pooling_convention = attrs.get('pooling_convention', 'valid')
 
     if pooling_convention == 'full':
@@ -789,14 +591,14 @@ def convert_pooling(node, **kwargs):
     if global_pool:
         node = onnx.helper.make_node(
             global_pool_types[pool_type],
-            [input_node.name],  # input
+            input_nodes,  # input
             [name],
             name=name
         )
     else:
         node = onnx.helper.make_node(
             pool_types[pool_type],
-            [input_node.name],  # input
+            input_nodes,  # input
             [name],
             kernel_shape=kernel,
             pads=pad_dims,
@@ -812,43 +614,14 @@ def convert_exp(node, **kwargs):
     """Map MXNet's exp operator attributes to onnx's Exp operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Exp",
-        [input_node],
-        [name],
-        name=name,
-    )
-    return [node]
-
+    return create_basic_op_node('Exp', node, kwargs)
 
 @mx_op.register("_copy")
 def convert_identity(node, **kwargs):
     """Map MXNet's _copy operator attributes to onnx's Identity operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Identity",
-        [input_node],
-        [name],
-        name=name,
-    )
-    return [node]
+    return create_basic_op_node('Identity', node, kwargs)
 
 
 @mx_op.register("LeakyReLU")
@@ -856,13 +629,7 @@ def convert_leakyrelu(node, **kwargs):
     """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu 
operators
     based on the input node's attributes and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     act_type = attrs.get("act_type", "leaky")
     alpha = float(attrs.get("slope", 0.25))
@@ -870,25 +637,16 @@ def convert_leakyrelu(node, **kwargs):
     act_name = {"elu": "Elu", "leaky": "LeakyRelu", "prelu": "PRelu",
                 "selu": "Selu"}
 
-    if act_type == "prelu":
-        alpha_node_index = kwargs["index_lookup"][inputs[1][0]]
-        alpha_node_name = proc_nodes[alpha_node_index].name
-
-        node = onnx.helper.make_node(
-            act_name[act_type],
-            inputs=[input_node, alpha_node_name],
-            outputs=[name],
-            name=name)
-    elif act_type == "selu":
+    if act_type == "prelu" or act_type == "selu":
         node = onnx.helper.make_node(
             act_name[act_type],
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             name=name)
     else:
         node = onnx.helper.make_node(
             act_name[act_type],
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             name=name,
             alpha=alpha)
@@ -901,18 +659,13 @@ def convert_softmax(node, **kwargs):
     """Map MXNet's softmax operator attributes to onnx's Softmax operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    inputs = node["inputs"]
-    input_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_idx]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    name = node["name"]
-    axis = int(node.get("attrs", {}).get("axis", -1))
+    axis = int(attrs.get("axis", -1))
 
     softmax_node = onnx.helper.make_node(
         "Softmax",
-        [input_node.name],
+        input_nodes,
         [name],
         axis=axis,
         name=name
@@ -928,12 +681,10 @@ def convert_softmax_output(node, **kwargs):
     """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    inputs = node["inputs"]
-    input1_idx = kwargs["index_lookup"][inputs[0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input1 = proc_nodes[input1_idx]
-    name = node["name"]
+    name, _, _ = get_inputs(node, kwargs)
+
+    input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    input1 = kwargs["proc_nodes"][input1_idx]
 
     softmax_node = onnx.helper.make_node(
         "Softmax",
@@ -951,15 +702,12 @@ def convert_concat(node, **kwargs):
     """Map MXNet's Concat operator attributes to onnx's Concat operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    inputs = node["inputs"]
-    proc_nodes = kwargs["proc_nodes"]
-    input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in 
inputs]
-    axis = int(node.get("attrs", {}).get("dim", 1))
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = int(attrs.get("dim", 1))
     concat_node = onnx.helper.make_node(
         "Concat",
-        input_names,
+        input_nodes,
         [name],
         axis=axis,
         name=name
@@ -972,18 +720,15 @@ def convert_transpose(node, **kwargs):
     """Map MXNet's transpose operator attributes to onnx's Transpose operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_idx].name
-    axes = node.get("attrs", {}).get("axes", ())
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axes = attrs.get("axes", ())
     if axes:
         axes = tuple(map(int, re.findall(r'\d+', axes)))
 
         transpose_node = onnx.helper.make_node(
             "Transpose",
-            [input_node],
+            input_nodes,
             [name],
             perm=axes,
             name=name
@@ -991,7 +736,7 @@ def convert_transpose(node, **kwargs):
     else:
         transpose_node = onnx.helper.make_node(
             "Transpose",
-            [input_node],
+            input_nodes,
             [name],
             name=name
         )
@@ -1004,21 +749,16 @@ def convert_lrn(node, **kwargs):
     """Map MXNet's LRN operator attributes to onnx's LRN operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_idx].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    attrs = node["attrs"]
-    alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001
-    beta = float(attrs["beta"]) if "beta" in attrs else 0.75
-    bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0
-    size = int(attrs["nsize"])
+    alpha = float(attrs.get("alpha", 0.0001))
+    beta = float(attrs.get("beta", 0.75))
+    bias = float(attrs.get("knorm", 1.0))
+    size = int(attrs.get("nsize"))
 
     lrn_node = onnx.helper.make_node(
         "LRN",
-        inputs=[input_node],
+        inputs=input_nodes,
         outputs=[name],
         name=name,
         alpha=alpha,
@@ -1035,11 +775,8 @@ def convert_l2normalization(node, **kwargs):
     """Map MXNet's L2Normalization operator attributes to onnx's 
LpNormalization operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_id = kwargs["index_lookup"][node["inputs"][0][0]]
-    input_name = kwargs["proc_nodes"][input_id].name
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
     mode = attrs.get("mode", "instance")
 
     if mode != "channel":
@@ -1047,7 +784,7 @@ def convert_l2normalization(node, **kwargs):
 
     l2norm_node = onnx.helper.make_node(
         "LpNormalization",
-        [input_name],
+        input_nodes,
         [name],
         axis=1,  # channel only
         name=name
@@ -1060,16 +797,13 @@ def convert_dropout(node, **kwargs):
     """Map MXNet's Dropout operator attributes to onnx's Dropout operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_id = kwargs["index_lookup"][node["inputs"][0][0]]
-    input_name = kwargs["proc_nodes"][input_id].name
-    attrs = node["attrs"]
-    probability = float(attrs["p"])
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    probability = float(attrs.get("p", 0.5))
 
     dropout_node = onnx.helper.make_node(
         "Dropout",
-        [input_name],
+        input_nodes,
         [name],
         ratio=probability,
         name=name
@@ -1082,37 +816,21 @@ def convert_flatten(node, **kwargs):
     """Map MXNet's Flatten operator attributes to onnx's Flatten operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_idx].name  # .output[0]
-
-    flatten_node = onnx.helper.make_node(
-        "Flatten",
-        [input_node],
-        [name],
-        name=name
-    )
-    return [flatten_node]
+    return create_basic_op_node('Flatten', node, kwargs)
 
 @mx_op.register("clip")
 def convert_clip(node, **kwargs):
     """Map MXNet's Clip operator attributes to onnx's Clip operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    input_idx = kwargs["index_lookup"][node["inputs"][0][0]]
-    proc_nodes = kwargs["proc_nodes"]
-    input_node = proc_nodes[input_idx].name
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
     a_min = np.float(attrs.get('a_min', -np.inf))
     a_max = np.float(attrs.get('a_max', np.inf))
 
     clip_node = onnx.helper.make_node(
         "Clip",
-        [input_node],
+        input_nodes,
         [name],
         name=name,
         min=a_min,
@@ -1123,21 +841,16 @@ def convert_clip(node, **kwargs):
 
 def scalar_op_helper(node, op_name, **kwargs):
     """Helper function for scalar arithmetic operations"""
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    scalar_value = [float(node.get("attrs", {}).get("scalar", 1))]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    input_name_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_name_id].name
+    scalar_value = [float(attrs.get("scalar", 1))]
 
     initializer = kwargs["initializer"]
     flag = True
     # If the input value is in initializer, just multiply with scalar input
     # and create a new initializer
     for i in initializer:
-        if i.name == input_node:
+        if i.name == input_nodes[0]:
             if op_name == 'Mul':
                 new_initializer = onnx.numpy_helper.to_array(i) * 
scalar_value[0]
             elif op_name == 'Sub':
@@ -1170,7 +883,7 @@ def scalar_op_helper(node, op_name, **kwargs):
 
         mul_node = onnx.helper.make_node(
             op_name,
-            [input_node, scalar_op_name],
+            [input_nodes[0], scalar_op_name],
             [name],
             name=name
         )
@@ -1180,7 +893,7 @@ def scalar_op_helper(node, op_name, **kwargs):
         data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype]
         dims = np.shape(new_initializer)
 
-        new_a_node = input_node + str(kwargs["idx"])
+        new_a_node = input_nodes[0] + str(kwargs["idx"])
         tensor_node = onnx.helper.make_tensor_value_info(new_a_node, 
data_type, dims)
 
         initializer.append(
@@ -1239,21 +952,14 @@ def convert_argmax(node, **kwargs):
     """Map MXNet's argmax operator attributes to onnx's ArgMax operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-
-    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
-    input_node = proc_nodes[input_node_idx].name
-    name = node["name"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     axis = int(attrs.get("axis"))
     keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
 
     node = onnx.helper.make_node(
         'ArgMax',
-        inputs=[input_node],
+        inputs=input_nodes,
         axis=axis,
         keepdims=keepdims,
         outputs=[name],
@@ -1266,21 +972,14 @@ def convert_argmin(node, **kwargs):
     """Map MXNet's argmin operator attributes to onnx's ArgMin operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-
-    input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
-    input_node = proc_nodes[input_node_idx].name
-    name = node["name"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     axis = int(attrs.get("axis"))
     keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
 
     node = onnx.helper.make_node(
         'ArgMin',
-        inputs=[input_node],
+        inputs=input_nodes,
         axis=axis,
         keepdims=keepdims,
         outputs=[name],
@@ -1293,25 +992,7 @@ def convert_maximum(node, **kwargs):
     """Map MXNet's _maximum operator attributes to onnx's Max operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-
-    input_node_list = []
-    for node_input in node_inputs:
-        node_id = kwargs["index_lookup"][node_input[0]]
-        input_node_list.append(proc_nodes[node_id].name)
-
-    name = node["name"]
-
-    node = onnx.helper.make_node(
-        'Max',
-        inputs=input_node_list,
-        outputs=[name],
-        name=name,
-    )
-
-    return [node]
+    return create_basic_op_node('Max', node, kwargs)
 
 
 @mx_op.register("_minimum")
@@ -1319,49 +1000,24 @@ def convert_minimum(node, **kwargs):
     """Map MXNet's _minimum operator attributes to onnx's Min operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    proc_nodes = kwargs["proc_nodes"]
-    node_inputs = node["inputs"]
-
-    input_node_list = []
-    for node_input in node_inputs:
-        node_id = kwargs["index_lookup"][node_input[0]]
-        input_node_list.append(proc_nodes[node_id].name)
-
-    name = node["name"]
-
-    node = onnx.helper.make_node(
-        'Min',
-        inputs=input_node_list,
-        outputs=[name],
-        name=name,
-    )
-
-    return [node]
-
+    return create_basic_op_node('Min', node, kwargs)
 
 @mx_op.register("min")
 def convert_min(node, **kwargs):
     """Map MXNet's min operator attributes to onnx's ReduceMin operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    mx_axis = node.get("attrs", {}).get("axis", None)
+    mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
 
-    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    keepdims = int(attrs.get("keepdims", 0))
 
     if axes is not None:
         node = onnx.helper.make_node(
             'ReduceMin',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             axes=axes,
             keepdims=keepdims,
@@ -1372,7 +1028,7 @@ def convert_min(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'ReduceMin',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             keepdims=keepdims,
             name=name
@@ -1386,23 +1042,17 @@ def convert_max(node, **kwargs):
     """Map MXNet's max operator attributes to onnx's ReduceMax operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    mx_axis = node.get("attrs", {}).get("axis", None)
+    mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
 
-    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    keepdims = int(attrs.get("keepdims", 0))
 
     if axes is not None:
         node = onnx.helper.make_node(
             'ReduceMax',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             axes=axes,
             keepdims=keepdims,
@@ -1413,7 +1063,7 @@ def convert_max(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'ReduceMax',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             keepdims=keepdims,
             name=name
@@ -1427,23 +1077,17 @@ def convert_mean(node, **kwargs):
     """Map MXNet's mean operator attributes to onnx's ReduceMean operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    mx_axis = node.get("attrs", {}).get("axis", None)
+    mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
 
-    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    keepdims = int(attrs.get("keepdims", 0))
 
     if axes is not None:
         node = onnx.helper.make_node(
             'ReduceMean',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             axes=axes,
             keepdims=keepdims,
@@ -1454,7 +1098,7 @@ def convert_mean(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'ReduceMean',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             keepdims=keepdims,
             name=name
@@ -1468,23 +1112,17 @@ def convert_prod(node, **kwargs):
     """Map MXNet's prod operator attributes to onnx's ReduceProd operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    mx_axis = node.get("attrs", {}).get("axis", None)
+    mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
 
-    keepdims = int(node.get("attrs", {}).get("keepdims", 0))
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    keepdims = int(attrs.get("keepdims", 0))
 
     if axes is not None:
         node = onnx.helper.make_node(
             'ReduceProd',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             axes=axes,
             keepdims=keepdims,
@@ -1495,7 +1133,7 @@ def convert_prod(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'ReduceProd',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             keepdims=keepdims,
             name=name
@@ -1510,25 +1148,7 @@ def convert_elementwise_add(node, **kwargs):
     """Map MXNet's elemwise_add operator attributes to onnx's Add operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    add_node = onnx.helper.make_node(
-        "Add",
-        [input_node_a, input_node_b],
-        [name],
-        name=name,
-    )
-
-    return [add_node]
+    return create_basic_op_node('Add', node, kwargs)
 
 
 @mx_op.register("broadcast_add")
@@ -1536,25 +1156,7 @@ def covert_broadcast_add(node, **kwargs):
     """Map MXNet's broadcast_add operator attributes to onnx's Add operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    add_node = onnx.helper.make_node(
-        "Add",
-        [input_node_a, input_node_b],
-        [name],
-        name=name,
-    )
-
-    return [add_node]
+    return create_basic_op_node('Add', node, kwargs)
 
 
 @mx_op.register("elemwise_sub")
@@ -1562,224 +1164,63 @@ def convert_elementwise_sub(node, **kwargs):
     """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    sub_node = onnx.helper.make_node(
-        "Sub",
-        [input_node_a, input_node_b],
-        [name],
-        name=name,
-    )
-
-    return [sub_node]
+    return create_basic_op_node('Sub', node, kwargs)
 
 @mx_op.register("broadcast_sub")
 def covert_broadcast_sub(node, **kwargs):
     """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    sub_node = onnx.helper.make_node(
-        "Sub",
-        [input_node_a, input_node_b],
-        [name],
-        name=name,
-    )
-
-    return [sub_node]
-
+    return create_basic_op_node('Sub', node, kwargs)
 
 @mx_op.register("elemwise_mul")
 def convert_elemwise_mul(node, **kwargs):
     """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    mul_node = onnx.helper.make_node(
-        "Mul",
-        [input_node_a, input_node_b],
-        [name],
-        name=name,
-    )
-
-    return [mul_node]
+    return create_basic_op_node('Mul', node, kwargs)
 
 @mx_op.register("broadcast_mul")
 def convert_broadcast_mul(node, **kwargs):
     """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    mul_node = onnx.helper.make_node(
-        "Mul",
-        [input_node_a, input_node_b],
-        [name],
-        name=name
-    )
-
-    return [mul_node]
-
+    return create_basic_op_node('Mul', node, kwargs)
 
 @mx_op.register("elemwise_div")
 def convert_elemwise_div(node, **kwargs):
     """Map MXNet's elemwise_div operator attributes to onnx's Div operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    div_node = onnx.helper.make_node(
-        "Div",
-        [input_node_a, input_node_b],
-        [name],
-        name=name
-    )
-
-    return [div_node]
-
+    return create_basic_op_node('Div', node, kwargs)
 
 @mx_op.register("broadcast_div")
 def convert_broadcast_div(node, **kwargs):
     """Map MXNet's broadcast_div operator attributes to onnx's Div operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    div_node = onnx.helper.make_node(
-        "Div",
-        [input_node_a, input_node_b],
-        [name],
-        name=name
-    )
-
-    return [div_node]
-
+    return create_basic_op_node('Div', node, kwargs)
 
 @mx_op.register("negative")
 def convert_negative(node, **kwargs):
     """Map MXNet's negative operator attributes to onnx's Neg operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-
-    input_node = proc_nodes[input_node_id].name
-
-    neg_node = onnx.helper.make_node(
-        "Neg",
-        [input_node],
-        [name],
-        name=name,
-    )
-
-    return [neg_node]
-
+    return create_basic_op_node('Neg', node, kwargs)
 
 @mx_op.register("abs")
 def convert_abs(node, **kwargs):
     """Map MXNet's abs operator attributes to onnx's Abs operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-
-    input_node = proc_nodes[input_node_id].name
-
-    abs_node = onnx.helper.make_node(
-        "Abs",
-        [input_node],
-        [name],
-        name=name
-    )
-
-    return [abs_node]
-
+    return create_basic_op_node('Abs', node, kwargs)
 
 @mx_op.register("add_n")
 def convert_addn(node, **kwargs):
     """Map MXNet's add_n operator attributes to onnx's Sum operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_list = []
-    for input_val in inputs:
-        
input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name)
-
-    sum_node = onnx.helper.make_node(
-        "Sum",
-        input_list,
-        [name],
-        name=name
-    )
-    return [sum_node]
+    return create_basic_op_node('Sum', node, kwargs)
 
  # Rounding
 @mx_op.register("ceil")
@@ -1787,42 +1228,14 @@ def convert_ceil(node, **kwargs):
     """Map MXNet's ceil operator attributes to onnx's Ceil operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Ceil",
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Ceil', node, kwargs)
 
 @mx_op.register("floor")
 def convert_floor(node, **kwargs):
     """Map MXNet's floor operator attributes to onnx's Floor operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Floor",
-        [input_node],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Floor', node, kwargs)
 
 # Changing shape and type.
 @mx_op.register("Reshape")
@@ -1831,11 +1244,7 @@ def convert_reshape(node, **kwargs):
     Converts output shape attribute to output shape tensor
     and return multiple created nodes.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     output_shape_list = convert_string_to_list(attrs["shape"])
 
@@ -1857,8 +1266,7 @@ def convert_reshape(node, **kwargs):
         )
     )
 
-    input_node_idx = kwargs["index_lookup"][inputs[0][0]]
-    input_node_name = proc_nodes[input_node_idx].name
+    input_nodes.append(output_shape_name)
 
     not_supported_shape = [-2, -3, -4]
 
@@ -1868,7 +1276,7 @@ def convert_reshape(node, **kwargs):
 
     reshape_node = onnx.helper.make_node(
         "Reshape",
-        [input_node_name, output_shape_name],
+        input_nodes,
         [name],
         name=name
     )
@@ -1880,11 +1288,9 @@ def convert_cast(node, **kwargs):
     """Map MXNet's Cast operator attributes to onnx's Cast operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    dtype = node["attrs"]["dtype"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    dtype = attrs["dtype"]
 
     # dtype can be mapped only with types from TensorProto
     # float32 is mapped to float and float64 to double in onnx
@@ -1894,12 +1300,9 @@ def convert_cast(node, **kwargs):
     elif dtype == 'float64':
         dtype = 'double'
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
     node = onnx.helper.make_node(
         "Cast",
-        [input_node],
+        input_nodes,
         [name],
         to=getattr(onnx.TensorProto, dtype.upper()),
         name=name,
@@ -1912,23 +1315,17 @@ def convert_slice_axis(node, **kwargs):
     """Map MXNet's slice_axis operator attributes to onnx's Slice operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    axes = int(node["attrs"]["axis"])
-    starts = int(node["attrs"]["begin"])
-    if node["attrs"]["end"] == 'None':
-        raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' 
attribute")
-    else:
-        ends = int(node["attrs"]["end"])
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    axes = int(attrs.get("axis"))
+    starts = int(attrs.get("begin"))
+    ends = int(attrs.get("end", None))
+    if not ends:
+        raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' 
attribute")
 
     node = onnx.helper.make_node(
         "Slice",
-        [input_node],
+        input_nodes,
         [name],
         axes=[axes],
         starts=[starts],
@@ -1944,21 +1341,16 @@ def convert_slice_channel(node, **kwargs):
     operator based on squeeze_axis attribute
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    num_outputs = int(node.get("attrs", {})["num_outputs"])
-    axis = int(node.get("attrs", {}).get("axis", 1))
-    squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0))
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    num_outputs = int(attrs.get("num_outputs"))
+    axis = int(attrs.get("axis", 1))
+    squeeze_axis = int(attrs.get("squeeze_axis", 0))
 
     if squeeze_axis == 1 and num_outputs == 1:
         node = onnx.helper.make_node(
             "Squeeze",
-            [input_node],
+            input_nodes,
             [name],
             axes=[axis],
             name=name,
@@ -1967,7 +1359,7 @@ def convert_slice_channel(node, **kwargs):
     elif squeeze_axis == 0 and num_outputs > 1:
         node = onnx.helper.make_node(
             "Split",
-            [input_node],
+            input_nodes,
             [name],
             axis=axis,
             split=[num_outputs],
@@ -1984,18 +1376,13 @@ def convert_expand_dims(node, **kwargs):
     """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    axis = int(node["attrs"]["axis"])
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    axis = int(attrs.get("axis"))
 
     node = onnx.helper.make_node(
         "Unsqueeze",
-        [input_node],
+        input_nodes,
         [name],
         axes=[axis],
         name=name,
@@ -2007,22 +1394,17 @@ def convert_squeeze(node, **kwargs):
     """Map MXNet's squeeze operator attributes to onnx's squeeze operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    if "axis" in node["attrs"]:
-        axis = convert_string_to_list(node["attrs"]["axis"])
-    else:
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    axis = attrs.get("axis", None)
+    if not axis:
         raise AttributeError("Missing axis attribute: ONNX currently requires 
axis to "
                              "be specified for squeeze operator")
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    axis = convert_string_to_list(axis)
 
     node = onnx.helper.make_node(
         "Squeeze",
-        [input_node],
+        input_nodes,
         [name],
         axes=axis,
         name=name,
@@ -2035,132 +1417,48 @@ def convert_log(node, **kwargs):
     """Map MXNet's log operator attributes to onnx's Log operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Log",
-        [input_node],
-        [name],
-        name=name,
-    )
-    return [node]
-
+    return create_basic_op_node('Log', node, kwargs)
 
 @mx_op.register("reciprocal")
 def convert_reciprocal(node, **kwargs):
     """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Reciprocal",
-        [input_node],
-        [name],
-        name=name,
-    )
-    return [node]
+    return create_basic_op_node('Reciprocal', node, kwargs)
 
 @mx_op.register("_power")
 def convert_power(node, **kwargs):
     """Map MXNet's _power operator attributes to onnx's Pow operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    node = onnx.helper.make_node(
-        "Pow",
-        [input_node_a, input_node_b],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Pow', node, kwargs)
 
 @mx_op.register("broadcast_power")
 def convert_broadcast_power(node, **kwargs):
     """Map MXNet's _power operator attributes to onnx's Pow operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_b_id = kwargs["index_lookup"][inputs[1][0]]
-
-    input_node_a = proc_nodes[input_node_a_id].name
-    input_node_b = proc_nodes[input_node_b_id].name
-
-    node = onnx.helper.make_node(
-        "Pow",
-        [input_node_a, input_node_b],
-        [name],
-        name=name
-    )
-    return [node]
+    return create_basic_op_node('Pow', node, kwargs)
 
 @mx_op.register("sqrt")
 def convert_sqrt(node, **kwargs):
     """Map MXNet's sqrt operator attributes to onnx's Sqrt operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
-    node = onnx.helper.make_node(
-        "Sqrt",
-        [input_node],
-        [name],
-        name=name,
-    )
-    return [node]
+    return create_basic_op_node('Sqrt', node, kwargs)
 
 @mx_op.register("depth_to_space")
 def convert_depthtospace(node, **kwargs):
     """Map MXNet's depth_to_space operator attributes to onnx's
     DepthToSpace operator and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    attrs = node["attrs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     blksize = int(attrs.get("block_size", 0))
 
     node = onnx.helper.make_node(
         "DepthToSpace",
-        [input_node],
+        input_nodes,
         [name],
         blocksize=blksize,
         name=name,
@@ -2172,20 +1470,13 @@ def convert_spacetodepth(node, **kwargs):
     """Map MXNet's space_to_depth operator attributes to onnx's
     SpaceToDepth operator and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    attrs = node["attrs"]
-
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     blksize = int(attrs.get("block_size", 0))
 
     node = onnx.helper.make_node(
         "SpaceToDepth",
-        [input_node],
+        input_nodes,
         [name],
         blocksize=blksize,
         name=name,
@@ -2197,13 +1488,7 @@ def convert_square(node, **kwargs):
     """Map MXNet's square operator attributes to onnx's Pow operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-
-    input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node_a = proc_nodes[input_node_a_id].name
+    name, input_nodes, _ = get_inputs(node, kwargs)
 
     initializer = kwargs["initializer"]
     data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
@@ -2220,9 +1505,11 @@ def convert_square(node, **kwargs):
         )
     )
 
+    input_nodes.append(power2_name)
+
     node = onnx.helper.make_node(
         "Pow",
-        [input_node_a, power2_name],
+        input_nodes,
         [name],
         name=name
     )
@@ -2233,24 +1520,17 @@ def convert_sum(node, **kwargs):
     """Map MXNet's sum operator attributes to onnx's ReduceSum operator
     and return the created node.
     """
-    onnx = import_onnx_modules()
-    name = node["name"]
-    proc_nodes = kwargs["proc_nodes"]
-    inputs = node["inputs"]
-    attrs = node["attrs"]
+    name, input_nodes, attrs = get_inputs(node, kwargs)
 
     mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else 
None
 
     keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
-    input_node_id = kwargs["index_lookup"][inputs[0][0]]
-    input_node = proc_nodes[input_node_id].name
-
     if axes:
         node = onnx.helper.make_node(
             'ReduceSum',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             axes=axes,
             keepdims=keepdims,
@@ -2259,7 +1539,7 @@ def convert_sum(node, **kwargs):
     else:
         node = onnx.helper.make_node(
             'ReduceSum',
-            inputs=[input_node],
+            inputs=input_nodes,
             outputs=[name],
             keepdims=keepdims,
             name=name
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 11847381ab2..b02d970f9c2 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -77,7 +77,11 @@ def register(op_name):
         """Register operators"""
         def wrapper(func):
             """Helper function to map functions"""
-            MXNetGraph.registry_[op_name] = func
+            try:
+                import onnx as _
+                MXNetGraph.registry_[op_name] = func
+            except ImportError:
+                pass
             return func
 
         return wrapper


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to