You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/03/05 13:54:52 UTC

[incubator-mxnet] branch master updated: Non-blocking row_sparse_pull. Fix incorrect indices generated by device kvstore.row_sparse_pull (#9887)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 02dd89a  Non-blocking row_sparse_pull. Fix incorrect indices generated by device kvstore.row_sparse_pull (#9887)
02dd89a is described below

commit 02dd89a68f659c2a9b0bff62c54c50dff1151f6b
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Mon Mar 5 21:54:48 2018 +0800

    Non-blocking row_sparse_pull. Fix incorrect indices generated by device kvstore.row_sparse_pull (#9887)
    
    * nonblocking Kvstore (#195)
    
    * draft
    
    * rm use_copy. fix dist kvstore. TODO: fix dtype
    
    * fix dtype, shape
    
    * remove reshape
    
    * cleanup
    
    * fix compilation
    
    * rsp draft
    
    * update param name
    
    * doc update and small refactoring
    
    * minor updates
    
    * enhance test case with 2-D rowids
    
    * update gpu tests
    
    * rewrite gpu unique kernels
    
    * update gpu tests
    
    * update reshape test/
    
    * fix lint
    
    * update test for py3
---
 python/mxnet/kvstore.py               |   2 +-
 src/kvstore/comm.h                    | 196 ++++++++++------------------------
 src/kvstore/kvstore_dist.h            |  34 +++---
 src/kvstore/kvstore_local.h           |  58 ++++++----
 src/kvstore/kvstore_utils.cc          |  17 +--
 src/kvstore/kvstore_utils.cu          |  96 +++++++++--------
 src/kvstore/kvstore_utils.h           |   9 +-
 tests/nightly/dist_sync_kvstore.py    |   4 +-
 tests/python/gpu/test_kvstore_gpu.py  |  18 +++-
 tests/python/unittest/test_kvstore.py |   2 +-
 10 files changed, 196 insertions(+), 240 deletions(-)

diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 890c902..221b94f 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -321,7 +321,7 @@ class KVStore(object):
             other pull actions.
 
         row_ids : NDArray or list of NDArray
-            The row_ids for which to pull for each value. Each row_id is an 1D NDArray \
+            The row_ids for which to pull for each value. Each row_id is an 1-D NDArray \
             whose values don't have to be unique nor sorted.
 
         Examples
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index da2d03d..3085966 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -65,14 +65,14 @@ class Comm {
 
   /**
    * \brief broadcast src to dst[i] with target row_ids for every i
+   * \param key the identifier key for the stored ndarray
+   * \param src the source row_sparse ndarray to broadcast
    * \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast,
-            where the row_ids are expected to be unique and sorted
-   * \param use_copy if set to true, directly copy src to dst[i] without looking up the
-            provided row_ids
+            where the row_ids are expected to be unique and sorted in row_id.data()
+   * \param priority the priority of the operation
    */
   virtual void BroadcastRowSparse(int key, const NDArray& src,
                                   const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                                  const bool use_copy,
                                   const int priority) = 0;
 
   /**
@@ -209,7 +209,6 @@ class CommCPU : public Comm {
 
   void BroadcastRowSparse(int key, const NDArray& src,
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                          const bool use_copy,
                           const int priority) override {
     using namespace mshadow;
     CHECK_EQ(src.storage_type(), kRowSparseStorage)
@@ -219,107 +218,30 @@ class CommCPU : public Comm {
     for (size_t i = 0; i < dst.size(); ++i) {
       NDArray* out = dst[i].first;
       NDArray row_id = dst[i].second;
-      if (use_copy) {
-        CopyFromTo(src, out, priority);
-      } else {
-        CHECK_EQ(out->storage_type(), kRowSparseStorage)
-                 << "BroadcastRowSparse expects row_sparse dst NDArray";
-        CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
-                 << "BroadcastRowSparse with row_indices on gpu context not supported";
-        // retain according to unique indices
-        const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0])
-          || (row_id.dtype() != out->aux_type(rowsparse::kIdx))
-          || (out->ctx().dev_mask() != Context::kGPU);
-        if (use_sparse_retain) {  // use sparse_retain op
-          const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
-          NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
-              src.ctx(), true, src.dtype(), src.aux_types()) : *out;
-          Engine::Get()->PushAsync(
-            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-              const TBlob& indices = row_id.data();
-              NDArray temp = out_cpu;  // get rid of const qualifier
-              op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
-                                                    src, indices, kWriteTo,
-                                                    &temp);
-              on_complete();
-            }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
-            FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
-          if (is_to_gpu) {
-            CopyFromTo(out_cpu, out, priority);
-          }
-        } else {  // direct copy rows
-          Engine::Get()->PushAsync(
-            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-              CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
-                                    src, row_id, out);
-              // wait for GPU operations to complete
-              rctx.get_stream<gpu>()->Wait();
-              on_complete();
-            }, out->ctx(), {src.var(), row_id.var()}, {out->var()},
-            FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
-        }
-      }
+      CHECK_EQ(out->storage_type(), kRowSparseStorage)
+               << "BroadcastRowSparse expects row_sparse dst NDArray";
+      CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
+               << "BroadcastRowSparse with row_indices on gpu context not supported";
+      // retain according to unique indices
+      const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
+      NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
+          src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+      Engine::Get()->PushAsync(
+        [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          const TBlob& indices = row_id.data();
+          NDArray temp = retained_cpu;  // get rid the of const qualifier
+          op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
+                                                src, indices, kWriteTo,
+                                                &temp);
+          on_complete();
+        }, Context::CPU(), {src.var(), row_id.var()}, {retained_cpu.var()},
+        FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+      // if retained_cpu == out, CopyFromTo will ignore the copy operation
+      CopyFromTo(retained_cpu, out, priority);
     }
   }
 
  private:
-  /*!
-   * \brief When src is a rsp with full rows,
-   * simply copy retained rows directly from cpu to gpu
-   * without invoking sparse_retain op.
-   */
-  void CopyRetainedRowsToGPU(mshadow::Stream<cpu>* cpu_stream,
-                             mshadow::Stream<gpu>* gpu_stream,
-                             const NDArray& src,
-                             const NDArray& indices,
-                             NDArray* dst) {
-#if MXNET_USE_CUDA == 1
-    CHECK_EQ(src.storage_type(), kRowSparseStorage)
-      << "CopyRetainedRowsToGPU expects row-sparse src NDArray";
-    CHECK_EQ(src.ctx().dev_mask(), Context::kCPU)
-      << "CopyRetainedRowsToGPU with src on gpu context not supported";
-    CHECK_EQ(src.storage_shape()[0], src.shape()[0])
-      << "CopyRetainedRowsToGPU only supports src rsp with full rows";
-    CHECK_EQ(indices.storage_type(), kDefaultStorage);
-    CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU);
-    CHECK_EQ(dst->storage_type(), kRowSparseStorage);
-    CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU);
-    CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx))
-      << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)";
-    if (!src.storage_initialized() || indices.data().Size() == 0U) {
-      op::FillZerosRspImpl(gpu_stream, *dst);
-      return;
-    }
-    using namespace mshadow;
-
-    const TBlob& src_data = src.data();
-    const TBlob& idx_data = indices.data();
-    const size_t row_length = src.shape().ProdShape(1, src.shape().ndim());
-    const size_t num_rows_retained = idx_data.Size();
-    dst->CheckAndAlloc({Shape1(num_rows_retained)});
-    TBlob dst_data = dst->data();
-    TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx);
-    MSHADOW_TYPE_SWITCH(src.dtype(), DType, {
-      MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, {
-        // copy idx array
-        Tensor<gpu, 1, IType> dst_idx_tensor = dst_idx_data.FlatTo1D<gpu, IType>(gpu_stream);
-        const Tensor<cpu, 1, IType> idx_tensor = idx_data.FlatTo1D<cpu, IType>(cpu_stream);
-        Copy(dst_idx_tensor, idx_tensor, gpu_stream);
-        // copy src data
-        const Tensor<cpu, 2, DType> src_data_tensor = src_data.get_with_shape<cpu, 2, DType>(
-            Shape2(src_data.shape_[0], row_length), cpu_stream);
-        Tensor<gpu, 2, DType> dst_data_tensor = dst_data.get_with_shape<gpu, 2, DType>(
-            Shape2(dst_data.shape_[0], row_length), gpu_stream);
-        for (size_t i = 0; i < num_rows_retained; ++i) {
-          Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream);
-        }
-      })
-    })
-#else
-    LOG(FATAL) << "GPU not enabled";
-#endif
-  }
-
   // reduce sum into val[0]
   inline void ReduceSumCPU(const std::vector<NDArray> &in_data) {
     MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, {
@@ -632,7 +554,6 @@ class CommDevice : public Comm {
 
   void BroadcastRowSparse(int key, const NDArray& src,
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                          const bool use_copy,
                           const int priority) override {
     CHECK_EQ(src.storage_type(), kRowSparseStorage)
       << "BroadcastRowSparse expects row-sparse src NDArray";
@@ -640,46 +561,39 @@ class CommDevice : public Comm {
     for (size_t i = 0; i < dst.size(); ++i) {
       NDArray* out = dst[i].first;
       NDArray row_id = dst[i].second;
-      if (use_copy) {
-        CopyFromTo(src, out, priority);
-      } else {
-        CHECK_EQ(out->storage_type(), kRowSparseStorage)
-                 << "BroadcastRowSparse expects row_sparse dst NDArray";
-
-        const bool is_diff_ctx = out->ctx() != src.ctx();
-        NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
-            src.ctx(), true, out->dtype(), out->aux_types()) : *out;
-
-        CHECK_EQ(row_id.ctx(), src.ctx())
-                << "row_id and src are expected to be on the same context";
-
-        Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-            NDArray temp = out_gpu;
-            const TBlob& indices = row_id.data();
-            switch (temp.ctx().dev_mask()) {
-              case cpu::kDevMask: {
-                mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
-                    src, indices, kWriteTo, &temp);
-                break;
-              }
+      CHECK_EQ(out->storage_type(), kRowSparseStorage)
+               << "BroadcastRowSparse expects row_sparse dst NDArray";
+      CHECK_EQ(row_id.ctx(), src.ctx())
+              << "row_id and src are expected to be on the same context";
+      // retain according to indices
+      const bool is_diff_ctx = out->ctx() != src.ctx();
+      NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
+          src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+      Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          const TBlob& indices = row_id.data();
+          using namespace mxnet::common;
+          NDArray temp = out_gpu;
+          switch (temp.ctx().dev_mask()) {
+            case cpu::kDevMask: {
+              SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+                  src, indices, kWriteTo, &temp);
+              break;
+            }
 #if MXNET_USE_CUDA
-              case gpu::kDevMask: {
-                mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
-                    src, indices, kWriteTo, &temp);
-                // wait for GPU operations to complete
-                rctx.get_stream<gpu>()->Wait();
-                break;
-              }
-#endif
-              default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+            case gpu::kDevMask: {
+              SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+                  src, indices, kWriteTo, &temp);
+              // wait for GPU operations to complete
+              rctx.get_stream<gpu>()->Wait();
+              break;
             }
-            on_complete();
-          }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
-        FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
-        if (is_diff_ctx) {
-          CopyFromTo(out_gpu, out, priority);
-        }
-      }
+#endif
+            default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          }
+          on_complete();
+        }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+      FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+      CopyFromTo(out_gpu, out, priority);
     }
   }
 
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index e01cc42..7ab5783 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -279,24 +279,20 @@ class KVStoreDist : public KVStoreLocal {
       }
       auto &target_val_rowids = grouped_val_rowids[i];
       const size_t num_vals = target_val_rowids.size();
