This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ca9092  fix bug in 'device' type kvstore (#12350)
6ca9092 is described below

commit 6ca909270d5a7f8af280ad83c8aa4862ba819d04
Author: solin319 <[email protected]>
AuthorDate: Thu Aug 30 17:25:00 2018 +0800

    fix bug in 'device' type kvstore (#12350)
    
    * fix bug in 'device' type kvstore
    
    When we init a key after another key pushed. This key has no merged_buf_ in 
file 'comm.h', but the inited_ is true. So it can't pull this new key.
    ```
    import mxnet as mx
    a=mx.nd.array([1,2,3], ctx=mx.gpu(0))
    b=mx.nd.array([0,0,0], ctx=mx.gpu(0))
    kv=mx.kv.create('device')
    kv.init('1', a)
    kv.push('1', [a,a,a,a])
    kv.pull('1', b)
    kv.init('2', a)
    kv.pull('2', b)
    ```
    
    * add kv test pull
---
 src/kvstore/comm.h                    |  7 +++++--
 tests/python/unittest/test_kvstore.py | 17 +++++++++++++++++
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 34cab30..61370a5 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -459,6 +459,7 @@ class CommDevice : public Comm {
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int dtype = mshadow::kFloat32) override {
     sorted_key_attrs_.emplace_back(key, shape, dtype);
+    inited_ = false;
   }
 
   void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -701,8 +702,10 @@ class CommDevice : public Comm {
       }
       // Delayed allocation - as the dense merged buffer might not be used at 
all if push()
       // only sees sparse arrays
-      bool delay_alloc = true;
-      buf.merged = NDArray(shape, ctx, delay_alloc, type);
+      if (buf.merged.is_none()) {
+        bool delay_alloc = true;
+        buf.merged = NDArray(shape, ctx, delay_alloc, type);
+      }
       ctx_info[ctx.dev_id].second += shape.Size();
     }
     inited_ = true;
diff --git a/tests/python/unittest/test_kvstore.py 
b/tests/python/unittest/test_kvstore.py
index 921a570..28d4ec2 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -107,6 +107,23 @@ def test_init():
     check_init(mx.kv.create(), 'a')
 
 @with_seed()
+def test_pull():
+    """test pull"""
+    def check_pull(kv):
+        a = mx.nd.ones(shape)
+        b = mx.nd.zeros(shape)
+        kv.init('1', mx.nd.zeros(shape))
+        kv.push('1', [a,a,a,a])
+        kv.pull('1', b)
+        check_diff_to_scalar(b, 4)
+        kv.init('2', mx.nd.zeros(shape))
+        kv.pull('2', b)
+        check_diff_to_scalar(b, 0)
+
+    check_pull(mx.kv.create('device'))
+    check_pull(mx.kv.create())
+
+@with_seed()
 def test_list_kv_pair():
     """list key-value pair push & pull"""
     def check_list_kv_pair(kv, key, stype):

Reply via email to