This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 706c369  fix trainer when the model involves share_parameters (#18880)
706c369 is described below

commit 706c369fd0b7eb302d81bf2c7862baaea7eeed30
Author: Ziyue Huang <[email protected]>
AuthorDate: Sun Aug 9 07:55:16 2020 +0800

    fix trainer when the model involves share_parameters (#18880)
    
    * fix trainer when using shared_param
    
    * add unittest
---
 python/mxnet/gluon/trainer.py               | 10 ++++----
 tests/python/unittest/test_gluon_trainer.py | 40 +++++++++++++++++++++++++++++
 2 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 0229e9b..7f1a7d0 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -398,25 +398,25 @@ class Trainer(object):
             return
         for i, param in enumerate(self._params):
             if param.grad_req != 'null':
-
+                idx = self._param2idx[param._uuid]
                 grad_list = param.list_grad()
                 # sparse gradients, call push and pull separately
                 if grad_list[0].stype != 'default':
-                    self._kvstore.push(i, grad_list, priority=-i)
+                    self._kvstore.push(idx, grad_list, priority=-i)
                     if param._stype == 'default':
                         if self._update_on_kvstore:
                             pull_list = param.list_data()
                         else:
                             pull_list = param.list_grad()
-                        self._kvstore.pull(i, pull_list, priority=-i,
+                        self._kvstore.pull(idx, pull_list, priority=-i,
                                            ignore_sparse=self._distributed)
                 else:
                     # allreduce dense gradients if not update_on_kvstore,
                     # otherwise push dense gradients, pull dense weights
                     if self._update_on_kvstore:
-                        self._kvstore.pushpull(i, grad_list, 
out=param.list_data(), priority=-i)
+                        self._kvstore.pushpull(idx, grad_list, 
out=param.list_data(), priority=-i)
                     else:
-                        self._kvstore.pushpull(i, grad_list, priority=-i)
+                        self._kvstore.pushpull(idx, grad_list, priority=-i)
 
     def update(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update.
diff --git a/tests/python/unittest/test_gluon_trainer.py 
b/tests/python/unittest/test_gluon_trainer.py
index a83d104..5c94fc8 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -359,3 +359,43 @@ def test_trainer_allreduce_hybridsequential():
             out = net(mx.nd.ones((1, 1), ctx=ctx))
         out.backward()
     trainer.allreduce_grads()
+
+
+def test_trainer_share_parameters():
+    class Net(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Net, self).__init__(**kwargs)
+            self.dense1 = gluon.nn.Dense(5, in_units=2, use_bias=False)
+            params = self.dense1.collect_params()
+            self.dense2 = gluon.nn.Dense(5, in_units=2,
+                                         
use_bias=False).share_parameters(params)
+            self.dense3 = gluon.nn.Dense(5, in_units=5, use_bias=False)
+
+        def forward(self, x):
+            hidden = self.dense1(x) + self.dense2(x)
+            out = self.dense3(hidden)
+            return out
+
+    net = Net()
+    ctxes = [mx.cpu(0), mx.cpu(1)]
+    net.initialize(mx.init.One(), ctx=ctxes)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
+    data = mx.nd.array([[1, 1], [1, 1]])
+    xs = gluon.utils.split_and_load(data, ctxes)
+    ys = []
+    with mx.autograd.record():
+        for x in xs:
+            y = net(x)
+            ys.append(y)
+    for y in ys:
+        y.backward()
+    trainer.step(1)
+    params = net.collect_params()
+    shared_params = []
+    for param in params.values():
+        p = param.data(mx.cpu(0)).asnumpy()
+        if p.shape[1] == 2:
+            shared_params.append(p)
+
+    assert((shared_params[0] == shared_params[1]).all())
+

Reply via email to