-      size_t num_rows = 0;
-      // TODO(haibin) refactor this for loop
       for (size_t i = 0; i < num_vals; i++) {
         auto &row_id = target_val_rowids[i].second;
-        NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
-        CopyFromTo(row_id, &indices, 0);
-        Unique(&indices, priority);
-        target_val_rowids[i].second = indices;
-        num_rows += indices.shape().Size();
-      }
-      if (num_vals > 1) {
-        // TODO(haibin) aggregate over all unique indices
-        LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet";
-      } else {
-        auto& indices = target_val_rowids[0].second;
-        PullRowSparse_(key, recv_buf, indices, priority);
-        comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority);
+        target_val_rowids[i].second = Unique(row_id, pinned_ctx_, 0);
       }
+      CHECK_EQ(num_vals, 1) << "RowSparsePull with multiple values is not supported yet";
+      NDArray& indices = target_val_rowids[0].second;
+      PullRowSparse_(key, recv_buf, indices, priority);
+      // The recv_buf contains values pulled from remote server with unique indices.
+      // Directly broadcast w/o rowids if num_vals == 1
+      auto get_val = [](const std::pair<NDArray*, NDArray>& p) { return p.first; };
+      std::vector<NDArray*> grouped_val(grouped_val_rowid.size());
+      std::transform(grouped_val_rowid.begin(), grouped_val_rowid.end(),
+                     grouped_val.begin(), get_val);
+      comm_->Broadcast(key, recv_buf, grouped_val, priority);
     }
   }
 
