You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/05/02 06:46:38 UTC

[incubator-mxnet] branch v1.2.0 updated: fix kvstore rowsparse pull (#10777)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch v1.2.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.2.0 by this push:
     new 60641ef  fix kvstore rowsparse pull  (#10777)
60641ef is described below

commit 60641ef1183bb4584c9356e84b6ca6d5fce58d6d
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue May 1 23:46:30 2018 -0700

    fix kvstore rowsparse pull  (#10777)
    
    * fix kvstore rowsparse pull
    
    * Trigger CI
    
    * fix compilation error
---
 src/kvstore/comm.h                   | 24 +++++++++++++++---------
 tests/python/gpu/test_kvstore_gpu.py | 13 +++++++++++++
 2 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 9624899..c0514b1 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -223,9 +223,12 @@ class CommCPU : public Comm {
       CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
                << "BroadcastRowSparse with row_indices on gpu context not supported";
       // retain according to unique indices
-      const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
-      NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
-          src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+      const bool is_same_ctx = out->ctx() == src.ctx();
+      const bool is_diff_var = out->var() != src.var();
+      NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
+          NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
+                  src.dtype(), src.aux_types());
+
       Engine::Get()->PushAsync(
         [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob& indices = row_id.data();
@@ -566,13 +569,16 @@ class CommDevice : public Comm {
       CHECK_EQ(row_id.ctx(), src.ctx())
               << "row_id and src are expected to be on the same context";
       // retain according to indices
-      const bool is_diff_ctx = out->ctx() != src.ctx();
-      NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
-          src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+      const bool is_same_ctx = out->ctx() == src.ctx();
+      const bool is_diff_var = out->var() != src.var();
+      NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
+          NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
+                  out->dtype(), out->aux_types());
+
       Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob& indices = row_id.data();
           using namespace mxnet::common;
-          NDArray temp = out_gpu;
+          NDArray temp = retained_gpu;
           switch (temp.ctx().dev_mask()) {
             case cpu::kDevMask: {
               SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
@@ -591,9 +597,9 @@ class CommDevice : public Comm {
             default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
           }
           on_complete();
-        }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+        }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
       FnProperty::kNormal, priority, "KVStoreSparseRetain");
-      CopyFromTo(out_gpu, out, priority);
+      CopyFromTo(retained_gpu, out, priority);
     }
   }
 
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 1fc3a4d..04f18f4 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -91,6 +91,19 @@ def test_rsp_push_pull():
     check_rsp_push_pull('device')
     check_rsp_push_pull('device', is_push_cpu=False)
 
+def test_row_sparse_pull_single_device():
+    kvstore = mx.kv.create('local')
+    copy = mx.nd.random_normal(shape=(4,4), ctx=mx.cpu(0))
+    grad = copy.tostype("row_sparse")
+
+    key = 0
+    kvstore.init(key, grad)
+    idx = grad.indices
+    kvstore.push(key, grad)
+    kvstore.row_sparse_pull(key, out=grad, row_ids=idx)
+
+    assert_almost_equal(grad.asnumpy(), copy.asnumpy())
+
 def test_rsp_push_pull_large_rowid():
     num_rows = 793470
     val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())

-- 
To stop receiving notification emails like this one, please contact
anirudh2290@apache.org.