This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f665c13 [MXNET-652] Allow for multi-context input to load gluon
models from onnx (#11637)
f665c13 is described below
commit f665c132c52baeeaa7d0ac3901bbb7a03e488012
Author: Ray Zhang <[email protected]>
AuthorDate: Tue Jul 10 22:48:34 2018 -0700
[MXNET-652] Allow for multi-context input to load gluon models from onnx
(#11637)
* Switch from string format to context objects in load from onnx
* Lint
* Addressed tests
---
python/mxnet/contrib/onnx/onnx2mx/import_onnx.py | 8 +++-----
python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py | 8 ++++----
tests/python-pytest/onnx/import/gluon_backend.py | 9 ++++++++-
tests/python-pytest/onnx/import/gluon_backend_rep.py | 1 +
4 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index 4e85171..3af196f 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -20,7 +20,6 @@
""" Support import export formats."""
from __future__ import absolute_import as _abs
from .... import symbol
-from .... import cpu, gpu
from .... import ndarray as nd
from ....base import string_types
from ._import_helper import _convert_map as convert_map
@@ -157,15 +156,15 @@ class GraphProto(object): # pylint:
disable=too-few-public-methods
}
return metadata
- def graph_to_gluon(self, graph, context):
+ def graph_to_gluon(self, graph, ctx):
"""Construct SymbolBlock from onnx graph.
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
- context : str
- context for mxnet module object. Should be 'CPU' or 'GPU'
+ ctx : Context or list of Context
+ Loads the model into one or many context(s).
Returns
-------
@@ -177,7 +176,6 @@ class GraphProto(object): # pylint:
disable=too-few-public-methods
data_names = [input_tensor[0] for input_tensor in
metadata['input_tensor_data']]
data_inputs = [symbol.var(data_name) for data_name in data_names]
- ctx = gpu() if context == 'GPU' else cpu()
from ....gluon import SymbolBlock
net = SymbolBlock(outputs=sym, inputs=data_inputs)
net_params = net.collect_params()
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
index eee968b..5df41c3 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
@@ -21,7 +21,7 @@
from .import_onnx import GraphProto
-def import_to_gluon(model_file, context):
+def import_to_gluon(model_file, ctx):
"""
Imports the ONNX model files, passed as a parameter, into Gluon
SymbolBlock object.
@@ -29,8 +29,8 @@ def import_to_gluon(model_file, context):
----------
model_file : str
ONNX model file name
- context : str
- context. Should be 'CPU' or 'GPU'
+ ctx : Context or list of Context
+ Loads the model into one or many context(s).
Returns
-------
@@ -44,5 +44,5 @@ def import_to_gluon(model_file, context):
raise ImportError("Onnx and protobuf need to be installed.
Instructions to"
+ " install -
https://github.com/onnx/onnx#installation")
model_proto = onnx.load(model_file)
- net = graph.graph_to_gluon(model_proto.graph, context)
+ net = graph.graph_to_gluon(model_proto.graph, ctx)
return net
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py
b/tests/python-pytest/onnx/import/gluon_backend.py
index 302fd4d..25be60b 100644
--- a/tests/python-pytest/onnx/import/gluon_backend.py
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -18,6 +18,7 @@
# coding: utf-8
"""Gluon backend wrapper for onnx test infrastructure"""
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+import mxnet as mx
try:
from onnx import helper, TensorProto
@@ -55,7 +56,12 @@ class GluonBackend(Backend):
used to run inference on the input model and return the result for
comparison.
"""
graph = GraphProto()
- net = graph.graph_to_gluon(model.graph, device)
+ if device == 'CPU':
+ ctx = mx.cpu()
+ else:
+ raise NotImplementedError("ONNX tests are run only for CPU
context.")
+
+ net = graph.graph_to_gluon(model.graph, ctx)
return GluonBackendRep(net, device)
@classmethod
@@ -63,6 +69,7 @@ class GluonBackend(Backend):
"""Supports only CPU for testing"""
return device == 'CPU'
+
prepare = GluonBackend.prepare
supports_device = GluonBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/gluon_backend_rep.py
b/tests/python-pytest/onnx/import/gluon_backend_rep.py
index a90d350..04c6ddd 100644
--- a/tests/python-pytest/onnx/import/gluon_backend_rep.py
+++ b/tests/python-pytest/onnx/import/gluon_backend_rep.py
@@ -34,6 +34,7 @@ from mxnet import nd
# Implemented by following onnx docs guide:
# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+
class GluonBackendRep(BackendRep):
"""Running model inference on gluon backend and return the result
to onnx test infrastructure for comparison."""