You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/15 00:08:24 UTC

[incubator-mxnet] branch master updated: rsp push and rsp pull for comm device, used in kvstore('device') (#8732)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 786e376  rsp push and rsp pull for comm device, used in kvstore('device') (#8732)
786e376 is described below

commit 786e376651c7f6f9b05b7758d091b22a7a72ef55
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Mon Jan 15 08:08:18 2018 +0800

    rsp push and rsp pull for comm device, used in kvstore('device') (#8732)
    
    * comm device for rsp push and pull
    
    * update
    
    * update test
    
    * optimization for same row_ids
    
    * add stream->wait
    
    * remove using space
    
    * fix race of rsc and extend ElementwiseSum to rsp cases
    
    * add log fatal in ElementwiseSum
    
    * direct copy rows if full rsp and put all outputs on ctx of src
    
    * trigger
    
    * fix
    
    * simplify copy
    
    * move check same rowids to utils and add test for same rowids case
    
    * remove direct copy row by row
    
    * fix checkSameRowid
    
    * gpu unique impl draft
    
    * unique
    
    * update
    
    * fix windows build
    
    * trigger windows build
    
    * support single rowid with multiple vals
    
    * address comments
    
    * check same row_ids and copy in fronted
    
    * revise names and disable test for local kvstore
---
 python/mxnet/kvstore.py                  |  20 ++++-
 src/common/utils.cc                      |  10 +++
 src/common/utils.cu                      |  10 +++
 src/common/utils.h                       |  11 +++
 src/kvstore/comm.h                       | 130 +++++++++++++++++++++++--------
 src/kvstore/kvstore_local.h              |  51 +++++++-----
 src/{common => kvstore}/utils.cc         |  24 +++---
 src/kvstore/utils.cu                     | 102 ++++++++++++++++++++++++
 src/{common/utils.cc => kvstore/utils.h} |  36 +++++----
 src/ndarray/ndarray.cc                   |  82 ++++++++++++-------
 tests/python/gpu/test_kvstore_gpu.py     |  74 +++++++++++-------
 11 files changed, 414 insertions(+), 136 deletions(-)

diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index b2a4bea..890c902 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -298,7 +298,8 @@ class KVStore(object):
 
     def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
         """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
-        from the store with specified row_ids.
+        from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \
+        is invoked just once and the result is broadcast to all the rest of outputs.
 
         `row_sparse_pull` is executed asynchronously after all previous
         `pull`/`row_sparse_pull` calls and the last `push` call for the
@@ -349,7 +350,17 @@ class KVStore(object):
         """
         assert(out is not None)
         assert(row_ids is not None)
-        ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+        if isinstance(row_ids, NDArray):
+            row_ids = [row_ids]
+        assert(isinstance(row_ids, list)), \
+            "row_ids should be NDArray or list of NDArray"
+        first_out = out
+        # whether row_ids are the same
+        single_rowid = False
+        if len(row_ids) == 1 and isinstance(out, list):
+            single_rowid = True
+            first_out = [out[0]]
+        ckeys, cvals, use_str_keys = _ctype_key_value(key, first_out)
         _, crow_ids, _ = _ctype_key_value(key, row_ids)
         assert(len(crow_ids) == len(cvals)), \
                "the number of row_ids doesn't match the number of values"
@@ -359,6 +370,11 @@ class KVStore(object):
         else:
             check_call(_LIB.MXKVStorePullRowSparse(
                 self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+        # the result can be copied to other devices without invoking row_sparse_pull
+        # if the indices are the same
+        if single_rowid:
+            for out_i in out[1:]:
+                out[0].copyto(out_i)
 
     def set_gradient_compression(self, compression_params):
         """ Specifies type of low-bit quantization for gradient compression \
diff --git a/src/common/utils.cc b/src/common/utils.cc
index 784fcf8..9fe46d9 100644
--- a/src/common/utils.cc
+++ b/src/common/utils.cc
@@ -24,6 +24,7 @@
 
 #include "./utils.h"
 #include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
 
 namespace mxnet {
 namespace common {
@@ -35,6 +36,15 @@ void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
 }
 
 template<>
+void SparseRetainOpForwardRspWrapper<cpu>(mshadow::Stream<cpu> *s,
+                                          const NDArray& input_nd,
+                                          const TBlob& idx_data,
+                                          const OpReqType req,
+                                          NDArray* output_nd) {
+  mxnet::op::SparseRetainOpForwardRspImpl<cpu>(s, input_nd, idx_data, req, output_nd);
+}
+
+template<>
 void CastStorageDispatch<cpu>(const OpContext& ctx,
                               const NDArray& input,
                               const NDArray& output) {
diff --git a/src/common/utils.cu b/src/common/utils.cu
index c6e2bf8..0937d7a 100644
--- a/src/common/utils.cu
+++ b/src/common/utils.cu
@@ -24,6 +24,7 @@
 
 #include "./utils.h"
 #include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
 
 namespace mxnet {
 namespace common {
@@ -35,6 +36,15 @@ void CheckFormatWrapper<gpu>(const RunContext &rctx, const NDArray &input,
 }
 
 template<>
+void SparseRetainOpForwardRspWrapper<gpu>(mshadow::Stream<gpu> *s,
+                                          const NDArray& input_nd,
+                                          const TBlob& idx_data,
+                                          const OpReqType req,
+                                          NDArray* output_nd) {
+  mxnet::op::SparseRetainOpForwardRspImpl<gpu>(s, input_nd, idx_data, req, output_nd);
+}
+
+template<>
 void CastStorageDispatch<gpu>(const OpContext& ctx,
                               const NDArray& input,
                               const NDArray& output) {
diff --git a/src/common/utils.h b/src/common/utils.h
index 6f7e452..4bb8024 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -213,7 +213,18 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input,
   }
 }
 
+/*! \brief Pick rows specified by user input index array from a row sparse ndarray
+ *         and save them in the output sparse ndarray.
+ */
+template<typename xpu>
+void SparseRetainOpForwardRspWrapper(mshadow::Stream<xpu> *s,
+                                     const NDArray& input_nd,
+                                     const TBlob& idx_data,
+                                     const OpReqType req,
+                                     NDArray* output_nd);
 
+/* \brief Casts tensor storage type to the new type.
+ */
 template<typename xpu>
 void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);
 
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 5429df7..d41fa64 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -34,6 +34,7 @@
 #include "gradient_compression.h"
 #include "../ndarray/ndarray_function.h"
 #include "../operator/tensor/sparse_retain-inl.h"
+#include "./utils.h"
 namespace mxnet {
 namespace kvstore {
 /**
@@ -176,17 +177,17 @@ class CommCPU : public Comm {
         reduce[i] = buf.copy_buf[i];
         const_vars[i] = reduce[i].var();
       }
-      auto result = buf.merged;
+      NDArray result = buf.merged;
+      Resource rsc = ResourceManager::Get()->Request(result.ctx(),
+          ResourceRequest(ResourceRequest::kTempSpace));
       Engine::Get()->PushAsync(
-        [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           NDArray out = result;
-          Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
-              ResourceRequest(ResourceRequest::kTempSpace));
           is_serial_push_?
             ReduceSumCPUExSerial(reduce, &out)
             : mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
           on_complete();
-        }, Context::CPU(), const_vars, {result.var()},
+        }, Context::CPU(), const_vars, {result.var(), rsc.var},
         FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
     }
 
@@ -491,11 +492,7 @@ class CommDevice : public Comm {
 
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int dtype = mshadow::kFloat32) override {
-    if (stype == kDefaultStorage) {
-      sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
-    } else {
-      LOG(FATAL) << "storage type " << stype << " not implemented for device yet";
-    }
+    sorted_key_attrs_.emplace_back(key, shape, dtype, stype);
   }
 
   void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -528,26 +525,42 @@ class CommDevice : public Comm {
     InitBuffersAndComm(src);
     auto& buf = merge_buf_[key];
     std::vector<NDArray> reduce(src.size());
-    CopyFromTo(src[0], &(buf.merged), priority);
-    reduce[0] = buf.merged;
 
-    if (buf.copy_buf.empty()) {
-      // TODO(mli) this results in large device memory usage for huge ndarray,
-      // such as the largest fullc in VGG. consider to do segment reduce with
-      // NDArray.Slice or gpu direct memory access. for the latter, we need to
-      // remove some ctx check, and also it reduces 20% perf
-      buf.copy_buf.resize(src.size()-1);
+    const NDArrayStorageType stype = buf.merged.storage_type();
+    if (stype == kDefaultStorage) {
+      CopyFromTo(src[0], &(buf.merged), priority);
+      reduce[0] = buf.merged;
+
+      if (buf.copy_buf.empty()) {
+        // TODO(mli) this results in large device memory usage for huge ndarray,
+        // such as the largest fullc in VGG. consider to do segment reduce with
+        // NDArray.Slice or gpu direct memory access. for the latter, we need to
+        // remove some ctx check, and also it reduces 20% perf
+        buf.copy_buf.resize(src.size()-1);
+        for (size_t i = 0; i < src.size()-1; ++i) {
+          buf.copy_buf[i] = NDArray(
+            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+        }
+      }
       for (size_t i = 0; i < src.size()-1; ++i) {
-        buf.copy_buf[i] = NDArray(
-          buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+        CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
+        reduce[i+1] = buf.copy_buf[i];
+      }
+    } else {
+      if (buf.copy_buf.empty()) {
+        buf.copy_buf.resize(src.size());
+        for (size_t j = 0; j < src.size(); ++j) {
+          buf.copy_buf[j] = NDArray(
+            buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
+            true, buf.merged.dtype());
+        }
+      }
+      for (size_t i = 0; i < src.size(); ++i) {
+        CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
+        reduce[i] = buf.copy_buf[i];
       }
     }
-    for (size_t i = 0; i < src.size()-1; ++i) {
-      CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-      reduce[i+1] = buf.copy_buf[i];
-    }
-
-    ElementwiseSum(reduce, &buf.merged);
+    ElementwiseSum(reduce, &buf.merged, priority);
     return buf.merged;
   }
 
@@ -621,7 +634,53 @@ class CommDevice : public Comm {
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
                           const bool use_copy,
                           const int priority) override {
-    LOG(FATAL) << "Not implemented yet";
+    CHECK_EQ(src.storage_type(), kRowSparseStorage)
+      << "BroadcastRowSparse expects row-sparse src NDArray";
+
+    for (size_t i = 0; i < dst.size(); ++i) {
+      NDArray* out = dst[i].first;
+      NDArray row_id = dst[i].second;
+      if (use_copy) {
+        CopyFromTo(src, out, priority);
+      } else {
+        CHECK_EQ(out->storage_type(), kRowSparseStorage)
+                 << "BroadcastRowSparse expects row_sparse dst NDArray";
+
+        const bool is_diff_ctx = out->ctx() != src.ctx();
+        NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
+            src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+
+        CHECK_EQ(row_id.ctx(), src.ctx())
+                << "row_id and src are expected to be on the same context";
+
+        Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+            NDArray temp = out_gpu;
+            const TBlob& indices = row_id.data();
+            switch (temp.ctx().dev_mask()) {
+              case cpu::kDevMask: {
+                mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+                    src, indices, kWriteTo, &temp);
+                break;
+              }
+#if MXNET_USE_CUDA
+              case gpu::kDevMask: {
+                mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+                    src, indices, kWriteTo, &temp);
+                // wait for GPU operations to complete
+                rctx.get_stream<gpu>()->Wait();
+                break;
+              }
+#endif
+              default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+            }
+            on_complete();
+          }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+        FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+        if (is_diff_ctx) {
+          CopyFromTo(out_gpu, out, priority);
+        }
+      }
+    }
   }
 
  private:
@@ -667,7 +726,7 @@ class CommDevice : public Comm {
 #endif
   }
 
-  using KeyAttrs = std::tuple<int, TShape, int>;
+  using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
   // try to allocate buff on device evenly
   void InitMergeBuffer(const std::vector<Context>& devs) {
     std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
@@ -680,9 +739,10 @@ class CommDevice : public Comm {
       ctx_info[d.dev_id] = std::make_pair(d, 0);
     }
     for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
-      int key  = std::get<0>(sorted_key_attrs_[i]);
-      TShape s = std::get<1>(sorted_key_attrs_[i]);
-      int type = std::get<2>(sorted_key_attrs_[i]);
+      const int key  = std::get<0>(sorted_key_attrs_[i]);
+      const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
+      const int type = std::get<2>(sorted_key_attrs_[i]);
+      const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
       auto& buf = merge_buf_[key];
       Context ctx;
       size_t min_size = std::numeric_limits<size_t>::max();
@@ -693,8 +753,12 @@ class CommDevice : public Comm {
           min_size = size;
         }
       }
-      buf.merged = NDArray(s, ctx, false, type);
-      ctx_info[ctx.dev_id].second += s.Size();
+      if (stype == kDefaultStorage) {
+        buf.merged = NDArray(shape, ctx, false, type);
+      } else {
+        buf.merged = NDArray(stype, shape, ctx, true, type);
+      }
+      ctx_info[ctx.dev_id].second += shape.Size();
     }
     inited_ = true;
   }
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 1bb84fd..78b6c8f 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -34,6 +34,7 @@
 #include <functional>
 #include <algorithm>
 #include "./comm.h"
+#include "./utils.h"
 
 namespace mxnet {
 namespace kvstore {
@@ -223,12 +224,12 @@ class KVStoreLocal : public KVStore {
                << "PullRowSparse expects row_sparse src NDArray";
       auto &target_val_rowids = grouped_val_rowids[i];
       const size_t num_vals = target_val_rowids.size();
-      for (size_t i = 0; i < num_vals; i++) {
-        auto &row_id = target_val_rowids[i].second;
-        NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
+      for (size_t j = 0; j < num_vals; j++) {
+        auto &row_id = target_val_rowids[j].second;
+        NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
         CopyFromTo(row_id, &indices, 0);
         Unique(&indices, priority);
-        target_val_rowids[i].second = indices;
+        target_val_rowids[j].second = indices;
       }
       comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
     }
@@ -354,29 +355,41 @@ class KVStoreLocal : public KVStore {
   }
 
   /**
-   * \brief sort and get unique values. Output is expected to be on cpu_pinned context
+   * \brief sort and get unique values.
    */
-  void Unique(NDArray *out, int priority = 0) {
-    CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
-             << "Unique expects input with `pinned_ctx_`";
+  void Unique(NDArray *out, int priority) {
+    Resource rsc = ResourceManager::Get()->Request(out->ctx(),
+      ResourceRequest(ResourceRequest::kTempSpace));
     Engine::Get()->PushAsync(
-      [out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+      [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         NDArray *output = out;
         CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
-        const auto size = out->shape()[0];
-        auto out_data = output->data();
-        MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, {
-          auto dptr = output->data().dptr<IType>();
-          common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
-          auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
-          *output = output->Reshape(mshadow::Shape1(num_unique_idx));
-        });
+        nnvm::dim_t size = out->shape()[0];
+        switch (out->ctx().dev_mask()) {
+          case cpu::kDevMask: {
+            mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
+            UniqueImpl(rsc, s, output, size);
+            break;
+          }
+  #if MXNET_USE_CUDA
+          case gpu::kDevMask: {
+            mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+            UniqueImpl(rsc, s, output, size);
+            // wait for GPU operations to complete
+            s->Wait();
+            break;
+          }
+  #endif
+          default:
+            LOG(FATAL) << "GPU not enabled.";
+        }
         on_complete();
-      }, pinned_ctx_, {}, {out->var()},
-      FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
+      }, out->ctx(), {}, {out->var(), rsc.var},
+      FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
     out->WaitToRead();
   }
 
+
   /// reducer and broadcaster
   Comm* comm_;
   /// pinned context
diff --git a/src/common/utils.cc b/src/kvstore/utils.cc
similarity index 64%
copy from src/common/utils.cc
copy to src/kvstore/utils.cc
index 784fcf8..c22553f 100644
--- a/src/common/utils.cc
+++ b/src/kvstore/utils.cc
@@ -23,23 +23,23 @@
  */
 
 #include "./utils.h"
-#include "../operator/tensor/cast_storage-inl.h"
+#include "../common/utils.h"
 
 namespace mxnet {
-namespace common {
+namespace kvstore {
 
-template<>
-void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
-                             const TBlob &err_cpu, const bool full_check) {
-  CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
-}
 
 template<>
-void CastStorageDispatch<cpu>(const OpContext& ctx,
-                              const NDArray& input,
-                              const NDArray& output) {
-  mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
+void UniqueImpl<cpu>(const Resource& rsc, mshadow::Stream<cpu> *s,
+                     NDArray *out, nnvm::dim_t size) {
+  MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
+    IType *dptr = out->data().dptr<IType>();
+    common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
+    size_t num_unique_idx = std::unique(dptr, dptr + size) - dptr;
+    *out = out->Reshape(mshadow::Shape1(num_unique_idx));
+  });
 }
 
-}  // namespace common
+
+}  // namespace kvstore
 }  // namespace mxnet
diff --git a/src/kvstore/utils.cu b/src/kvstore/utils.cu
new file mode 100644
index 0000000..088a49e
--- /dev/null
+++ b/src/kvstore/utils.cu
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file utils.cu
+ * \brief gpu implementation of util functions
+ */
+#if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 44
+// Many CUDA 8 compilers other than V8.0.44 crash on Windows
+#pragma warning("Potential crash on CUDA compiler detected. Switching sorting from CUB to Thrust")
+#define SORT_WITH_THRUST
+#include <thrust/device_ptr.h>
+#include <thrust/sort.h>
+#include <thrust/system/cuda/execution_policy.h>
+#else
+#undef SORT_WITH_THRUST
+#endif
+#include "./utils.h"
+#include "../common/utils.h"
+#include <cub/cub.cuh>
+#include <mxnet/resource.h>
+
+namespace mxnet {
+namespace kvstore {
+
+
+template<typename IType>
+size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream<gpu> *s,
+                     IType *dptr, nnvm::dim_t size) {
+#ifndef SORT_WITH_THRUST
+  size_t sort_temp_bytes = 0;
+  cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes,
+    dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+  mshadow::Tensor<gpu, 1, char> sort_space = rsc
+    .get_space_typed<gpu, 1, char>(
+      mshadow::Shape1(sort_temp_bytes), s);
+  void *sort_temp_storage = static_cast<void*>(sort_space.dptr_);
+  cub::DeviceRadixSort::SortKeys(sort_temp_storage, sort_temp_bytes,
+    dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+#else
+  thrust::sort(thrust::cuda::par.on(mshadow::Stream<gpu>::GetStream(s)),
+    dptr, dptr + size, thrust::greater<IType>());
+#endif
+  size_t unique_temp_bytes = 0;
+  mshadow::Tensor<gpu, 1, char> dummy_space = rsc
+    .get_space_typed<gpu, 1, char>(
+      mshadow::Shape1(sizeof(size_t)), s);
+  size_t *dummy_ptr = reinterpret_cast<size_t*>(dummy_space.dptr_);
+  cub::DeviceSelect::Unique(NULL, unique_temp_bytes, dptr, dptr,
+    dummy_ptr, size, mshadow::Stream<gpu>::GetStream(s));
+
+  mshadow::Tensor<gpu, 1, char> unique_space = rsc
+    .get_space_typed<gpu, 1, char>(
+      mshadow::Shape1((unique_temp_bytes + sizeof(size_t) + 7) / 8 * 8), s);
+
+  void *unique_temp_storage = static_cast<void*>(
+    unique_space.dptr_);
+  size_t *d_num_selected_out = reinterpret_cast<size_t*>(
+    unique_space.dptr_ + (unique_temp_bytes + 7) / 8 * 8);
+
+  cub::DeviceSelect::Unique(unique_temp_storage, unique_temp_bytes, dptr, dptr,
+    d_num_selected_out, size, mshadow::Stream<gpu>::GetStream(s));
+
+  size_t num_selected_out = 0;
+  CUDA_CALL(cudaMemcpy(&num_selected_out, d_num_selected_out, sizeof(size_t),
+     cudaMemcpyDeviceToHost));
+  return num_selected_out;
+}
+
+/*!
+ * \brief sort and get unique values.
+ */
+template<>
+void UniqueImpl<gpu>(const Resource& rsc, mshadow::Stream<gpu> *s,
+                     NDArray *out, nnvm::dim_t size) {
+  MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
+    IType *dptr = out->data().dptr<IType>();
+    size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, size);
+    *out = out->Reshape(mshadow::Shape1(num_selected_out));
+  });
+}
+
+
+}  // namespace kvstore
+}  // namespace mxnet
diff --git a/src/common/utils.cc b/src/kvstore/utils.h
similarity index 57%
copy from src/common/utils.cc
copy to src/kvstore/utils.h
index 784fcf8..7547345 100644
--- a/src/common/utils.cc
+++ b/src/kvstore/utils.h
@@ -18,28 +18,30 @@
  */
 
 /*!
- * \file utils.cc
- * \brief cpu implementation of util functions
+ * \file utils.h
+ * \brief Basic utilility functions.
  */
+#ifndef MXNET_KVSTORE_UTILS_H_
+#define MXNET_KVSTORE_UTILS_H_
 
-#include "./utils.h"
-#include "../operator/tensor/cast_storage-inl.h"
+#include <dmlc/logging.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/resource.h>
+#include <utility>
+#include <vector>
 
 namespace mxnet {
-namespace common {
+namespace kvstore {
 
-template<>
-void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
-                             const TBlob &err_cpu, const bool full_check) {
-  CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
-}
 
-template<>
-void CastStorageDispatch<cpu>(const OpContext& ctx,
-                              const NDArray& input,
-                              const NDArray& output) {
-  mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
-}
+/*!
+ * \brief sort and get unique values.
+ */
+template<typename xpu>
+void UniqueImpl(const Resource& rsc, mshadow::Stream<xpu> *s,
+                NDArray *out, nnvm::dim_t size);
 
-}  // namespace common
+}  // namespace kvstore
 }  // namespace mxnet
+
+#endif  // MXNET_KVSTORE_UTILS_H_
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 8a3bb8d..4db314f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -623,36 +623,66 @@ void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priori
   // important: callback must always capture by value
   NDArray ret = *out;
 
-  switch (out->ctx().dev_mask()) {
-    case cpu::kDevMask: {
-      Engine::Get()->PushSync([source, ret](RunContext ctx) {
-          std::vector<TBlob> source_tblob(source.size());
-          for (size_t i = 0; i < source.size(); ++i) {
-            source_tblob[i] = source[i].data();
-          }
-          TBlob tmp = ret.data();
-          ndarray::ElementwiseSum<cpu>(source_tblob, &tmp, ctx);
-        }, out->ctx(), const_vars, {ret.var()},
-        FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
-      break;
+  const NDArrayStorageType stype = ret.storage_type();
+
+  if (stype == kDefaultStorage) {
+    switch (out->ctx().dev_mask()) {
+      case cpu::kDevMask: {
+        Engine::Get()->PushSync([source, ret](RunContext ctx) {
+            std::vector<TBlob> source_tblob(source.size());
+            for (size_t i = 0; i < source.size(); ++i) {
+              source_tblob[i] = source[i].data();
+            }
+            TBlob tmp = ret.data();
+            ndarray::ElementwiseSum<cpu>(source_tblob, &tmp, ctx);
+          }, out->ctx(), const_vars, {ret.var()},
+          FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
+        break;
+      }
+#if MXNET_USE_CUDA
+      case gpu::kDevMask: {
+        Engine::Get()->PushSync([source, ret](RunContext ctx) {
+            std::vector<TBlob> source_tblob(source.size());
+            for (size_t i = 0; i < source.size(); ++i) {
+              source_tblob[i] = source[i].data();
+            }
+            TBlob tmp = ret.data();
+            ndarray::ElementwiseSum<gpu>(source_tblob, &tmp, ctx);
+            // Wait GPU kernel to complete
+            ctx.get_stream<gpu>()->Wait();
+          }, out->ctx(), const_vars, {ret.var()},
+          FnProperty::kNormal, priority, PROFILER_MESSAGE("DenseElementwiseSum"));
+        break;
+      }
+#endif
+      default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
     }
+  } else if (stype == kRowSparseStorage) {
+    Resource rsc = ResourceManager::Get()->Request(ret.ctx(),
+      ResourceRequest(ResourceRequest::kTempSpace));
+
+    Engine::Get()->PushSync(
+      [source, ret, rsc](RunContext rctx) {
+        NDArray result = ret;
+        switch (ret.ctx().dev_mask()) {
+          case cpu::kDevMask: {
+            mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, source, &result);
+            break;
+          }
 #if MXNET_USE_CUDA
-    case gpu::kDevMask: {
-      Engine::Get()->PushSync([source, ret](RunContext ctx) {
-          std::vector<TBlob> source_tblob(source.size());
-          for (size_t i = 0; i < source.size(); ++i) {
-            source_tblob[i] = source[i].data();
+          case gpu::kDevMask: {
+            mxnet::ndarray::ElementwiseSum(rctx.get_stream<gpu>(), rsc, source, &result);
+            // wait for GPU operations to complete
+            rctx.get_stream<gpu>()->Wait();
+            break;
           }
-          TBlob tmp = ret.data();
-          ndarray::ElementwiseSum<gpu>(source_tblob, &tmp, ctx);
-          // Wait GPU kernel to complete
-          ctx.get_stream<gpu>()->Wait();
-        }, out->ctx(), const_vars, {ret.var()},
-        FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
-      break;
-    }
 #endif
-    default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+        }
+      }, ret.ctx(), const_vars, {ret.var(), rsc.var},
+    FnProperty::kNormal, priority, PROFILER_MESSAGE("RowSparseElementwiseSum"));
+  } else {
+    LOG(FATAL) << "Not implemented for storage_type " << common::stype_string(stype);
   }
 }
 
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 20528be..3249c98 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -26,9 +26,9 @@ keys = [5, 7, 11]
 str_keys = ['b', 'c', 'd']
 
 
-def init_kv_with_str(stype='default'):
+def init_kv_with_str(stype='default', kv_type='local'):
     """init kv """
-    kv = mx.kv.create()
+    kv = mx.kv.create(kv_type)
     # single
     kv.init('a', mx.nd.zeros(shape, stype=stype))
     # list
@@ -36,34 +36,54 @@ def init_kv_with_str(stype='default'):
     return kv
 
 
-def test_row_sparse_pull():
-    kv = init_kv_with_str('row_sparse')
-    kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+def test_rsp_push_pull():
+    def check_rsp_push_pull(kv_type, is_push_cpu=True):
+        kv = init_kv_with_str('row_sparse', kv_type)
+        kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+        push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)]
+        kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs])
 
-    def check_row_sparse_pull(kv, count, ctx=default_context()):
-        num_rows = shape[0]
-        vals = []
-        row_ids = []
-        all_row_ids = np.arange(num_rows)
-        for i in range(count):
-            vals.append(mx.nd.zeros(shape, ctx=ctx).tostype('row_sparse'))
-            row_id = np.random.randint(num_rows, size=num_rows)
-            row_ids.append(mx.nd.array(row_id, dtype='int64'))
-        row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
-        vals_to_pull = vals[0] if len(vals) == 1 else vals
+        def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False):
+            num_rows = shape[0]
+            row_ids = []
+            all_row_ids = np.arange(num_rows)
+            vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)]
+            if is_same_rowid:
+                row_id = np.random.randint(num_rows, size=num_rows)
+                row_ids = [mx.nd.array(row_id, dtype='int64')] * count
+            elif use_slice:
+                total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows), dtype='int64')
+                row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)]
+            else:
+                for i in range(count):
+                    row_id = np.random.randint(num_rows, size=num_rows)
+                    row_ids.append(mx.nd.array(row_id, dtype='int64'))
+            row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids
+            vals_to_pull = vals[0] if len(vals) == 1 else vals
 
-        kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
-        for val, row_id in zip(vals, row_ids):
-            retained = val.asnumpy()
-            excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
-            for row in range(num_rows):
-                expected_val = np.zeros_like(retained[row])
-                expected_val += 0 if row in excluded_row_ids else 1
-                assert_almost_equal(retained[row], expected_val)
+            kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
+            for val, row_id in zip(vals, row_ids):
+                retained = val.asnumpy()
+                excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
+                for row in range(num_rows):
+                    expected_val = np.zeros_like(retained[row])
+                    expected_val += 0 if row in excluded_row_ids else 2
+                    assert_almost_equal(retained[row], expected_val)
 
-    check_row_sparse_pull(kv, 1, mx.gpu(0))
-    check_row_sparse_pull(kv, 4, mx.gpu(0))
+        check_rsp_pull(kv, 1, [mx.gpu(0)])
+        check_rsp_pull(kv, 1, [mx.cpu(0)])
+        check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)])
+        check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], is_same_rowid=True)
+        check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)])
+        check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], is_same_rowid=True)
+        check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True) 
+        check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True)
+
+    # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384
+    # check_rsp_push_pull('local')
+    check_rsp_push_pull('device')
+    check_rsp_push_pull('device', is_push_cpu=False)
 
 
 if __name__ == '__main__':
-    test_row_sparse_pull()
+    test_rsp_push_pull()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].