This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new be0579c  refactor gluon trainer (#7338)
be0579c is described below

commit be0579ce7519cd910dab8c0261df7212d155d0b1
Author: Eric Junyuan Xie <piiswr...@users.noreply.github.com>
AuthorDate: Sat Aug 5 21:28:24 2017 -0700

    refactor gluon trainer (#7338)
    
    * fix optimizer
    
    * Update trainer.py
---
 python/mxnet/gluon/parameter.py                    |  2 +-
 python/mxnet/gluon/trainer.py                      | 59 +++++++++++++---------
 python/mxnet/optimizer.py                          | 12 +++--
 .../python/unittest/{test_nn.py => test_gluon.py}  | 23 +++++++++
 4 files changed, 69 insertions(+), 27 deletions(-)

diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 0ae829a..bdc9674 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -361,7 +361,7 @@ class ParameterDict(object):
     """
     def __init__(self, prefix='', shared=None):
         self._prefix = prefix
-        self._params = {}
+        self._params = OrderedDict()
         self._shared = shared
 
     def __getitem__(self, key):
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 5483f6b..e8aae71 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -15,14 +15,19 @@ class Trainer(object):
     params : ParameterDict
         The set of parameters to optimize.
     optimizer : str or Optimizer
-        The optimizer to use.
+        The optimizer to use. See
+        `help 
<http://mxnet.io/api/python/optimization.html#the-mxnet-optimizer-package>`_
+        on Optimizer for a list of available optimizers.
     optimizer_params : dict
         Key-word arguments to be passed to optimizer constructor. For example,
-        `{'learning_rate': 0.1}`
+        `{'learning_rate': 0.1}`. All optimizers accept learning_rate, wd 
(weight decay),
+        clip_gradient, and lr_scheduler. See each optimizer's
+        constructor for a list of additional supported arguments.
     kvstore : str or KVStore
-        kvstore type for multi-gpu and distributed training.
+        kvstore type for multi-gpu and distributed training. See help on
+        :any:`mxnet.kvstore.create` for more information.
     """
-    def __init__(self, params, optimizer, optimizer_params, kvstore='device'):
+    def __init__(self, params, optimizer, optimizer_params=None, 
kvstore='device'):
         if isinstance(params, (dict, ParameterDict)):
             params = list(params.values())
         if not isinstance(params, (list, tuple)):
@@ -35,9 +40,9 @@ class Trainer(object):
                 raise ValueError(
                     "First argument must be a list or dict of Parameters, " \
                     "got list of %s."%(type(param)))
-            if param.grad_req != 'null':
-                self._params.append(param)
+            self._params.append(param)
 
+        optimizer_params = optimizer_params if optimizer_params else {}
         self._scale = optimizer_params.get('rescale_grad', 1.0)
         self._contexts = self._check_contexts()
         self._init_optimizer(optimizer, optimizer_params)
@@ -56,32 +61,39 @@ class Trainer(object):
         return contexts
 
     def _init_optimizer(self, optimizer, optimizer_params):
-        self._optimizer = opt.create(optimizer, **optimizer_params)
-
-        lr_mult = {}
-        wd_mult = {}
-        for i, param in enumerate(self._params):
-            lr_mult[i] = param.lr_mult
-            wd_mult[i] = param.wd_mult
-        self._optimizer.set_lr_mult(lr_mult)
-        self._optimizer.set_wd_mult(wd_mult)
+        param_dict = {i: param for i, param in enumerate(self._params)}
+        if isinstance(optimizer, opt.Optimizer):
+            assert not optimizer_params, \
+                "optimizer_params must be None if optimizer is an instance of 
" \
+                "Optimizer instead of str"
+            self._optimizer = optimizer
+            self._optimizer.param_dict = param_dict
+        else:
+            self._optimizer = opt.create(optimizer, param_dict=param_dict,
+                                         **optimizer_params)
 
         self._updaters = [opt.get_updater(self._optimizer) \
                             for _ in self._contexts]
 
     def _init_kvstore(self):
         arg_arrays = {param.name: param.data(self._contexts[0]) for param in 
self._params}
-        kvstore, update_on_kvstore = _create_kvstore(self._kvstore, 
len(self._contexts), arg_arrays)
-        self._kvstore = kvstore
-        self._update_on_kvstore = update_on_kvstore
+        kvstore, update_on_kvstore = _create_kvstore(self._kvstore, 
len(self._contexts),
+                                                     arg_arrays)
         if kvstore:
-            assert 'dist' not in self._kvstore.type, "distributed training not 
supported yet"
+            if 'dist' in kvstore.type:
+                update_on_kvstore = False
             for i, param in enumerate(self._params):
                 param_arrays = param.list_data()
                 kvstore.init(i, param_arrays[0])
                 kvstore.pull(i, param_arrays, priority=-i)
             if update_on_kvstore:
                 kvstore.set_optimizer(self._optimizer)
+            self._kvstore = kvstore
+            self._update_on_kvstore = update_on_kvstore
+        else:
+            self._kvstore = None
+            self._update_on_kvstore = None
+
         self._kv_initialized = True
 
     def step(self, batch_size, ignore_stale_grad=False):
@@ -103,9 +115,8 @@ class Trainer(object):
         self._optimizer.rescale_grad = self._scale / batch_size
 
         for i, param in enumerate(self._params):
-            assert param.list_ctx() == self._contexts, \
-                "Parameter %s's contexts changed after Optim initialization: " 
\
-                "was %s, now %s"%(param.name, self._contexts, param.list_ctx())
+            if param.grad_req == 'null':
+                continue
             if not ignore_stale_grad:
                 for data in param.list_data():
                     if not data._fresh_grad:
@@ -117,6 +128,7 @@ class Trainer(object):
                             "call step with ignore_stale_grad=True to suppress 
this "
                             "warning and skip updating of Parameters with 
stale gradient" \
                             %(param.name, str(data.context)))
+
             if self._kvstore:
                 self._kvstore.push(i, param.list_grad(), priority=-i)
                 if self._update_on_kvstore:
@@ -124,7 +136,8 @@ class Trainer(object):
                     continue
                 else:
                     self._kvstore.pull(i, param.list_grad(), priority=-i)
+
             for upd, arr, grad in zip(self._updaters, param.list_data(), 
param.list_grad()):
-                if arr._fresh_grad:
+                if not ignore_stale_grad or arr._fresh_grad:
                     upd(i, grad, arr)
                     arr._fresh_grad = False
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 57fadf4..934566e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -43,7 +43,8 @@ class Optimizer(object):
     """
     def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
                  clip_gradient=None, learning_rate=0.01,
-                 lr_scheduler=None, sym=None, begin_num_update=0):
+                 lr_scheduler=None, sym=None, begin_num_update=0,
+                 param_dict=None):
         self.rescale_grad = rescale_grad
         self.lr = learning_rate
         self.lr_scheduler = lr_scheduler
@@ -64,6 +65,7 @@ class Optimizer(object):
             'param_idx2name should be a dict of param indexes to names.'
         self.idx2name = param_idx2name.copy()
         self.sym = sym
+        self.param_dict = param_dict if param_dict else {}
 
         self.set_lr_mult({})
         self.set_wd_mult({})
@@ -277,7 +279,9 @@ class Optimizer(object):
         else:
             lr = self.lr
 
-        if index in self.lr_mult:
+        if index in self.param_dict:
+            lr *= self.param_dict[index].lr_mult
+        elif index in self.lr_mult:
             lr *= self.lr_mult[index]
         elif index in self.idx2name:
             lr *= self.lr_mult.get(self.idx2name[index], 1.0)
@@ -298,7 +302,9 @@ class Optimizer(object):
             Weight decay for this index.
         """
         wd = self.wd
-        if index in self.wd_mult:
+        if index in self.param_dict:
+            wd *= self.param_dict[index].wd_mult
+        elif index in self.wd_mult:
             wd *= self.wd_mult[index]
         elif index in self.idx2name:
             wd *= self.wd_mult.get(self.idx2name[index], 1.0)
diff --git a/tests/python/unittest/test_nn.py 
b/tests/python/unittest/test_gluon.py
similarity index 93%
rename from tests/python/unittest/test_nn.py
rename to tests/python/unittest/test_gluon.py
index e293063..8256c71 100644
--- a/tests/python/unittest/test_nn.py
+++ b/tests/python/unittest/test_gluon.py
@@ -302,6 +302,29 @@ def test_flatten():
     assert flatten(x).shape == (3, 1)
 
 
+def test_trainer():
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0})
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
+
+    x.lr_mult = 0.5
+
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -3).all()
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to