This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 391a1be  Set idx2name for Optimizer object (#14703)
391a1be is described below

commit 391a1be260eb75b437ebced6743647b8e9df7802
Author: Yuxi Hu <[email protected]>
AuthorDate: Thu Apr 18 23:04:46 2019 -0700

    Set idx2name for Optimizer object (#14703)
    
    * set idx2name for optimizer
    
    * add unit test
---
 python/mxnet/model.py                |  2 ++
 python/mxnet/module/module.py        | 16 +++++++++-------
 tests/python/unittest/test_module.py | 28 ++++++++++++++++++++++++++++
 3 files changed, 39 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index efb5109..f44ff04 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -884,6 +884,8 @@ class FeedForward(BASE_ESTIMATOR):
                                    rescale_grad=(1.0/batch_size),
                                    **(self.kwargs))
         elif isinstance(self.optimizer, opt.Optimizer):
+            if not optimizer.idx2name:
+                optimizer.idx2name = param_idx2name.copy()
             optimizer = self.optimizer
 
         # do training
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index a7d3336..e83751d 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -505,14 +505,14 @@ class Module(BaseModule):
             batch_size *= kvstore.num_workers
         rescale_grad = 1.0/batch_size
 
+        idx2name = {}
+        if update_on_kvstore:
+            idx2name.update(enumerate(self._exec_group.param_names))
+        else:
+            for k in range(len(self._context)):
+                idx2name.update({i*len(self._context)+k: n
+                                 for i, n in 
enumerate(self._exec_group.param_names)})
         if isinstance(optimizer, str):
-            idx2name = {}
-            if update_on_kvstore:
-                idx2name.update(enumerate(self._exec_group.param_names))
-            else:
-                for k in range(len(self._context)):
-                    idx2name.update({i*len(self._context)+k: n
-                                     for i, n in 
enumerate(self._exec_group.param_names)})
             optimizer_params = dict(optimizer_params)
             if 'rescale_grad' not in optimizer_params:
                 optimizer_params['rescale_grad'] = rescale_grad
@@ -528,6 +528,8 @@ class Module(BaseModule):
                     "is not normalized to 1.0/batch_size/num_workers (%s vs. 
%s). "%(
                         optimizer.rescale_grad, rescale_grad) +
                     "Is this intended?", stacklevel=2)
+            if not optimizer.idx2name:
+                optimizer.idx2name = idx2name.copy()
 
         self._optimizer = optimizer
         self._kvstore = kvstore
diff --git a/tests/python/unittest/test_module.py 
b/tests/python/unittest/test_module.py
index 36c1993..c82afdf 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -931,6 +931,34 @@ def test_module_update_no_pragram():
     mod.update()
     assert(mod.get_outputs()[0].shape == data_shape)
 
+
+def test_module_init_optimizer():
+    def get_module_idx2name(mod):
+        idx2name = {}
+        idx2name.update(enumerate(mod._exec_group.param_names))
+        return idx2name
+
+    data = mx.sym.Variable('data')
+    sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc')
+    batch_size = 8
+    opt_params = {'learning_rate': 1, 'rescale_grad': 1.0 / batch_size}
+
+    # Pass an optimizer str
+    mod1 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
+    mod1.bind(data_shapes=[('data', (batch_size, 20))])
+    mod1.init_params()
+    mod1.init_optimizer(optimizer='sgd', optimizer_params=opt_params)
+    assert mod1._optimizer.idx2name == get_module_idx2name(mod1)
+
+    # Pass an Optimizer object
+    mod2 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
+    mod2.bind(data_shapes=[('data', (batch_size, 20))])
+    mod2.init_params()
+    opt = mx.optimizer.SGD(**opt_params)
+    mod2.init_optimizer(optimizer=opt)
+    assert mod2._optimizer.idx2name == get_module_idx2name(mod2)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to