This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6ca9092 fix bug in 'device' type kvstore (#12350)
6ca9092 is described below
commit 6ca909270d5a7f8af280ad83c8aa4862ba819d04
Author: solin319 <[email protected]>
AuthorDate: Thu Aug 30 17:25:00 2018 +0800
fix bug in 'device' type kvstore (#12350)
* fix bug in 'device' type kvstore
When we init a key after another key pushed. This key has no merged_buf_ in
file 'comm.h', but the inited_ is true. So it can't pull this new key.
```
import mxnet as mx
a=mx.nd.array([1,2,3], ctx=mx.gpu(0))
b=mx.nd.array([0,0,0], ctx=mx.gpu(0))
kv=mx.kv.create('device')
kv.init('1', a)
kv.push('1', [a,a,a,a])
kv.pull('1', b)
kv.init('2', a)
kv.pull('2', b)
```
* add kv test pull
---
src/kvstore/comm.h | 7 +++++--
tests/python/unittest/test_kvstore.py | 17 +++++++++++++++++
2 files changed, 22 insertions(+), 2 deletions(-)
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 34cab30..61370a5 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -459,6 +459,7 @@ class CommDevice : public Comm {
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
sorted_key_attrs_.emplace_back(key, shape, dtype);
+ inited_ = false;
}
void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -701,8 +702,10 @@ class CommDevice : public Comm {
}
// Delayed allocation - as the dense merged buffer might not be used at
all if push()
// only sees sparse arrays
- bool delay_alloc = true;
- buf.merged = NDArray(shape, ctx, delay_alloc, type);
+ if (buf.merged.is_none()) {
+ bool delay_alloc = true;
+ buf.merged = NDArray(shape, ctx, delay_alloc, type);
+ }
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
diff --git a/tests/python/unittest/test_kvstore.py
b/tests/python/unittest/test_kvstore.py
index 921a570..28d4ec2 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -107,6 +107,23 @@ def test_init():
check_init(mx.kv.create(), 'a')
@with_seed()
+def test_pull():
+ """test pull"""
+ def check_pull(kv):
+ a = mx.nd.ones(shape)
+ b = mx.nd.zeros(shape)
+ kv.init('1', mx.nd.zeros(shape))
+ kv.push('1', [a,a,a,a])
+ kv.pull('1', b)
+ check_diff_to_scalar(b, 4)
+ kv.init('2', mx.nd.zeros(shape))
+ kv.pull('2', b)
+ check_diff_to_scalar(b, 0)
+
+ check_pull(mx.kv.create('device'))
+ check_pull(mx.kv.create())
+
+@with_seed()
def test_list_kv_pair():
"""list key-value pair push & pull"""
def check_list_kv_pair(kv, key, stype):