This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f665c13  [MXNET-652] Allow for multi-context input to load gluon 
models from onnx (#11637)
f665c13 is described below

commit f665c132c52baeeaa7d0ac3901bbb7a03e488012
Author: Ray Zhang <[email protected]>
AuthorDate: Tue Jul 10 22:48:34 2018 -0700

    [MXNET-652] Allow for multi-context input to load gluon models from onnx 
(#11637)
    
    * Switch from string format to context objects in load from onnx
    
    * Lint
    
    * Addressed tests
---
 python/mxnet/contrib/onnx/onnx2mx/import_onnx.py     | 8 +++-----
 python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py | 8 ++++----
 tests/python-pytest/onnx/import/gluon_backend.py     | 9 ++++++++-
 tests/python-pytest/onnx/import/gluon_backend_rep.py | 1 +
 4 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py 
b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index 4e85171..3af196f 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -20,7 +20,6 @@
 """ Support import export formats."""
 from __future__ import absolute_import as _abs
 from .... import symbol
-from .... import cpu, gpu
 from .... import ndarray as nd
 from ....base import string_types
 from ._import_helper import _convert_map as convert_map
@@ -157,15 +156,15 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
                    }
         return metadata
 
-    def graph_to_gluon(self, graph, context):
+    def graph_to_gluon(self, graph, ctx):
         """Construct SymbolBlock from onnx graph.
 
         Parameters
         ----------
         graph : onnx protobuf object
             The loaded onnx graph
-        context : str
-            context for mxnet module object. Should be 'CPU' or 'GPU'
+        ctx : Context or list of Context
+            Loads the model into one or many context(s).
 
         Returns
         -------
@@ -177,7 +176,6 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
         data_names = [input_tensor[0] for input_tensor in 
metadata['input_tensor_data']]
         data_inputs = [symbol.var(data_name) for data_name in data_names]
 
-        ctx = gpu() if context == 'GPU' else cpu()
         from ....gluon import SymbolBlock
         net = SymbolBlock(outputs=sym, inputs=data_inputs)
         net_params = net.collect_params()
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py 
b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
index eee968b..5df41c3 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
@@ -21,7 +21,7 @@
 
 from .import_onnx import GraphProto
 
-def import_to_gluon(model_file, context):
+def import_to_gluon(model_file, ctx):
     """
     Imports the ONNX model files, passed as a parameter, into Gluon 
SymbolBlock object.
 
@@ -29,8 +29,8 @@ def import_to_gluon(model_file, context):
     ----------
     model_file : str
         ONNX model file name
-    context : str
-        context. Should be 'CPU' or 'GPU'
+    ctx : Context or list of Context
+        Loads the model into one or many context(s).
 
     Returns
     -------
@@ -44,5 +44,5 @@ def import_to_gluon(model_file, context):
         raise ImportError("Onnx and protobuf need to be installed. 
Instructions to"
                           + " install - 
https://github.com/onnx/onnx#installation";)
     model_proto = onnx.load(model_file)
-    net = graph.graph_to_gluon(model_proto.graph, context)
+    net = graph.graph_to_gluon(model_proto.graph, ctx)
     return net
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py 
b/tests/python-pytest/onnx/import/gluon_backend.py
index 302fd4d..25be60b 100644
--- a/tests/python-pytest/onnx/import/gluon_backend.py
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -18,6 +18,7 @@
 # coding: utf-8
 """Gluon backend wrapper for onnx test infrastructure"""
 from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+import mxnet as mx
 
 try:
     from onnx import helper, TensorProto
@@ -55,7 +56,12 @@ class GluonBackend(Backend):
             used to run inference on the input model and return the result for 
comparison.
         """
         graph = GraphProto()
-        net = graph.graph_to_gluon(model.graph, device)
+        if device == 'CPU':
+            ctx = mx.cpu()
+        else:
+            raise NotImplementedError("ONNX tests are run only for CPU 
context.")
+
+        net = graph.graph_to_gluon(model.graph, ctx)
         return GluonBackendRep(net, device)
 
     @classmethod
@@ -63,6 +69,7 @@ class GluonBackend(Backend):
         """Supports only CPU for testing"""
         return device == 'CPU'
 
+
 prepare = GluonBackend.prepare
 
 supports_device = GluonBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/gluon_backend_rep.py 
b/tests/python-pytest/onnx/import/gluon_backend_rep.py
index a90d350..04c6ddd 100644
--- a/tests/python-pytest/onnx/import/gluon_backend_rep.py
+++ b/tests/python-pytest/onnx/import/gluon_backend_rep.py
@@ -34,6 +34,7 @@ from mxnet import nd
 # Implemented by following onnx docs guide:
 # https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
 
+
 class GluonBackendRep(BackendRep):
     """Running model inference on gluon backend and return the result
      to onnx test infrastructure for comparison."""

Reply via email to