This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7395be8  Fix row_sparse_pull with single gpu (#10772)
7395be8 is described below

commit 7395be8b54c83042b735a6d162afc08b6af21f5a
Author: Leonard Lausen <[email protected]>
AuthorDate: Thu May 3 20:17:04 2018 -0700

    Fix row_sparse_pull with single gpu (#10772)
    
    * Fix row_sparse_pull with single gpu
    
    * Add test
    
    * Fix row_sparse_pull with single gpu
    
    * Add test
    
    * fix sparse retain in comm.h
    
    * remove dedup var
    
    * update test
---
 src/kvstore/comm.h                   | 25 ++++++++++++++++---------
 tests/python/gpu/test_kvstore_gpu.py | 15 +++++++++++++++
 2 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 9624899..70de79b 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -223,9 +223,12 @@ class CommCPU : public Comm {
       CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
                << "BroadcastRowSparse with row_indices on gpu context not 
supported";
       // retain according to unique indices
-      const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
-      NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, 
src.shape(),
-          src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+      const bool is_same_ctx = out->ctx() == src.ctx();
+      const bool is_diff_var = out->var() != src.var();
+      NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
+          NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
+                  src.dtype(), src.aux_types());
+
       Engine::Get()->PushAsync(
         [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob& indices = row_id.data();
@@ -565,14 +568,18 @@ class CommDevice : public Comm {
                << "BroadcastRowSparse expects row_sparse dst NDArray";
       CHECK_EQ(row_id.ctx(), src.ctx())
               << "row_id and src are expected to be on the same context";
+
       // retain according to indices
-      const bool is_diff_ctx = out->ctx() != src.ctx();
-      NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
-          src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+      const bool is_same_ctx = out->ctx() == src.ctx();
+      const bool is_diff_var = out->var() != src.var();
+      NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
+          NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
+                  out->dtype(), out->aux_types());
+
       Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete 
on_complete) {
           const TBlob& indices = row_id.data();
           using namespace mxnet::common;
-          NDArray temp = out_gpu;
+          NDArray temp = retained_gpu;
           switch (temp.ctx().dev_mask()) {
             case cpu::kDevMask: {
               SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
@@ -591,9 +598,9 @@ class CommDevice : public Comm {
             default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
           }
           on_complete();
-        }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+        }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
       FnProperty::kNormal, priority, "KVStoreSparseRetain");
-      CopyFromTo(out_gpu, out, priority);
+      CopyFromTo(retained_gpu, out, priority);
     }
   }
 
diff --git a/tests/python/gpu/test_kvstore_gpu.py 
b/tests/python/gpu/test_kvstore_gpu.py
index 1fc3a4d..a6e8ebf 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -91,6 +91,21 @@ def test_rsp_push_pull():
     check_rsp_push_pull('device')
     check_rsp_push_pull('device', is_push_cpu=False)
 
+
+def test_row_sparse_pull_single_device():
+    kvstore = mx.kv.create('device')
+    copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0))
+    grad = copy.tostype("row_sparse")
+
+    key = 0
+    kvstore.init(key, grad)
+    idx = grad.indices
+    kvstore.push(key, grad)
+    kvstore.row_sparse_pull(key, out=grad, row_ids=idx)
+
+    assert_almost_equal(grad.asnumpy(), copy.asnumpy())
+
+
 def test_rsp_push_pull_large_rowid():
     num_rows = 793470
     val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to