vandanavk closed pull request #12634: [WIP] Add optional output label in ONNX 
export
URL: https://github.com/apache/incubator-mxnet/pull/12634
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index e5158051d6f..a572cc20021 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -32,7 +32,7 @@
 from ._export_helper import load_module
 
 
-def export_model(sym, params, input_shape, input_type=np.float32,
+def export_model(sym, params, input_shape, input_type=np.float32, 
out_label=None,
                  onnx_file_path='model.onnx', verbose=False):
     """Exports the MXNet model file, passed as a parameter, into ONNX model.
     Accepts both symbol,parameter objects as well as json and params filepaths 
as input.
@@ -49,6 +49,8 @@ def export_model(sym, params, input_shape, 
input_type=np.float32,
         Input shape of the model e.g [(1,3,224,224)]
     input_type : data type
         Input data type e.g. np.float32
+    out_label : str
+        custom output node label
     onnx_file_path : str
         Path where to save the generated onnx file
     verbose : Boolean
@@ -75,10 +77,12 @@ def export_model(sym, params, input_shape, 
input_type=np.float32,
         sym_obj, params_obj = load_module(sym, params)
         onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, 
input_shape,
                                                        
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       out_label,
                                                        verbose=verbose)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
         onnx_graph = converter.create_onnx_graph_proto(sym, params, 
input_shape,
                                                        
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
+                                                       out_label,
                                                        verbose=verbose)
     else:
         raise ValueError("Input sym and params should either be files or 
objects")
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index 11847381ab2..cca0748def1 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -189,7 +189,7 @@ def convert_weights_to_numpy(weights_dict):
         return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
                      for k, v in weights_dict.items()])
 
-    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, 
verbose=False):
+    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, 
out_label=None, verbose=False):
         """Convert MXNet graph to ONNX graph
 
         Parameters
@@ -202,6 +202,8 @@ def create_onnx_graph_proto(self, sym, params, in_shape, 
in_type, verbose=False)
             Input shape of the model e.g [(1,3,224,224)]
         in_type : data type
             Input data type e.g. np.float32
+        out_label : str
+            Optional output label
         verbose : Boolean
             If true will print logs of the model conversion
 
@@ -222,7 +224,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, 
in_type, verbose=False)
         # name is "Softmax", this node will have a name "Softmax_label". Also, 
the new node
         # will always be second last node in the json graph.
         # Deriving the output_label name.
-        output_label = sym.get_internals()[len(sym.get_internals()) - 1].name 
+ "_label"
+        if not out_label:
+            output_label = sym.get_internals()[len(sym.get_internals()) - 
1].name + "_label"
+        else:
+            output_label = out_label
 
         # Determine output shape
         output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, 
output_label)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to