@@ -462,10 +458,12 @@ class KVStoreDist : public KVStoreLocal {
     auto pull_from_servers = [this, key, recv_buf, indices]
       (RunContext rctx, Engine::CallbackOnComplete cb) {
       // allocate memory for the buffer
-      size_t num_rows = indices.shape().Size();
+      CHECK_EQ(indices.dtype(), mshadow::kInt64);
+      const TBlob idx_data = indices.data();
+      size_t num_rows = idx_data.shape_.Size();
       recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
       real_t* data = recv_buf.data().dptr<real_t>();
-      const auto offsets = indices.data().dptr<int64_t>();
+      const auto offsets = idx_data.dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
       // convert to ps keys in row sparse format
@@ -480,7 +478,7 @@ class KVStoreDist : public KVStoreLocal {
       // because after pull is done, the callback function returns and locks are released.
       // at this point, later functions may access the indices variable while copy happens
       mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
-                    indices.data().FlatTo1D<cpu, int64_t>());
+                    idx_data.FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
                                        static_cast<int>(DataHandleType::kRowSparsePushPull),
                                        [vals, cb]() { delete vals; cb(); });
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 7b3d6fa..69fb37e 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -35,6 +35,7 @@
 #include <algorithm>
 #include "./comm.h"
 #include "./kvstore_utils.h"
+#include "../ndarray/ndarray_function.h"
 
 namespace mxnet {
 namespace kvstore {
@@ -226,12 +227,9 @@ class KVStoreLocal : public KVStore {
       const size_t num_vals = target_val_rowids.size();
       for (size_t j = 0; j < num_vals; j++) {
         auto &row_id = target_val_rowids[j].second;
-        NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
-        CopyFromTo(row_id, &indices, 0);
-        Unique(&indices, priority);
-        target_val_rowids[j].second = indices;
+        target_val_rowids[j].second = Unique(row_id, local.ctx(), 0);
       }
-      comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
+      comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], priority);
     }
   }
 
