Heres a complete, working solution on v1.8.x branch. Notice the new `save` and `load` functions perform the full model export/reload of both model architecture and params. ``` import mxnet as mx import json
class MyBlock(mx.gluon.nn.Block): def __init__(self, **kwargs): super(MyBlock, self).__init__(**kwargs) def add(self, block): self._children[block.name + str(len(self._children))] = block def forward(self, x, *args): out = (x,) + args for block in self._children.values(): out = block(*out) return out def save(self, prefix): # create empty model structure model = {} def _save_cached_graphs(blk, index, structure): # create new entry for this block mdl = {'orig_name': blk.name} # encode unique name based on block type and ID name = type(blk).__name__.lower() structure[name+str(index[0])] = mdl if isinstance(blk, mx.gluon.nn.HybridBlock): # save in/out formats mdl['in_format'] = blk._in_format mdl['out_format'] = blk._out_format # save cached graph & input symbols syms, out = blk._cached_graph mdl_syms = [] for sym in syms: mdl_syms.append(sym.tojson()) mdl['inputs'] = mdl_syms mdl['symbol'] = out.tojson() children = dict() mdl['children'] = children # recursively save children for ch_name, child in blk._children.items(): index[0] += 1 # save child's original name in this block's map children[child.name] = ch_name _save_cached_graphs(child, index, mdl) # save top-level block index = [0] _save_cached_graphs(self, index, model) # save model fp = open(prefix+'-model.json','w') json.dump(model, fp) fp.close() # save params self.save_parameters(prefix+'-model.params') def load(self, prefix): # load model json from file fp = open(prefix+'-model.json') model = json.load(fp) fp.close() def _load_cached_graphs(blk, index, log): # get block name name = type(blk).__name__.lower() # lookup previous encoded name based on block type and ID mdl = log[name+str(index[0])] # rename block to what it was when saved blk._name = mdl['orig_name'] if isinstance(blk, mx.gluon.nn.HybridBlock): # restore in/out formats blk._in_format = mdl['in_format'] blk._out_format = mdl['out_format'] # get saved symbol out = mx.sym.load_json(mdl['symbol']) syms = [] # recreate inputs for this symbol for inp in mdl['inputs']: syms.append(mx.sym.load_json(inp)) # reset cached_graph and active status blk._cached_graph = (syms, out) blk._active = True # rename params with updated block name pnames = list(blk.params.keys()) for p in pnames: param = blk.params._params[p] new_name = blk.name +'_'+ p[len(blk.params._prefix):] blk.params._params.pop(p) blk.params._params[new_name] = param # recursively reload children for ch_name, child in blk._children.items(): index[0] += 1 _load_cached_graphs(child, index, mdl) # current set of child names ch_names = list(blk._children.keys()) # original child names children = mdl['children'] # loop and remap children with original names for ch_name in ch_names: child = blk._children[ch_name] blk._children.pop(ch_name) orig_name = children[child.name] blk._children[orig_name] = child # load top-level block index = [0] _load_cached_graphs(self, index, model) # load params self.load_parameters(prefix+'-model.params') def createNet(): inside = MyBlock() dense = mx.gluon.nn.Dense(10) inside.add(dense) net = MyBlock() net.add(inside) net.add(mx.gluon.nn.Dense(10)) return net # create and initialize model net = createNet() net.initialize() # run first inference to test x = mx.nd.empty((1,10)) out = net(x) # hybridize (the hybridizeable blocks, ie. the Dense layers) net.hybridize() out = net(x) # save hybridized model net.save('MyModel') # create a new model, uninitialized net = createNet() # reload hybridized model net.load('MyModel') # run inference again out = net(x) ``` -- You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub: https://github.com/apache/incubator-mxnet/issues/19535#issuecomment-727269397