szha closed pull request #11637: [MXNET-652] Allow for multi-context input to
load gluon models from onnx
URL: https://github.com/apache/incubator-mxnet/pull/11637
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
index 4e851712972..3af196f8b09 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
@@ -20,7 +20,6 @@
""" Support import export formats."""
from __future__ import absolute_import as _abs
from .... import symbol
-from .... import cpu, gpu
from .... import ndarray as nd
from ....base import string_types
from ._import_helper import _convert_map as convert_map
@@ -157,15 +156,15 @@ def get_graph_metadata(self, graph):
}
return metadata
- def graph_to_gluon(self, graph, context):
+ def graph_to_gluon(self, graph, ctx):
"""Construct SymbolBlock from onnx graph.
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
- context : str
- context for mxnet module object. Should be 'CPU' or 'GPU'
+ ctx : Context or list of Context
+ Loads the model into one or many context(s).
Returns
-------
@@ -177,7 +176,6 @@ def graph_to_gluon(self, graph, context):
data_names = [input_tensor[0] for input_tensor in
metadata['input_tensor_data']]
data_inputs = [symbol.var(data_name) for data_name in data_names]
- ctx = gpu() if context == 'GPU' else cpu()
from ....gluon import SymbolBlock
net = SymbolBlock(outputs=sym, inputs=data_inputs)
net_params = net.collect_params()
diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
index eee968b32cd..5df41c3f327 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py
@@ -21,7 +21,7 @@
from .import_onnx import GraphProto
-def import_to_gluon(model_file, context):
+def import_to_gluon(model_file, ctx):
"""
Imports the ONNX model files, passed as a parameter, into Gluon
SymbolBlock object.
@@ -29,8 +29,8 @@ def import_to_gluon(model_file, context):
----------
model_file : str
ONNX model file name
- context : str
- context. Should be 'CPU' or 'GPU'
+ ctx : Context or list of Context
+ Loads the model into one or many context(s).
Returns
-------
@@ -44,5 +44,5 @@ def import_to_gluon(model_file, context):
raise ImportError("Onnx and protobuf need to be installed.
Instructions to"
+ " install -
https://github.com/onnx/onnx#installation")
model_proto = onnx.load(model_file)
- net = graph.graph_to_gluon(model_proto.graph, context)
+ net = graph.graph_to_gluon(model_proto.graph, ctx)
return net
diff --git a/tests/python-pytest/onnx/import/gluon_backend.py
b/tests/python-pytest/onnx/import/gluon_backend.py
index 302fd4dcf08..25be60b57dc 100644
--- a/tests/python-pytest/onnx/import/gluon_backend.py
+++ b/tests/python-pytest/onnx/import/gluon_backend.py
@@ -18,6 +18,7 @@
# coding: utf-8
"""Gluon backend wrapper for onnx test infrastructure"""
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
+import mxnet as mx
try:
from onnx import helper, TensorProto
@@ -55,7 +56,12 @@ def prepare(cls, model, device='CPU', **kwargs):
used to run inference on the input model and return the result for
comparison.
"""
graph = GraphProto()
- net = graph.graph_to_gluon(model.graph, device)
+ if device == 'CPU':
+ ctx = mx.cpu()
+ else:
+ raise NotImplementedError("ONNX tests are run only for CPU
context.")
+
+ net = graph.graph_to_gluon(model.graph, ctx)
return GluonBackendRep(net, device)
@classmethod
@@ -63,6 +69,7 @@ def supports_device(cls, device):
"""Supports only CPU for testing"""
return device == 'CPU'
+
prepare = GluonBackend.prepare
supports_device = GluonBackend.supports_device
diff --git a/tests/python-pytest/onnx/import/gluon_backend_rep.py
b/tests/python-pytest/onnx/import/gluon_backend_rep.py
index a90d350c8cd..04c6ddde63e 100644
--- a/tests/python-pytest/onnx/import/gluon_backend_rep.py
+++ b/tests/python-pytest/onnx/import/gluon_backend_rep.py
@@ -34,6 +34,7 @@
# Implemented by following onnx docs guide:
# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
+
class GluonBackendRep(BackendRep):
"""Running model inference on gluon backend and return the result
to onnx test infrastructure for comparison."""
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services