spidyDev commented on a change in pull request #10605: [MXNET-310] [ONNX-MXNet] 
API to import ONNX models into Gluon.
URL: https://github.com/apache/incubator-mxnet/pull/10605#discussion_r182510986
 
 

 ##########
 File path: python/mxnet/contrib/onnx/_import/import_onnx.py
 ##########
 @@ -132,13 +133,71 @@ def from_onnx(self, graph):
             out = out[0]
         return out, argDict, auxDict
 
+    def get_graph_metadata(self, graph):
+        """
+        Get metadata from a given onnx graph.
+        """
+        _params = set()
+        for tensor_vals in graph.initializer:
+            _params.add(tensor_vals.name)
+
+        input_data = []
+        for graph_input in graph.input:
+            shape = []
+            if graph_input.name not in _params:
+                for val in graph_input.type.tensor_type.shape.dim:
+                    shape.append(val.dim_value)
+                input_data.append((graph_input.name, tuple(shape)))
+
+        output_data = []
+        for graph_out in graph.output:
+            shape = []
+            for val in graph_out.type.tensor_type.shape.dim:
+                shape.append(val.dim_value)
+            output_data.append((graph_out.name, tuple(shape)))
+        metadata = {'input_tensor_data' : input_data,
+                    'output_tensor_data' : output_data
+                   }
+        return metadata
+
+    def graph_to_gluon(self, graph):
+        """Construct SymbolBlock from onnx graph.
+
+        Parameters
+        ----------
+        graph : onnx protobuf object
+            The loaded onnx graph
+
+        Returns
+        -------
+        sym_block :gluon.nn.SymbolBlock
+            The returned gluon SymbolBlock
+        """
+        sym, arg_params, aux_params = self.from_onnx(graph)
+        metadata = self.get_graph_metadata(graph)
+        data_names = [input_tensor[0] for input_tensor in 
metadata['input_tensor_data']]
+        data_inputs = [symbol.var(data_name) for data_name in data_names]
+
+        from ....gluon import SymbolBlock
+        net = SymbolBlock(outputs=sym, inputs=data_inputs)
 
 Review comment:
   Add few comments to explain what is the logic here 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to