This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 0fb57ff fix symbolblock save_params (#10748) 0fb57ff is described below commit 0fb57ff31ef5caa32edf973213bde8a8faba85e5 Author: Eric Junyuan Xie <piiswr...@users.noreply.github.com> AuthorDate: Mon May 14 22:31:05 2018 -0700 fix symbolblock save_params (#10748) * fix symbolblock save_params * fix --- python/mxnet/gluon/block.py | 14 ++++++++++++++ tests/python/unittest/test_gluon.py | 27 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 7e41272..4779484 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -649,6 +649,18 @@ class HybridBlock(Block): # pylint: disable= invalid-name raise NotImplementedError +def _common_prefix(names): + """Get the common prefix for all names""" + if not names: + return '' + prefix = names[0] + for name in names: + i = 0 + while i < len(prefix) and i < len(name) and prefix[i] == name[i]: + i += 1 + prefix = prefix[:i] + return prefix + class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models @@ -710,6 +722,8 @@ class SymbolBlock(HybridBlock): self.params.get(i, grad_req='null', allow_deferred_init=True) self._cached_graph = syms, out + len_prefix = len(_common_prefix(list(self._params.keys()))) + self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()} def forward(self, x, *args): if isinstance(x, NDArray): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index b054aa6..fb73e53 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -986,6 +986,33 @@ def test_save_load(): net.load_params('test.params') +def test_symbol_block_save_load(): + class Net(gluon.HybridBlock): + def __init__(self): + super(Net, self).__init__() + with self.name_scope(): + backbone = gluon.model_zoo.vision.resnet18_v1() + data = mx.sym.var('data') + featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0'] + out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames] + internals = backbone(data).get_internals() + outs = [internals[out_name] for out_name in out_names] + self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params()) + self.body = nn.Conv2D(3, 1) + + def hybrid_forward(self, F, x): + x = self.body(x) + return self.backbone(x) + + net1 = Net() + net1.initialize(mx.init.Normal()) + net1.hybridize() + net1(mx.nd.random.normal(shape=(1, 3, 32, 32))) + net1.save_params('./test.params') + + net2 = Net() + net2.load_params('./test.params', ctx=mx.cpu()) + def test_hybrid_multi_context(): net = mx.gluon.model_zoo.vision.get_resnet(1, 18) -- To stop receiving notification emails like this one, please contact j...@apache.org.