vandanavk closed pull request #12634: [WIP] Add optional output label in ONNX
export
URL: https://github.com/apache/incubator-mxnet/pull/12634
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index e5158051d6f..a572cc20021 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -32,7 +32,7 @@
from ._export_helper import load_module
-def export_model(sym, params, input_shape, input_type=np.float32,
+def export_model(sym, params, input_shape, input_type=np.float32,
out_label=None,
onnx_file_path='model.onnx', verbose=False):
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths
as input.
@@ -49,6 +49,8 @@ def export_model(sym, params, input_shape,
input_type=np.float32,
Input shape of the model e.g [(1,3,224,224)]
input_type : data type
Input data type e.g. np.float32
+ out_label : str
+ custom output node label
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
@@ -75,10 +77,12 @@ def export_model(sym, params, input_shape,
input_type=np.float32,
sym_obj, params_obj = load_module(sym, params)
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj,
input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+ out_label,
verbose=verbose)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params,
input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+ out_label,
verbose=verbose)
else:
raise ValueError("Input sym and params should either be files or
objects")
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 11847381ab2..cca0748def1 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -189,7 +189,7 @@ def convert_weights_to_numpy(weights_dict):
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])
- def create_onnx_graph_proto(self, sym, params, in_shape, in_type,
verbose=False):
+ def create_onnx_graph_proto(self, sym, params, in_shape, in_type,
out_label=None, verbose=False):
"""Convert MXNet graph to ONNX graph
Parameters
@@ -202,6 +202,8 @@ def create_onnx_graph_proto(self, sym, params, in_shape,
in_type, verbose=False)
Input shape of the model e.g [(1,3,224,224)]
in_type : data type
Input data type e.g. np.float32
+ out_label : str
+ Optional output label
verbose : Boolean
If true will print logs of the model conversion
@@ -222,7 +224,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape,
in_type, verbose=False)
# name is "Softmax", this node will have a name "Softmax_label". Also,
the new node
# will always be second last node in the json graph.
# Deriving the output_label name.
- output_label = sym.get_internals()[len(sym.get_internals()) - 1].name
+ "_label"
+ if not out_label:
+ output_label = sym.get_internals()[len(sym.get_internals()) -
1].name + "_label"
+ else:
+ output_label = out_label
# Determine output shape
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape,
output_label)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services