eric-haibin-lin opened a new issue #11638: distributed kvstore doesn't respect shape of ndarrays URL: https://github.com/apache/incubator-mxnet/issues/11638 Steps to reproduce: `test.py` ``` import mxnet as mx shape = (4, 2) def init_kv(): """init kv """ kv = mx.kv.create('dist') kv.init(3, mx.nd.ones(shape=shape)) return kv def test_updater(): def check_updater(kv): # single key = 3 vals = mx.nd.ones(shape) outs = mx.nd.empty(shape) kv.push(key, vals) kv.pull(key, out=outs) mx.nd.waitall() kv = init_kv() kv.set_optimizer(mx.optimizer.SGD()) check_updater(kv) test_updater() ``` Add the following assertion to `SGD:update_multi_precision()` ``` def update_multi_precision(self, index, weight, grad, state): assert weight.shape == (4,2), weight use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) ``` Run with `python tools/launch.py --launcher=local -n 2 python test.py` ``` Traceback (most recent call last): File "_ctypes/callbacks.c", line 315, in 'calling callback function' File "/home/ubuntu/haibin2/python/mxnet/kvstore.py", line 83, in updater_handle updater(key, lhs, rhs) File "/home/ubuntu/haibin2/python/mxnet/optimizer.py", line 1484, in __call__ self.optimizer.update_multi_precision(index, weight, grad, self.states[index]) File "/home/ubuntu/haibin2/python/mxnet/optimizer.py", line 545, in update_multi_precision assert weight.shape == (4,2), weight AssertionError: (8L,) ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
