You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/06 20:57:22 UTC

[incubator-mxnet] branch master updated: fix creating cpu sparse array from gpu data (#10830)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1968bba  fix creating cpu sparse array from gpu data (#10830)
1968bba is described below

commit 1968bba659bfb6fcc461132aac1729a4f33d5c41
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sun May 6 13:57:15 2018 -0700

    fix creating cpu sparse array from gpu data (#10830)
---
 src/ndarray/ndarray.cc                |  8 ++++----
 tests/python/gpu/test_operator_gpu.py | 18 ++++++++++++++++++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index a643da1..67b4c06 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1911,7 +1911,7 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
     if (src_dev_mask == cpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
       Engine::Get()->PushAsync(
         [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
+          const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<cpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
@@ -1921,17 +1921,17 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == cpu::kDevMask) {
       Engine::Get()->PushAsync(
         [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
+          const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<gpu, cpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
           on_complete();
-        }, this->ctx(), const_vars, {this->var()},
+        }, src.ctx(), const_vars, {this->var()},
         FnProperty::kCopyFromGPU, 0, "SyncCopyFromNDArrayGPU2CPU");
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
       Engine::Get()->PushAsync(
         [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
+          const TBlob src_data = (i >= 0 ? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<gpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 08c749e..313730c 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1834,6 +1834,24 @@ def test_batchnorm_backwards_notrain():
                 loss=y.square().sum()
             loss.backward(train_mode=False)
 
+@with_seed()
+def test_create_sparse_ndarray_gpu_to_cpu():
+    dim0 = 10
+    dim1 = 5
+    densities = [0, 0.5, 1]
+    for density in densities:
+        shape = rand_shape_2d(dim0, dim1)
+        matrix = rand_ndarray(shape, 'row_sparse', density)
+        data = matrix.data
+        indices = matrix.indices
+        rsp_created = mx.nd.sparse.row_sparse_array((data, indices), shape=shape, ctx=mx.cpu())
+        assert rsp_created.stype == 'row_sparse'
+        assert same(rsp_created.data.asnumpy(), data.asnumpy())
+        assert same(rsp_created.indices.asnumpy(), indices.asnumpy())
+        rsp_copy = mx.nd.array(rsp_created)
+        assert(same(rsp_copy.asnumpy(), rsp_created.asnumpy()))
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.