samskalicky commented on issue #19535:
URL: 
https://github.com/apache/incubator-mxnet/issues/19535#issuecomment-727269397


   Heres a complete, working solution on v1.8.x branch. Notice the new `save` 
and `load` functions perform the full model export/reload of both model 
architecture and params. 
   ```
   import mxnet as mx
   import json
   
   class MyBlock(mx.gluon.nn.Block):
       def __init__(self, **kwargs):
           super(MyBlock, self).__init__(**kwargs)
       def add(self, block):
           self._children[block.name + str(len(self._children))] = block
       def forward(self, x, *args):
           out = (x,) + args
           for block in self._children.values():
               out = block(*out)
           return out
       def save(self, prefix):
           # create empty model structure
           model = {}
           def _save_cached_graphs(blk, index, structure):
               # create new entry for this block
               mdl = {'orig_name': blk.name}
               # encode unique name based on block type and ID
               name = type(blk).__name__.lower()
               structure[name+str(index[0])] = mdl
               if isinstance(blk, mx.gluon.nn.HybridBlock):
                   # save in/out formats
                   mdl['in_format'] = blk._in_format
                   mdl['out_format'] = blk._out_format
                   # save cached graph & input symbols
                   syms, out = blk._cached_graph
                   mdl_syms = []
                   for sym in syms:
                       mdl_syms.append(sym.tojson())
                   mdl['inputs'] = mdl_syms
                   mdl['symbol'] = out.tojson()
               children = dict()
               mdl['children'] = children
               # recursively save children
               for ch_name, child in blk._children.items():
                   index[0] += 1
                   # save child's original name in this block's map
                   children[child.name] = ch_name
                   _save_cached_graphs(child, index, mdl)
           # save top-level block
           index = [0]
           _save_cached_graphs(self, index, model)
           # save model
           fp = open(prefix+'-model.json','w')
           json.dump(model, fp)
           fp.close()
           # save params
           self.save_parameters(prefix+'-model.params')
           
       def load(self, prefix):
           # load model json from file
           fp = open(prefix+'-model.json')
           model = json.load(fp)
           fp.close()
           def _load_cached_graphs(blk, index, log):
               # get block name
               name = type(blk).__name__.lower()
               # lookup previous encoded name based on block type and ID
               mdl = log[name+str(index[0])]
               # rename block to what it was when saved
               blk._name = mdl['orig_name']
               if isinstance(blk, mx.gluon.nn.HybridBlock):
                   # restore in/out formats
                   blk._in_format = mdl['in_format']
                   blk._out_format = mdl['out_format']
                   # get saved symbol
                   out = mx.sym.load_json(mdl['symbol'])
                   syms = []
                   # recreate inputs for this symbol
                   for inp in mdl['inputs']:
                       syms.append(mx.sym.load_json(inp))
                   # reset cached_graph and active status
                   blk._cached_graph = (syms, out)
                   blk._active = True
               # rename params with updated block name
               pnames = list(blk.params.keys())
               for p in pnames:
                   param = blk.params._params[p]
                   new_name = blk.name +'_'+ p[len(blk.params._prefix):]
                   blk.params._params.pop(p)
                   blk.params._params[new_name] = param            
               # recursively reload children
               for ch_name, child in blk._children.items():
                   index[0] += 1
                   _load_cached_graphs(child, index, mdl)
               # current set of child names
               ch_names = list(blk._children.keys())
               # original child names
               children = mdl['children']
               # loop and remap children with original names
               for ch_name in ch_names:
                   child = blk._children[ch_name]
                   blk._children.pop(ch_name)
                   orig_name = children[child.name]
                   blk._children[orig_name] = child
           # load top-level block
           index = [0]
           _load_cached_graphs(self, index, model)
           # load params
           self.load_parameters(prefix+'-model.params')
   
   def createNet():
       inside = MyBlock()
       dense = mx.gluon.nn.Dense(10)
       inside.add(dense)
       net = MyBlock()
       net.add(inside)
       net.add(mx.gluon.nn.Dense(10))
       return net
   
   # create and initialize model
   net = createNet()
   net.initialize()
   # run first inference to test
   x = mx.nd.empty((1,10))
   out = net(x)
   # hybridize (the hybridizeable blocks, ie. the Dense layers)
   net.hybridize()
   out = net(x)
   
   # save hybridized model
   net.save('MyModel')
   
   # create a new model, uninitialized
   net = createNet()
   # reload hybridized model
   net.load('MyModel')
   # run inference again
   out = net(x)
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to