joddiy commented on issue #691:
URL: https://github.com/apache/singa/issues/691#issuecomment-629608607
```
# handle ONNX
def to_onnx(model):
return a onnx model
class SONNXModel(Module):
def __init__(self, onnx_mode):
singa_rep = sonnx.prepare(onnx_model) # will update the prepare
function to remove device and batchsize
for layer_name, layer in singa_rep.layers:
self.__dict__[layer_name] = layer
# store weights here as numpy
for weith_name, weight in singa_rep.weights:
self.weights[weith_name] = weight
# store layer info such as input and output name(only weights)
for layer_name, layer_info in singa_rep.layer_infos:
self.layer_infos[layer_name] = layer_info
def forward(self, aux_output):
# run forward according to onnx graph
return the last output + aux_output
def compile(self, inputs, is_train, use_graph, graph_alg)
# init weights
super.compile(self, inputs, is_train, use_graph, graph_alg)
# set weights' value
for layer_name, layer in self.__dict__:
input_info, output_info = self.layer_infos[layer_name]
for input_name in input_info:
layer.set_weight(self.weights[input_name]) ** remember to
release self.weights to free memory.
class MyModel(SONNXModel):
def __init__(self, onnx):
super.__init__(onnx)
self.layer1 = Conv()
self.layer2 = Conv()
def forward(self, x):
x1, x2 = super.forward(x, aux_output)
x = self.layer1.forward(x2)
return self.layer2.forward(x1) + x
def train_one_batch(self, x, y):
y_ = self.forward(x)
....
ox = onnx.load(fpath)
x = Placeholder((2, 3), device = gpu, dtype=singa.float) # alias of Tensor
m = MyModel(ox)
# compatible with existing code which does not have the following two
statements.
m.compile([x], is_train=True, use_graph=True, graph_alg='sequence')
y = Placeholder((2,), device = gpu)
for npx, npy in data:
x.copy_from(npx)
y.copy_from(npy)
m.train_one_batch(x, y) # build the graph in the first iter. For the
old code, the params are initialized here.
```
update code with the comments with `**`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]