@@ -354,42 +352,62 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  /**
-   * \brief sort and get unique values.
+  /*
+   * \brief Compute the unique values in data and store them in ascending order
+   * in an int64_t row_sparse ndarray on ctx. The opeartion is async. The result
+   * row_sparse ndarray stores the unique values in out.data(). The aux_data()
+   * contains values that are not necessarily meaningful and should be ignored.
+   * \param data the input data
+   * \param ctx the target context
+   * \param priority the priority of the operation
    */
-  void Unique(NDArray *out, int priority) {
-    Resource rsc = ResourceManager::Get()->Request(out->ctx(),
+  NDArray Unique(const NDArray &data, Context ctx, int priority) {
+    // create kRowSparseStorage output ndarray
+    const size_t num_elements = data.shape().Size();
+    NDArray out(kRowSparseStorage, mshadow::Shape2(num_elements, 1),
+                ctx, true, mshadow::kInt64);
+    bool diff_ctx = data.ctx() != ctx;
+    NDArray data_in_ctx = diff_ctx ? NDArray(data.shape(), ctx, true, data.dtype()) : data;
+    // if data == data_in_ctx, CopyFromTo is smart enough to skip the copy
+    CopyFromTo(data, &data_in_ctx, priority);
+    Resource rsc = ResourceManager::Get()->Request(out.ctx(),
       ResourceRequest(ResourceRequest::kTempSpace));
+    // GPU requires temp resources
+    std::vector<Engine::VarHandle> mutate_vars{out.var()};
+    if (out.ctx().dev_mask() == gpu::kDevMask) mutate_vars.emplace_back(rsc.var);
     Engine::Get()->PushAsync(
-      [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-        NDArray *output = out;
-        CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
-        nnvm::dim_t size = out->shape()[0];
-        switch (out->ctx().dev_mask()) {
+      [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        // copy data.data() to out.data()
+        out.CheckAndAlloc({mshadow::Shape1(num_elements)});
+        TBlob out_data = out.data();
+        switch (out.ctx().dev_mask()) {
           case cpu::kDevMask: {
             mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
-            UniqueImpl(rsc, s, output, size);
+            ndarray::Copy<cpu, cpu>(data_in_ctx.data(), &out_data,
+                                    ctx, ctx, rctx);
+            UniqueImpl(rsc, s, out);
             break;
           }
   #if MXNET_USE_CUDA
           case gpu::kDevMask: {
             mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
-            UniqueImpl(rsc, s, output, size);
+            ndarray::Copy<gpu, gpu>(data_in_ctx.data(), &out_data,
+                                    ctx, ctx, rctx);
+            UniqueImpl(rsc, s, out);
             // wait for GPU operations to complete
             s->Wait();
             break;
           }
   #endif
           default:
-            LOG(FATAL) << "GPU not enabled.";
+            LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
         }
         on_complete();
-      }, out->ctx(), {}, {out->var(), rsc.var},
+      }, out.ctx(), {data_in_ctx.var()}, mutate_vars,
       FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
-    out->WaitToRead();
+    return out;
   }
 
-
   /// reducer and broadcaster
   Comm* comm_;
   /// pinned context
diff --git a/src/kvstore/kvstore_utils.cc b/src/kvstore/kvstore_utils.cc
index 9e14d8b..e187b0c 100644
--- a/src/kvstore/kvstore_utils.cc
+++ b/src/kvstore/kvstore_utils.cc
@@ -28,15 +28,18 @@
 namespace mxnet {
 namespace kvstore {
 
-
 template<>
 void UniqueImpl<cpu>(const Resource& rsc, mshadow::Stream<cpu> *s,
-                     NDArray *out, nnvm::dim_t size) {
-  MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
-    IType *dptr = out->data().dptr<IType>();
-    common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
-    size_t num_unique_idx = std::unique(dptr, dptr + size) - dptr;
-    *out = out->Reshape(mshadow::Shape1(num_unique_idx));
+                      const NDArray& out) {
+  const size_t num_elements = out.shape().Size();
+  CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected";
+  MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, {
+    IType *dptr = out.data().dptr<IType>();
+    common::ParallelSort(dptr, dptr + num_elements,
+                         engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
+    const size_t num_selected_out = std::unique(dptr, dptr + num_elements) - dptr;
+    // set the shape of data/aux_data according to the number of unique values
+    out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out));
   });
 }
 
diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu
index 00f316f..438fe29 100644
--- a/src/kvstore/kvstore_utils.cu
+++ b/src/kvstore/kvstore_utils.cu
@@ -40,63 +40,73 @@
 namespace mxnet {
 namespace kvstore {
 
-
 template<typename IType>
 size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream<gpu> *s,
-                     IType *dptr, nnvm::dim_t size) {
-#ifndef SORT_WITH_THRUST
+                   IType *dptr, const size_t size) {
+  // estimate unique temp space. The first byte is reserved to store the number
+  // of unique values selected
+  const size_t num_selected_bytes = sizeof(size_t);
+  size_t unique_temp_bytes = 0;
+  size_t *null_ptr = nullptr;
+  size_t *null_dptr = nullptr;
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  cub::DeviceSelect::Unique(NULL, unique_temp_bytes, null_dptr, null_dptr,
+                            null_ptr, size, stream);
+  // estimate sort temp space
+  const size_t sort_output_bytes = size * sizeof(IType);
   size_t sort_temp_bytes = 0;
-  cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes,
-    dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
-  mshadow::Tensor<gpu, 1, char> sort_space = rsc
-    .get_space_typed<gpu, 1, char>(
-      mshadow::Shape1(sort_temp_bytes), s);
-  void *sort_temp_storage = static_cast<void*>(sort_space.dptr_);
-  cub::DeviceRadixSort::SortKeys(sort_temp_storage, sort_temp_bytes,
-    dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+#ifndef SORT_WITH_THRUST
+  // The least-significant bit index (inclusive) needed for key comparison
+  const int begin_bit = 0;
+  // The most-significant bit index (exclusive) needed for key comparison
+  const int end_bit = sizeof(IType) * 8;
+  cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes, null_dptr, null_dptr,
+                                 size, begin_bit, end_bit, stream);
 #else
-  thrust::sort(thrust::cuda::par.on(mshadow::Stream<gpu>::GetStream(s)),
-    dptr, dptr + size, thrust::greater<IType>());
+  // sort_temp_bytes remains 0 because thrust request memory by itself
 #endif
-  size_t unique_temp_bytes = 0;
-  mshadow::Tensor<gpu, 1, char> dummy_space = rsc
-    .get_space_typed<gpu, 1, char>(
-      mshadow::Shape1(sizeof(size_t)), s);
-  size_t *dummy_ptr = reinterpret_cast<size_t*>(dummy_space.dptr_);
-  cub::DeviceSelect::Unique(NULL, unique_temp_bytes, dptr, dptr,
-    dummy_ptr, size, mshadow::Stream<gpu>::GetStream(s));
-
-  mshadow::Tensor<gpu, 1, char> unique_space = rsc
-    .get_space_typed<gpu, 1, char>(
-      mshadow::Shape1((unique_temp_bytes + sizeof(size_t) + 7) / 8 * 8), s);
-
-  void *unique_temp_storage = static_cast<void*>(
-    unique_space.dptr_);
-  size_t *d_num_selected_out = reinterpret_cast<size_t*>(
-    unique_space.dptr_ + (unique_temp_bytes + 7) / 8 * 8);
-
-  cub::DeviceSelect::Unique(unique_temp_storage, unique_temp_bytes, dptr, dptr,
-    d_num_selected_out, size, mshadow::Stream<gpu>::GetStream(s));
-
+  // request temp storage
+  const size_t total_workspace = num_selected_bytes + sort_output_bytes +
+                                 std::max(sort_temp_bytes, unique_temp_bytes);
+  mshadow::Tensor<gpu, 1, char> workspace = rsc
+    .get_space_typed<gpu, 1, char>(mshadow::Shape1(total_workspace), s);
+  // temp space layout: num_selected_ptr, sort_output_bytes, unique/sort_temp_storage
+  size_t* num_selected_ptr = reinterpret_cast<size_t*>(workspace.dptr_);
+  IType* sort_output_ptr = reinterpret_cast<IType*>(workspace.dptr_ + num_selected_bytes);
+  void *temp_storage = static_cast<void*>(workspace.dptr_ +
+                                          num_selected_bytes + sort_output_bytes);
+  // execute the sort kernel
+#ifndef SORT_WITH_THRUST
+  cub::DeviceRadixSort::SortKeys(temp_storage, sort_temp_bytes, dptr, sort_output_ptr,
+                                 size, begin_bit, end_bit, stream);
+#else
+  thrust::sort(thrust::cuda::par.on(stream),
+               dptr, dptr + size, thrust::greater<IType>());
+  CUDA_CALL(cudaMemcpy(sort_output_ptr, dptr, sort_output_bytes,
+                       cudaMemcpyDeviceToDevice));
+#endif
+  // execute unique kernel
+  cub::DeviceSelect::Unique(temp_storage, unique_temp_bytes, sort_output_ptr, dptr,
+                            num_selected_ptr, size, stream);
+  // retrieve num selected unique values
   size_t num_selected_out = 0;
-  CUDA_CALL(cudaMemcpy(&num_selected_out, d_num_selected_out, sizeof(size_t),
+  CUDA_CALL(cudaMemcpy(&num_selected_out, num_selected_ptr, num_selected_bytes,
      cudaMemcpyDeviceToHost));
   return num_selected_out;
 }
 
-/*!
- * \brief sort and get unique values.
- */
 template<>
 void UniqueImpl<gpu>(const Resource& rsc, mshadow::Stream<gpu> *s,
-                     NDArray *out, nnvm::dim_t size) {
-  MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
-    IType *dptr = out->data().dptr<IType>();
-    size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, size);
-    *out = out->Reshape(mshadow::Shape1(num_selected_out));
+                     const NDArray &out) {
+  const size_t num_elements = out.shape().Size();
+  CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected";
+  MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, {
+    IType *dptr = out.data().dptr<IType>();
+    size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, num_elements);
+    // set the shape of data/aux_data according to the number of unique values
+    out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out));
   });
 }
 
-
 }  // namespace kvstore
 }  // namespace mxnet
