sandeep-krishnamurthy closed pull request #12920: ONNX export: Instance 
normalization, Shape
URL: https://github.com/apache/incubator-mxnet/pull/12920
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index e2aab6b1efa..facdcfedcbc 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -623,6 +623,23 @@ def convert_identity(node, **kwargs):
     """
     return create_basic_op_node('Identity', node, kwargs)
 
+@mx_op.register("InstanceNorm")
+def convert_instancenorm(node, **kwargs):
+    """Map MXNet's InstanceNorm operator attributes to onnx's 
InstanceNormalization operator
+    based on the input node's attributes and return the created node.
+    """
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    eps = float(attrs.get("eps", 0.001))
+
+    node = onnx.helper.make_node(
+        'InstanceNormalization',
+        inputs=input_nodes,
+        outputs=[name],
+        name=name,
+        epsilon=eps)
+
+    return [node]
 
 @mx_op.register("LeakyReLU")
 def convert_leakyrelu(node, **kwargs):
@@ -1546,6 +1563,15 @@ def convert_sum(node, **kwargs):
         )
     return [node]
 
+
+@mx_op.register("shape_array")
+def convert_shape(node, **kwargs):
+    """Map MXNet's shape_array operator attributes to onnx's Shape operator
+    and return the created node.
+    """
+    return create_basic_op_node('Shape', node, kwargs)
+
+
 @mx_op.register("hard_sigmoid")
 def convert_hardsigmoid(node, **kwargs):
     """Map MXNet's hard_sigmoid operator attributes to onnx's HardSigmoid 
operator
diff --git a/tests/python-pytest/onnx/export/backend_rep.py 
b/tests/python-pytest/onnx/backend_rep.py
similarity index 78%
rename from tests/python-pytest/onnx/export/backend_rep.py
rename to tests/python-pytest/onnx/backend_rep.py
index 8729eafea1a..63836ac848d 100644
--- a/tests/python-pytest/onnx/export/backend_rep.py
+++ b/tests/python-pytest/onnx/backend_rep.py
@@ -16,16 +16,17 @@
 # under the License.
 
 # coding: utf-8
-"""backend rep for onnx test infrastructure"""
+"""MXNet backend rep for onnx test infrastructure"""
 try:
     from onnx.backend.base import BackendRep
 except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed")
+    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+                      + " install - https://github.com/onnx/onnx#installation";)
 import mxnet as mx
 
 # Using these functions for onnx test infrastructure.
 # Implemented by following onnx docs guide:
-# 
https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
+# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
 # MXNetBackendRep object will be returned by MXNetBackend's prepare method 
which is used to
 # execute a model repeatedly.
 # Inputs will be passed to the run method of MXNetBackendRep class, it will 
