Roshrini closed pull request #13641: onnx export operators added
URL: https://github.com/apache/incubator-mxnet/pull/13641
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 15624b6c3a2..0dd816bcc6f 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -630,12 +630,19 @@ def convert_exp(node, **kwargs):
return create_basic_op_node('Exp', node, kwargs)
@mx_op.register("_copy")
-def convert_identity(node, **kwargs):
+def convert_copy(node, **kwargs):
"""Map MXNet's _copy operator attributes to onnx's Identity operator
and return the created node.
"""
return create_basic_op_node('Identity', node, kwargs)
+@mx_op.register("identity")
+def convert_identity(node, **kwargs):
+ """Map MXNet's identity operator attributes to onnx's ConstantFill operator
+ and return the created node.
+ """
+ return create_basic_op_node('ConstantFill', node, kwargs)
+
@mx_op.register("InstanceNorm")
def convert_instancenorm(node, **kwargs):
"""Map MXNet's InstanceNorm operator attributes to onnx's
InstanceNormalization operator
@@ -726,6 +733,32 @@ def convert_softmax_output(node, **kwargs):
return [softmax_node]
+@mx_op.register("LogisticRegressionOutput")
+def convert_logistic_regression_output(node, **kwargs):
+ """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
+ and return the created node.
+ """
+ name = node["name"]
+ input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+ input1 = kwargs["proc_nodes"][input1_idx]
+ sigmoid_node = onnx.helper.make_node(
+ "Sigmoid",
+ [input1.name],
+ [name],
+ name=name
+ )
+ return [sigmoid_node]
+
+@mx_op.register("BlockGrad")
+def convert_blockgrad(node, **kwargs):
+ """ Skip operator """
+ return create_basic_op_node('ConstantFill', node, kwargs)
+
+@mx_op.register("MakeLoss")
+def convert_makeloss(node, **kwargs):
+ """ Skip operator """
+ return create_basic_op_node('ConstantFill', node, kwargs)
+
@mx_op.register("Concat")
def convert_concat(node, **kwargs):
@@ -872,6 +905,7 @@ def convert_clip(node, **kwargs):
def scalar_op_helper(node, op_name, **kwargs):
"""Helper function for scalar arithmetic operations"""
name, input_nodes, attrs = get_inputs(node, kwargs)
+ from onnx import numpy_helper
input_type = kwargs["in_type"]
scalar_value = np.array([attrs.get("scalar", 1)],
@@ -884,13 +918,18 @@ def scalar_op_helper(node, op_name, **kwargs):
for i in initializer:
if i.name == input_nodes[0]:
if op_name == 'Mul':
- new_initializer = onnx.numpy_helper.to_array(i) *
scalar_value[0]
+ new_initializer = numpy_helper.to_array(i) * scalar_value[0]
elif op_name == 'Sub':
- new_initializer = onnx.numpy_helper.to_array(i) -
scalar_value[0]
+ if name.startswith("_rminusscalar"):
+ new_initializer = scalar_value[0] -
numpy_helper.to_array(i)
+ else:
+ new_initializer = numpy_helper.to_array(i) -
scalar_value[0]
elif op_name == 'Add':
- new_initializer = onnx.numpy_helper.to_array(i) +
scalar_value[0]
+ new_initializer = numpy_helper.to_array(i) + scalar_value[0]
elif op_name == 'Div':
- new_initializer = onnx.numpy_helper.to_array(i) /
scalar_value[0]
+ new_initializer = numpy_helper.to_array(i) / scalar_value[0]
+ elif op_name == 'Pow':
+ new_initializer = numpy_helper.to_array(i) ** scalar_value[0]
flag = False
break
@@ -956,6 +995,13 @@ def convert_minus_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Sub', **kwargs)
+@mx_op.register("_rminus_scalar")
+def convert_rminus_scalar(node, **kwargs):
+ """Map MXNet's _rminus_scalar operator attributes to onnx's Sub operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Sub', **kwargs)
# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_plus_scalar")
@@ -975,6 +1021,21 @@ def convert_div_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Div', **kwargs)
+@mx_op.register("_rdiv_scalar")
+def convert_rdiv_scalar(node, **kwargs):
+ """Map MXNet's _rdiv_scalar operator attributes to onnx's Div operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Div', **kwargs)
+
+@mx_op.register("_power_scalar")
+def convert_pow_scalar(node, **kwargs):
+ """Map MXNet's _pow_scalar operator attributes to onnx's Pow operator.
+ Creates a new node for the input scalar value, adds it to the initializer
+ and return multiple created nodes.
+ """
+ return scalar_op_helper(node, 'Pow', **kwargs)
# Sorting and Searching
@mx_op.register("argmax")
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index f63c1e9e8e6..9700dd6a30a 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -158,7 +158,7 @@ def _fix_broadcast(op_name, inputs, broadcast_axis,
proto_obj):
assert len(list(inputs)) == 2
input0_shape = get_input_shape(inputs[0], proto_obj)
- #creating reshape shape
+ # creating reshape shape
reshape_shape = list(len(input0_shape) * (1,))
reshape_shape[broadcast_axis] = -1
reshape_shape = tuple(reshape_shape)
diff --git a/tests/python-pytest/onnx/backend_test.py
b/tests/python-pytest/onnx/backend_test.py
index 6c6c3d2d9c7..5ec5efb8ea2 100644
--- a/tests/python-pytest/onnx/backend_test.py
+++ b/tests/python-pytest/onnx/backend_test.py
@@ -50,6 +50,7 @@ def prepare_tests(backend, operation):
for std_model_test in std_models:
BACKEND_TESTS.include(std_model_test)
- BACKEND_TESTS.exclude('.*bcast.*')
+ # Tests for scalar ops are in test_node.py
+ BACKEND_TESTS.exclude('.*scalar.*')
return BACKEND_TESTS
diff --git a/tests/python-pytest/onnx/test_node.py
b/tests/python-pytest/onnx/test_node.py
index 07ae866b96c..41b86de4b9b 100644
--- a/tests/python-pytest/onnx/test_node.py
+++ b/tests/python-pytest/onnx/test_node.py
@@ -138,6 +138,30 @@ def get_onnx_graph(testname, input_names, inputs,
output_name, output_shape, att
npt.assert_almost_equal(output[0], mxnet_output)
+ input1 = get_rnd((1, 10, 2, 3))
+ ipsym = mx.sym.Variable("input1")
+ for test in test_scalar_ops:
+ if test == 'Add':
+ outsym = 2 + ipsym
+ if test == "Sub":
+ outsym = ipsym - 2
+ if test == "rSub":
+ outsym = ipsym.__rsub__(2)
+ if test == "Mul":
+ outsym = 2 * ipsym
+ if test == "Div":
+ outsym = ipsym / 2
+ if test == "Pow":
+ outsym = ipsym ** 2
+ forward_op = forward_pass(outsym, None, None, ['input1'], input1)
+ converted_model = onnx_mxnet.export_model(outsym, {},
[np.shape(input1)], np.float32,
+
onnx_file_path=outsym.name + ".onnx")
+
+ sym, arg_params, aux_params =
onnx_mxnet.import_model(converted_model)
+ result = forward_pass(sym, arg_params, aux_params, ['input1'],
input1)
+
+ npt.assert_almost_equal(result, forward_op)
+
# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list],
attribute map, MXNet_specific=True/False)
test_cases = [
@@ -156,9 +180,13 @@ def get_onnx_graph(testname, input_names, inputs,
output_name, output_shape, att
{'block_size': 2}, False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)),
get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True),
+ ("test_logistic_regression", mx.sym.LogisticRegressionOutput, "Sigmoid",
+ [get_rnd((1000, 1000)), get_rnd((1000, 1000))], {}, True),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)),
get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True)
]
+test_scalar_ops = ['Add', 'Sub', 'rSub' 'Mul', 'Div', 'Pow']
+
if __name__ == '__main__':
unittest.main()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services