ptrendx commented on a change in pull request #13346: Aggregate SGD
URL: https://github.com/apache/incubator-mxnet/pull/13346#discussion_r244892351
 
 

 ##########
 File path: python/mxnet/optimizer/optimizer.py
 ##########
 @@ -522,39 +566,72 @@ def create_state(self, index, weight):
             momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, 
stype=stype)
         return momentum
 
-    def _update_impl(self, index, weight, grad, state, multi_precision=False):
-        assert(isinstance(weight, NDArray))
-        assert(isinstance(grad, NDArray))
-        self._update_count(index)
-        lr = self._get_lr(index)
-        wd = self._get_wd(index)
+    def _update_impl(self, indices, weights, grads, states, 
multi_precision=False):
+        aggregate = True
+        if not isinstance(indices, (tuple, list)):
+            indices = [indices]
+            weights = [weights]
+            grads = [grads]
+            states = [states]
+        for weight, grad in zip(weights, grads):
+            assert(isinstance(weight, NDArray))
+            assert(isinstance(grad, NDArray))
+            aggregate = (aggregate and
 
 Review comment:
   @anirudhacharya As you can see `aggregate` is set to `True` at the beginning 
and changes to `False` when encountering non-default storage type, so testing 
with both dense and sparse data tests both branches of the code.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to