piiswrong closed pull request #10748: fix symbolblock save_params
URL: https://github.com/apache/incubator-mxnet/pull/10748
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7e4127250a0..4779484ec3e 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -649,6 +649,18 @@ def hybrid_forward(self, F, x, *args, **kwargs):
         # pylint: disable= invalid-name
         raise NotImplementedError
 
+def _common_prefix(names):
+    """Get the common prefix for all names"""
+    if not names:
+        return ''
+    prefix = names[0]
+    for name in names:
+        i = 0
+        while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
+            i += 1
+        prefix = prefix[:i]
+    return prefix
+
 
 class SymbolBlock(HybridBlock):
     """Construct block from symbol. This is useful for using pre-trained models
@@ -710,6 +722,8 @@ def __init__(self, outputs, inputs, params=None):
                 self.params.get(i, grad_req='null', allow_deferred_init=True)
 
         self._cached_graph = syms, out
+        len_prefix = len(_common_prefix(list(self._params.keys())))
+        self._reg_params = {key[len_prefix:]: val for key, val in 
self._params.items()}
 
     def forward(self, x, *args):
         if isinstance(x, NDArray):
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index b054aa6555f..fb73e53bc05 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -986,6 +986,33 @@ def test_save_load():
 
     net.load_params('test.params')
 
+def test_symbol_block_save_load():
+    class Net(gluon.HybridBlock):
+        def __init__(self):
+            super(Net, self).__init__()
+            with self.name_scope():
+                backbone = gluon.model_zoo.vision.resnet18_v1()
+                data = mx.sym.var('data')
+                featnames = ['stage1_activation0', 'stage2_activation0', 
'stage3_activation0']
+                out_names = ['_'.join([backbone.name, featname, 'output']) for 
featname in featnames]
+                internals = backbone(data).get_internals()
+                outs = [internals[out_name] for out_name in out_names]
+                self.backbone = gluon.SymbolBlock(outs, data, 
params=backbone.collect_params())
+                self.body = nn.Conv2D(3, 1)
+
+        def hybrid_forward(self, F, x):
+            x = self.body(x)
+            return self.backbone(x)
+
+    net1 = Net()
+    net1.initialize(mx.init.Normal())
+    net1.hybridize()
+    net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
+    net1.save_params('./test.params')
+
+    net2 = Net()
+    net2.load_params('./test.params', ctx=mx.cpu())
+
 
 def test_hybrid_multi_context():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to