You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/05/02 06:46:38 UTC
[incubator-mxnet] branch v1.2.0 updated: fix kvstore rowsparse pull
(#10777)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch v1.2.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.2.0 by this push:
new 60641ef fix kvstore rowsparse pull (#10777)
60641ef is described below
commit 60641ef1183bb4584c9356e84b6ca6d5fce58d6d
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue May 1 23:46:30 2018 -0700
fix kvstore rowsparse pull (#10777)
* fix kvstore rowsparse pull
* Trigger CI
* fix compilation error
---
src/kvstore/comm.h | 24 +++++++++++++++---------
tests/python/gpu/test_kvstore_gpu.py | 13 +++++++++++++
2 files changed, 28 insertions(+), 9 deletions(-)
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 9624899..c0514b1 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -223,9 +223,12 @@ class CommCPU : public Comm {
CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
<< "BroadcastRowSparse with row_indices on gpu context not supported";
// retain according to unique indices
- const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
- NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
- src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+ const bool is_same_ctx = out->ctx() == src.ctx();
+ const bool is_diff_var = out->var() != src.var();
+ NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
+ NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
+ src.dtype(), src.aux_types());
+
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
@@ -566,13 +569,16 @@ class CommDevice : public Comm {
CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";
// retain according to indices
- const bool is_diff_ctx = out->ctx() != src.ctx();
- NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
- src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+ const bool is_same_ctx = out->ctx() == src.ctx();
+ const bool is_diff_var = out->var() != src.var();
+ NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
+ NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
+ out->dtype(), out->aux_types());
+
Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
using namespace mxnet::common;
- NDArray temp = out_gpu;
+ NDArray temp = retained_gpu;
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
@@ -591,9 +597,9 @@ class CommDevice : public Comm {
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
- }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+ }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
FnProperty::kNormal, priority, "KVStoreSparseRetain");
- CopyFromTo(out_gpu, out, priority);
+ CopyFromTo(retained_gpu, out, priority);
}
}
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 1fc3a4d..04f18f4 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -91,6 +91,19 @@ def test_rsp_push_pull():
check_rsp_push_pull('device')
check_rsp_push_pull('device', is_push_cpu=False)
+def test_row_sparse_pull_single_device():
+ kvstore = mx.kv.create('local')
+ copy = mx.nd.random_normal(shape=(4,4), ctx=mx.cpu(0))
+ grad = copy.tostype("row_sparse")
+
+ key = 0
+ kvstore.init(key, grad)
+ idx = grad.indices
+ kvstore.push(key, grad)
+ kvstore.row_sparse_pull(key, out=grad, row_ids=idx)
+
+ assert_almost_equal(grad.asnumpy(), copy.asnumpy())
+
def test_rsp_push_pull_large_rowid():
num_rows = 793470
val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
--
To stop receiving notification emails like this one, please contact
anirudh2290@apache.org.