Heres a complete, working solution on v1.8.x branch. Notice the new `save` and 
`load` functions perform the full model export/reload of both model 
architecture and params. 
```
import mxnet as mx
import json

class MyBlock(mx.gluon.nn.Block):
    def __init__(self, **kwargs):
        super(MyBlock, self).__init__(**kwargs)
    def add(self, block):
        self._children[block.name + str(len(self._children))] = block
    def forward(self, x, *args):
        out = (x,) + args
        for block in self._children.values():
            out = block(*out)
        return out
    def save(self, prefix):
        # create empty model structure
        model = {}
        def _save_cached_graphs(blk, index, structure):
            # create new entry for this block
            mdl = {'orig_name': blk.name}
            # encode unique name based on block type and ID
            name = type(blk).__name__.lower()
            structure[name+str(index[0])] = mdl
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # save in/out formats
                mdl['in_format'] = blk._in_format
                mdl['out_format'] = blk._out_format
                # save cached graph & input symbols
                syms, out = blk._cached_graph
                mdl_syms = []
                for sym in syms:
                    mdl_syms.append(sym.tojson())
                mdl['inputs'] = mdl_syms
                mdl['symbol'] = out.tojson()
            children = dict()
            mdl['children'] = children
            # recursively save children
            for ch_name, child in blk._children.items():
                index[0] += 1
                # save child's original name in this block's map
                children[child.name] = ch_name
                _save_cached_graphs(child, index, mdl)
        # save top-level block
        index = [0]
        _save_cached_graphs(self, index, model)
        # save model
        fp = open(prefix+'-model.json','w')
        json.dump(model, fp)
        fp.close()
        # save params
        self.save_parameters(prefix+'-model.params')
        
    def load(self, prefix):
        # load model json from file
        fp = open(prefix+'-model.json')
        model = json.load(fp)
        fp.close()
        def _load_cached_graphs(blk, index, log):
            # get block name
            name = type(blk).__name__.lower()
            # lookup previous encoded name based on block type and ID
            mdl = log[name+str(index[0])]
            # rename block to what it was when saved
            blk._name = mdl['orig_name']
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # restore in/out formats
                blk._in_format = mdl['in_format']
                blk._out_format = mdl['out_format']
                # get saved symbol
                out = mx.sym.load_json(mdl['symbol'])
                syms = []
                # recreate inputs for this symbol
                for inp in mdl['inputs']:
                    syms.append(mx.sym.load_json(inp))
                # reset cached_graph and active status
                blk._cached_graph = (syms, out)
                blk._active = True
            # rename params with updated block name
            pnames = list(blk.params.keys())
            for p in pnames:
                param = blk.params._params[p]
                new_name = blk.name +'_'+ p[len(blk.params._prefix):]
                blk.params._params.pop(p)
                blk.params._params[new_name] = param            
            # recursively reload children
            for ch_name, child in blk._children.items():
                index[0] += 1
                _load_cached_graphs(child, index, mdl)
            # current set of child names
            ch_names = list(blk._children.keys())
            # original child names
            children = mdl['children']
            # loop and remap children with original names
            for ch_name in ch_names:
                child = blk._children[ch_name]
                blk._children.pop(ch_name)
                orig_name = children[child.name]
                blk._children[orig_name] = child
        # load top-level block
        index = [0]
        _load_cached_graphs(self, index, model)
        # load params
        self.load_parameters(prefix+'-model.params')

def createNet():
    inside = MyBlock()
    dense = mx.gluon.nn.Dense(10)
    inside.add(dense)
    net = MyBlock()
    net.add(inside)
    net.add(mx.gluon.nn.Dense(10))
    return net

# create and initialize model
net = createNet()
net.initialize()
# run first inference to test
x = mx.nd.empty((1,10))
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net.hybridize()
out = net(x)

# save hybridized model
net.save('MyModel')

# create a new model, uninitialized
net = createNet()
# reload hybridized model
net.load('MyModel')
# run inference again
out = net(x)
```

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-mxnet/issues/19535#issuecomment-727269397

Reply via email to