This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7395be8 Fix row_sparse_pull with single gpu (#10772)
7395be8 is described below
commit 7395be8b54c83042b735a6d162afc08b6af21f5a
Author: Leonard Lausen <[email protected]>
AuthorDate: Thu May 3 20:17:04 2018 -0700
Fix row_sparse_pull with single gpu (#10772)
* Fix row_sparse_pull with single gpu
* Add test
* Fix row_sparse_pull with single gpu
* Add test
* fix sparse retain in comm.h
* remove dedup var
* update test
---
src/kvstore/comm.h | 25 ++++++++++++++++---------
tests/python/gpu/test_kvstore_gpu.py | 15 +++++++++++++++
2 files changed, 31 insertions(+), 9 deletions(-)
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 9624899..70de79b 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -223,9 +223,12 @@ class CommCPU : public Comm {
CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
<< "BroadcastRowSparse with row_indices on gpu context not
supported";
// retain according to unique indices
- const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
- NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage,
src.shape(),
- src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+ const bool is_same_ctx = out->ctx() == src.ctx();
+ const bool is_diff_var = out->var() != src.var();
+ NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
+ NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
+ src.dtype(), src.aux_types());
+
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
@@ -565,14 +568,18 @@ class CommDevice : public Comm {
<< "BroadcastRowSparse expects row_sparse dst NDArray";
CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";
+
// retain according to indices
- const bool is_diff_ctx = out->ctx() != src.ctx();
- NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
- src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+ const bool is_same_ctx = out->ctx() == src.ctx();
+ const bool is_diff_var = out->var() != src.var();
+ NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
+ NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
+ out->dtype(), out->aux_types());
+
Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete
on_complete) {
const TBlob& indices = row_id.data();
using namespace mxnet::common;
- NDArray temp = out_gpu;
+ NDArray temp = retained_gpu;
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
@@ -591,9 +598,9 @@ class CommDevice : public Comm {
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
- }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+ }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
FnProperty::kNormal, priority, "KVStoreSparseRetain");
- CopyFromTo(out_gpu, out, priority);
+ CopyFromTo(retained_gpu, out, priority);
}
}
diff --git a/tests/python/gpu/test_kvstore_gpu.py
b/tests/python/gpu/test_kvstore_gpu.py
index 1fc3a4d..a6e8ebf 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -91,6 +91,21 @@ def test_rsp_push_pull():
check_rsp_push_pull('device')
check_rsp_push_pull('device', is_push_cpu=False)
+
+def test_row_sparse_pull_single_device():
+ kvstore = mx.kv.create('device')
+ copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0))
+ grad = copy.tostype("row_sparse")
+
+ key = 0
+ kvstore.init(key, grad)
+ idx = grad.indices
+ kvstore.push(key, grad)
+ kvstore.row_sparse_pull(key, out=grad, row_ids=idx)
+
+ assert_almost_equal(grad.asnumpy(), copy.asnumpy())
+
+
def test_rsp_push_pull_large_rowid():
num_rows = 793470
val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
--
To stop receiving notification emails like this one, please contact
[email protected].