diff --git a/src/kvstore/kvstore_utils.h b/src/kvstore/kvstore_utils.h
index 8255619..ee173b4 100644
--- a/src/kvstore/kvstore_utils.h
+++ b/src/kvstore/kvstore_utils.h
@@ -35,12 +35,15 @@ namespace kvstore {
 
 
 /*!
- * \brief sort and get unique values.
+ * \brief compute unique and sorted values in a row_sparse ndarray.
+ * \param rsc Temp resource for computation
+ * \param s   Stream
+ * \param out Input and output ndarray. The ndarray stores the
+ *            unique elements in out.data().
  */
 template<typename xpu>
 void UniqueImpl(const Resource& rsc, mshadow::Stream<xpu> *s,
-                NDArray *out, nnvm::dim_t size);
-
+                 const NDArray& out);
 }  // namespace kvstore
 }  // namespace mxnet
 
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index df85fe5..3a3c916 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -99,7 +99,7 @@ def test_sync_push_pull():
             # select a random subset of rows this worker is interested in
             num_rows = shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np, dtype='int64')
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
             # perform pull
             val = mx.nd.zeros(shape, stype='row_sparse')
             kv.row_sparse_pull('9', out=val, row_ids=row_ids)
@@ -170,7 +170,7 @@ def test_sync_push_pull():
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np)
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
             # perform pull
             val = mx.nd.zeros(big_shape, stype='row_sparse')
             kv.row_sparse_pull('100', out=val, row_ids=row_ids)
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 5fd3097..1fc3a4d 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -57,14 +57,14 @@ def test_rsp_push_pull():
             vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)]
             if is_same_rowid:
                 row_id = np.random.randint(num_rows, size=num_rows)