perform computation and
@@ -54,9 +55,6 @@ def run(self, inputs, **kwargs):
         params : numpy array
             result obtained after running the inference on mxnet
         """
-        data_forward = []
-        for val in inputs:
-            data_forward.append(mx.nd.array(val))
         # create module, passing cpu context
         if self.device == 'CPU':
             ctx = mx.cpu()
@@ -68,17 +66,19 @@ def run(self, inputs, **kwargs):
         data_names = [graph_input for graph_input in self.symbol.list_inputs()
                       if graph_input not in self.arg_params and graph_input 
not in self.aux_params]
 
-        data_shapes = []
+        data_forward = []
         for idx, input_name in enumerate(data_names):
-            data_shapes.append((input_name, inputs[idx].shape))
+            val = inputs[idx]
+            data_forward.append(mx.nd.array(val))
 
-        mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, 
context=ctx,
-                            label_names=None)
-        mod.bind(for_training=False, data_shapes=data_shapes,
-                 label_shapes=None)
-        mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
+        if self.arg_params:
+            for idx, input_name in enumerate(self.arg_params):
+                val = self.arg_params[input_name]
+                data_names.append(input_name)
+                data_forward.append(mx.nd.array(val))
 
-        # run inference
-        mod.forward(mx.io.DataBatch(data_forward))
-        result = mod.get_outputs()[0].asnumpy()
+        args = dict(zip(data_names, data_forward))
+        exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
+        exe.forward(is_train=False)
+        result = exe.outputs[0].asnumpy()
         return [result]
diff --git a/tests/python-pytest/onnx/export/backend.py 
b/tests/python-pytest/onnx/export/backend.py
index e23cc01494e..3ea1dafca25 100644
--- a/tests/python-pytest/onnx/export/backend.py
+++ b/tests/python-pytest/onnx/export/backend.py
@@ -17,6 +17,8 @@
 
 # coding: utf-8
 """backend wrapper for onnx test infrastructure"""
+import os
+import sys
 import numpy as np
 from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
 from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph
@@ -25,6 +27,8 @@
     from onnx.backend.base import Backend
 except ImportError:
     raise ImportError("Onnx and protobuf need to be installed")
+CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(CURR_PATH, '../'))
 from backend_rep import MXNetBackendRep
 
 # Using these functions for onnx test infrastructure.
diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py 
b/tests/python-pytest/onnx/export/onnx_backend_test.py
index ec9ddf23c25..be9273eb6fa 100644
--- a/tests/python-pytest/onnx/export/onnx_backend_test.py
+++ b/tests/python-pytest/onnx/export/onnx_backend_test.py
@@ -95,7 +95,9 @@
     'test_clip'
     'test_cast',
     'test_depthtospace',
-    'test_hardsigmoid'
+    'test_hardsigmoid',
+    'test_instancenorm',
+    'test_shape'
     ]
 
 BASIC_MODEL_TESTS = [
diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py 
b/tests/python-pytest/onnx/import/mxnet_backend.py
index 10f89ecbbbc..bd4910b64f8 100644
--- a/tests/python-pytest/onnx/import/mxnet_backend.py
+++ b/tests/python-pytest/onnx/import/mxnet_backend.py
@@ -17,6 +17,8 @@
 
 # coding: utf-8
 """MXNet backend wrapper for onnx test infrastructure"""
+import os
+import sys
 from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
 try:
     from onnx import helper, TensorProto
@@ -24,7 +26,9 @@
 except ImportError:
     raise ImportError("Onnx and protobuf need to be installed. Instructions to"
                       + " install - https://github.com/onnx/onnx#installation";)
-from mxnet_backend_rep import MXNetBackendRep
+CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(CURR_PATH, '../'))
+from backend_rep import MXNetBackendRep
 
 # MXNetBackend class will take an ONNX model with inputs, perform a 
computation,
 # and then return the output.
diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py 
b/tests/python-pytest/onnx/import/mxnet_backend_rep.py
deleted file mode 100644
index 938f25d38bf..00000000000
--- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding: utf-8
-"""MXNet backend rep for onnx test infrastructure"""
-try:
-    from onnx.backend.base import BackendRep
-except ImportError:
-    raise ImportError("Onnx and protobuf need to be installed. Instructions to"
-                      + " install - https://github.com/onnx/onnx#installation";)
-import mxnet as mx
-
-# Using these functions for onnx test infrastructure.
-# Implemented by following onnx docs guide:
-# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
-# MXNetBackendRep object will be returned by MXNetBackend's prepare method 
which is used to
-# execute a model repeatedly.
-# Inputs will be passed to the run method of MXNetBackendRep class, it will 
perform computation and
-# retrieve the corresponding results for comparison to the onnx backend.
-# 
https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py.
-
-class MXNetBackendRep(BackendRep):
-    """Running model inference on mxnet engine and return the result
-     to onnx test infrastructure for comparison."""
-    def __init__(self, symbol, arg_params, aux_params, device):
-        self.symbol = symbol
-        self.arg_params = arg_params
-        self.aux_params = aux_params
-        self.device = device
-
-    def run(self, inputs, **kwargs):
-        """Run model inference and return the result
-
-        Parameters
-        ----------
-        inputs : numpy array
-            input to run a layer on
-
-        Returns
-        -------
-        params : numpy array
-            result obtained after running the inference on mxnet
-        """
-        data_forward = []
-        for val in inputs:
-            data_forward.append(mx.nd.array(val))
-        # create module, passing cpu context
-        if self.device == 'CPU':
-            ctx = mx.cpu()
-        else:
-            raise NotImplementedError("ONNX tests are run only for CPU 
context.")
-
-        # To fetch the data names of the input to the model we list the inputs 
of the symbol graph
-        # and exclude the argument and auxiliary parameters from the list
-        data_names = [graph_input for graph_input in self.symbol.list_inputs()
-                      if graph_input not in self.arg_params and graph_input 
not in self.aux_params]
-
-        data_shapes = []
-        for idx, input_name in enumerate(data_names):
-            data_shapes.append((input_name, inputs[idx].shape))
-
-        # module bind method requires all data to have same batch size,
-        # using module if all data have same batch size
-        if len(set([data_shape[1][0] for data_shape in data_shapes])) == 1:
-            mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, 
context=ctx,
-                                label_names=None)
-            mod.bind(for_training=False, data_shapes=data_shapes,
-                     label_shapes=None)
-            mod.set_params(arg_params=self.arg_params, 
aux_params=self.aux_params)
-
-            # run inference
-            mod.forward(mx.io.DataBatch(data_forward))
-            result = mod.get_outputs()[0].asnumpy()
-            # split operator inference returns 1 less dimension
-            if self.symbol.name.startswith('split'):
-                return [i.asnumpy() for i in mod.get_outputs()]
-            return [result]
-        # using symbol bind method if data have different batch size
-        else:
-            exec1 = self.symbol.bind(ctx, args=dict(zip(data_names, 
data_forward)))
-            exec1.forward(is_train=False)
-            result = exec1.outputs[0].asnumpy()
-            return [result]
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to