marcoabreu closed pull request #12350: fix bug in 'device' type kvstore
URL: https://github.com/apache/incubator-mxnet/pull/12350
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 34cab3037ce..61370a5bfaf 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -459,6 +459,7 @@ class CommDevice : public Comm {
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
sorted_key_attrs_.emplace_back(key, shape, dtype);
+ inited_ = false;
}
void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -701,8 +702,10 @@ class CommDevice : public Comm {
}
// Delayed allocation - as the dense merged buffer might not be used at
all if push()
// only sees sparse arrays
- bool delay_alloc = true;
- buf.merged = NDArray(shape, ctx, delay_alloc, type);
+ if (buf.merged.is_none()) {
+ bool delay_alloc = true;
+ buf.merged = NDArray(shape, ctx, delay_alloc, type);
+ }
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
diff --git a/tests/python/unittest/test_kvstore.py
b/tests/python/unittest/test_kvstore.py
index 921a5704d54..28d4ec262c0 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -106,6 +106,23 @@ def check_init(kv, key):
check_init(mx.kv.create(), 3)
check_init(mx.kv.create(), 'a')
+@with_seed()
+def test_pull():
+ """test pull"""
+ def check_pull(kv):
+ a = mx.nd.ones(shape)
+ b = mx.nd.zeros(shape)
+ kv.init('1', mx.nd.zeros(shape))
+ kv.push('1', [a,a,a,a])
+ kv.pull('1', b)
+ check_diff_to_scalar(b, 4)
+ kv.init('2', mx.nd.zeros(shape))
+ kv.pull('2', b)
+ check_diff_to_scalar(b, 0)
+
+ check_pull(mx.kv.create('device'))
+ check_pull(mx.kv.create())
+
@with_seed()
def test_list_kv_pair():
"""list key-value pair push & pull"""
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services