-                row_ids = [mx.nd.array(row_id, dtype='int64')] * count
+                row_ids = [mx.nd.array(row_id)] * count
             elif use_slice:
-                total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows), dtype='int64')
+                total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows))
                 row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)]
             else:
                 for i in range(count):
                     row_id = np.random.randint(num_rows, size=num_rows)
-                    row_ids.append(mx.nd.array(row_id, dtype='int64'))
+                    row_ids.append(mx.nd.array(row_id))
             row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids
             vals_to_pull = vals[0] if len(vals) == 1 else vals
 
@@ -91,6 +91,16 @@ def test_rsp_push_pull():
     check_rsp_push_pull('device')
     check_rsp_push_pull('device', is_push_cpu=False)
 
+def test_rsp_push_pull_large_rowid():
+    num_rows = 793470
+    val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
+    kv = mx.kv.create('device')
+    kv.init('a', val)
+    out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu())
+    kv.push('a', val)
+    kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64'))
+    assert(out.indices.shape[0] == num_rows)
 
 if __name__ == '__main__':
-    test_rsp_push_pull()
+    import nose
+    nose.runmodule()
diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py
index 6bab06c..c56046c 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -76,7 +76,7 @@ def test_row_sparse_pull():
         for i in range(count):
             vals.append(mx.nd.zeros(shape).tostype('row_sparse'))
             row_id = np.random.randint(num_rows, size=num_rows)
-            row_ids.append(mx.nd.array(row_id))
+            row_ids.append(mx.nd.array(row_id).reshape((2, num_rows//2)))
         row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
         vals_to_pull = vals[0] if len(vals) == 1 else vals
 

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.