This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 391a1be Set idx2name for Optimizer object (#14703)
391a1be is described below
commit 391a1be260eb75b437ebced6743647b8e9df7802
Author: Yuxi Hu <[email protected]>
AuthorDate: Thu Apr 18 23:04:46 2019 -0700
Set idx2name for Optimizer object (#14703)
* set idx2name for optimizer
* add unit test
---
python/mxnet/model.py | 2 ++
python/mxnet/module/module.py | 16 +++++++++-------
tests/python/unittest/test_module.py | 28 ++++++++++++++++++++++++++++
3 files changed, 39 insertions(+), 7 deletions(-)
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index efb5109..f44ff04 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -884,6 +884,8 @@ class FeedForward(BASE_ESTIMATOR):
rescale_grad=(1.0/batch_size),
**(self.kwargs))
elif isinstance(self.optimizer, opt.Optimizer):
+ if not optimizer.idx2name:
+ optimizer.idx2name = param_idx2name.copy()
optimizer = self.optimizer
# do training
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index a7d3336..e83751d 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -505,14 +505,14 @@ class Module(BaseModule):
batch_size *= kvstore.num_workers
rescale_grad = 1.0/batch_size
+ idx2name = {}
+ if update_on_kvstore:
+ idx2name.update(enumerate(self._exec_group.param_names))
+ else:
+ for k in range(len(self._context)):
+ idx2name.update({i*len(self._context)+k: n
+ for i, n in
enumerate(self._exec_group.param_names)})
if isinstance(optimizer, str):
- idx2name = {}
- if update_on_kvstore:
- idx2name.update(enumerate(self._exec_group.param_names))
- else:
- for k in range(len(self._context)):
- idx2name.update({i*len(self._context)+k: n
- for i, n in
enumerate(self._exec_group.param_names)})
optimizer_params = dict(optimizer_params)
if 'rescale_grad' not in optimizer_params:
optimizer_params['rescale_grad'] = rescale_grad
@@ -528,6 +528,8 @@ class Module(BaseModule):
"is not normalized to 1.0/batch_size/num_workers (%s vs.
%s). "%(
optimizer.rescale_grad, rescale_grad) +
"Is this intended?", stacklevel=2)
+ if not optimizer.idx2name:
+ optimizer.idx2name = idx2name.copy()
self._optimizer = optimizer
self._kvstore = kvstore
diff --git a/tests/python/unittest/test_module.py
b/tests/python/unittest/test_module.py
index 36c1993..c82afdf 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -931,6 +931,34 @@ def test_module_update_no_pragram():
mod.update()
assert(mod.get_outputs()[0].shape == data_shape)
+
+def test_module_init_optimizer():
+ def get_module_idx2name(mod):
+ idx2name = {}
+ idx2name.update(enumerate(mod._exec_group.param_names))
+ return idx2name
+
+ data = mx.sym.Variable('data')
+ sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc')
+ batch_size = 8
+ opt_params = {'learning_rate': 1, 'rescale_grad': 1.0 / batch_size}
+
+ # Pass an optimizer str
+ mod1 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
+ mod1.bind(data_shapes=[('data', (batch_size, 20))])
+ mod1.init_params()
+ mod1.init_optimizer(optimizer='sgd', optimizer_params=opt_params)
+ assert mod1._optimizer.idx2name == get_module_idx2name(mod1)
+
+ # Pass an Optimizer object
+ mod2 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
+ mod2.bind(data_shapes=[('data', (batch_size, 20))])
+ mod2.init_params()
+ opt = mx.optimizer.SGD(**opt_params)
+ mod2.init_optimizer(optimizer=opt)
+ assert mod2._optimizer.idx2name == get_module_idx2name(mod2)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()