You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/03/05 13:54:52 UTC
[incubator-mxnet] branch master updated: Non-blocking
row_sparse_pull. Fix incorrect indices generated by device
kvstore.row_sparse_pull (#9887)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 02dd89a Non-blocking row_sparse_pull. Fix incorrect indices generated by device kvstore.row_sparse_pull (#9887)
02dd89a is described below
commit 02dd89a68f659c2a9b0bff62c54c50dff1151f6b
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Mon Mar 5 21:54:48 2018 +0800
Non-blocking row_sparse_pull. Fix incorrect indices generated by device kvstore.row_sparse_pull (#9887)
* nonblocking Kvstore (#195)
* draft
* rm use_copy. fix dist kvstore. TODO: fix dtype
* fix dtype, shape
* remove reshape
* cleanup
* fix compilation
* rsp draft
* update param name
* doc update and small refactoring
* minor updates
* enhance test case with 2-D rowids
* update gpu tests
* rewrite gpu unique kernels
* update gpu tests
* update reshape test/
* fix lint
* update test for py3
---
python/mxnet/kvstore.py | 2 +-
src/kvstore/comm.h | 196 ++++++++++------------------------
src/kvstore/kvstore_dist.h | 34 +++---
src/kvstore/kvstore_local.h | 58 ++++++----
src/kvstore/kvstore_utils.cc | 17 +--
src/kvstore/kvstore_utils.cu | 96 +++++++++--------
src/kvstore/kvstore_utils.h | 9 +-
tests/nightly/dist_sync_kvstore.py | 4 +-
tests/python/gpu/test_kvstore_gpu.py | 18 +++-
tests/python/unittest/test_kvstore.py | 2 +-
10 files changed, 196 insertions(+), 240 deletions(-)
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 890c902..221b94f 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -321,7 +321,7 @@ class KVStore(object):
other pull actions.
row_ids : NDArray or list of NDArray
- The row_ids for which to pull for each value. Each row_id is an 1D NDArray \
+ The row_ids for which to pull for each value. Each row_id is an 1-D NDArray \
whose values don't have to be unique nor sorted.
Examples
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index da2d03d..3085966 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -65,14 +65,14 @@ class Comm {
/**
* \brief broadcast src to dst[i] with target row_ids for every i
+ * \param key the identifier key for the stored ndarray
+ * \param src the source row_sparse ndarray to broadcast
* \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast,
- where the row_ids are expected to be unique and sorted
- * \param use_copy if set to true, directly copy src to dst[i] without looking up the
- provided row_ids
+ where the row_ids are expected to be unique and sorted in row_id.data()
+ * \param priority the priority of the operation
*/
virtual void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
- const bool use_copy,
const int priority) = 0;
/**
@@ -209,7 +209,6 @@ class CommCPU : public Comm {
void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
- const bool use_copy,
const int priority) override {
using namespace mshadow;
CHECK_EQ(src.storage_type(), kRowSparseStorage)
@@ -219,107 +218,30 @@ class CommCPU : public Comm {
for (size_t i = 0; i < dst.size(); ++i) {
NDArray* out = dst[i].first;
NDArray row_id = dst[i].second;
- if (use_copy) {
- CopyFromTo(src, out, priority);
- } else {
- CHECK_EQ(out->storage_type(), kRowSparseStorage)
- << "BroadcastRowSparse expects row_sparse dst NDArray";
- CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
- << "BroadcastRowSparse with row_indices on gpu context not supported";
- // retain according to unique indices
- const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0])
- || (row_id.dtype() != out->aux_type(rowsparse::kIdx))
- || (out->ctx().dev_mask() != Context::kGPU);
- if (use_sparse_retain) { // use sparse_retain op
- const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
- NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
- src.ctx(), true, src.dtype(), src.aux_types()) : *out;
- Engine::Get()->PushAsync(
- [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
- const TBlob& indices = row_id.data();
- NDArray temp = out_cpu; // get rid of const qualifier
- op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
- src, indices, kWriteTo,
- &temp);
- on_complete();
- }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
- FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
- if (is_to_gpu) {
- CopyFromTo(out_cpu, out, priority);
- }
- } else { // direct copy rows
- Engine::Get()->PushAsync(
- [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
- CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
- src, row_id, out);
- // wait for GPU operations to complete
- rctx.get_stream<gpu>()->Wait();
- on_complete();
- }, out->ctx(), {src.var(), row_id.var()}, {out->var()},
- FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
- }
- }
+ CHECK_EQ(out->storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row_sparse dst NDArray";
+ CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
+ << "BroadcastRowSparse with row_indices on gpu context not supported";
+ // retain according to unique indices
+ const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
+ NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
+ src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+ Engine::Get()->PushAsync(
+ [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ const TBlob& indices = row_id.data();
+ NDArray temp = retained_cpu; // get rid the of const qualifier
+ op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
+ src, indices, kWriteTo,
+ &temp);
+ on_complete();
+ }, Context::CPU(), {src.var(), row_id.var()}, {retained_cpu.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+ // if retained_cpu == out, CopyFromTo will ignore the copy operation
+ CopyFromTo(retained_cpu, out, priority);
}
}
private:
- /*!
- * \brief When src is a rsp with full rows,
- * simply copy retained rows directly from cpu to gpu
- * without invoking sparse_retain op.
- */
- void CopyRetainedRowsToGPU(mshadow::Stream<cpu>* cpu_stream,
- mshadow::Stream<gpu>* gpu_stream,
- const NDArray& src,
- const NDArray& indices,
- NDArray* dst) {
-#if MXNET_USE_CUDA == 1
- CHECK_EQ(src.storage_type(), kRowSparseStorage)
- << "CopyRetainedRowsToGPU expects row-sparse src NDArray";
- CHECK_EQ(src.ctx().dev_mask(), Context::kCPU)
- << "CopyRetainedRowsToGPU with src on gpu context not supported";
- CHECK_EQ(src.storage_shape()[0], src.shape()[0])
- << "CopyRetainedRowsToGPU only supports src rsp with full rows";
- CHECK_EQ(indices.storage_type(), kDefaultStorage);
- CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU);
- CHECK_EQ(dst->storage_type(), kRowSparseStorage);
- CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU);
- CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx))
- << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)";
- if (!src.storage_initialized() || indices.data().Size() == 0U) {
- op::FillZerosRspImpl(gpu_stream, *dst);
- return;
- }
- using namespace mshadow;
-
- const TBlob& src_data = src.data();
- const TBlob& idx_data = indices.data();
- const size_t row_length = src.shape().ProdShape(1, src.shape().ndim());
- const size_t num_rows_retained = idx_data.Size();
- dst->CheckAndAlloc({Shape1(num_rows_retained)});
- TBlob dst_data = dst->data();
- TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx);
- MSHADOW_TYPE_SWITCH(src.dtype(), DType, {
- MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, {
- // copy idx array
- Tensor<gpu, 1, IType> dst_idx_tensor = dst_idx_data.FlatTo1D<gpu, IType>(gpu_stream);
- const Tensor<cpu, 1, IType> idx_tensor = idx_data.FlatTo1D<cpu, IType>(cpu_stream);
- Copy(dst_idx_tensor, idx_tensor, gpu_stream);
- // copy src data
- const Tensor<cpu, 2, DType> src_data_tensor = src_data.get_with_shape<cpu, 2, DType>(
- Shape2(src_data.shape_[0], row_length), cpu_stream);
- Tensor<gpu, 2, DType> dst_data_tensor = dst_data.get_with_shape<gpu, 2, DType>(
- Shape2(dst_data.shape_[0], row_length), gpu_stream);
- for (size_t i = 0; i < num_rows_retained; ++i) {
- Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream);
- }
- })
- })
-#else
- LOG(FATAL) << "GPU not enabled";
-#endif
- }
-
// reduce sum into val[0]
inline void ReduceSumCPU(const std::vector<NDArray> &in_data) {
MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, {
@@ -632,7 +554,6 @@ class CommDevice : public Comm {
void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
- const bool use_copy,
const int priority) override {
CHECK_EQ(src.storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row-sparse src NDArray";
@@ -640,46 +561,39 @@ class CommDevice : public Comm {
for (size_t i = 0; i < dst.size(); ++i) {
NDArray* out = dst[i].first;
NDArray row_id = dst[i].second;
- if (use_copy) {
- CopyFromTo(src, out, priority);
- } else {
- CHECK_EQ(out->storage_type(), kRowSparseStorage)
- << "BroadcastRowSparse expects row_sparse dst NDArray";
-
- const bool is_diff_ctx = out->ctx() != src.ctx();
- NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
- src.ctx(), true, out->dtype(), out->aux_types()) : *out;
-
- CHECK_EQ(row_id.ctx(), src.ctx())
- << "row_id and src are expected to be on the same context";
-
- Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
- NDArray temp = out_gpu;
- const TBlob& indices = row_id.data();
- switch (temp.ctx().dev_mask()) {
- case cpu::kDevMask: {
- mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
- src, indices, kWriteTo, &temp);
- break;
- }
+ CHECK_EQ(out->storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row_sparse dst NDArray";
+ CHECK_EQ(row_id.ctx(), src.ctx())
+ << "row_id and src are expected to be on the same context";
+ // retain according to indices
+ const bool is_diff_ctx = out->ctx() != src.ctx();
+ NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
+ src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+ Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ const TBlob& indices = row_id.data();
+ using namespace mxnet::common;
+ NDArray temp = out_gpu;
+ switch (temp.ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+ src, indices, kWriteTo, &temp);
+ break;
+ }
#if MXNET_USE_CUDA
- case gpu::kDevMask: {
- mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
- src, indices, kWriteTo, &temp);
- // wait for GPU operations to complete
- rctx.get_stream<gpu>()->Wait();
- break;
- }
-#endif
- default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ case gpu::kDevMask: {
+ SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+ src, indices, kWriteTo, &temp);
+ // wait for GPU operations to complete
+ rctx.get_stream<gpu>()->Wait();
+ break;
}
- on_complete();
- }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
- FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
- if (is_diff_ctx) {
- CopyFromTo(out_gpu, out, priority);
- }
- }
+#endif
+ default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ }
+ on_complete();
+ }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+ CopyFromTo(out_gpu, out, priority);
}
}
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index e01cc42..7ab5783 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -279,24 +279,20 @@ class KVStoreDist : public KVStoreLocal {
}
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
- size_t num_rows = 0;
- // TODO(haibin) refactor this for loop
for (size_t i = 0; i < num_vals; i++) {
auto &row_id = target_val_rowids[i].second;
- NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
- CopyFromTo(row_id, &indices, 0);
- Unique(&indices, priority);
- target_val_rowids[i].second = indices;
- num_rows += indices.shape().Size();
- }
- if (num_vals > 1) {
- // TODO(haibin) aggregate over all unique indices
- LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet";
- } else {
- auto& indices = target_val_rowids[0].second;
- PullRowSparse_(key, recv_buf, indices, priority);
- comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority);
+ target_val_rowids[i].second = Unique(row_id, pinned_ctx_, 0);
}
+ CHECK_EQ(num_vals, 1) << "RowSparsePull with multiple values is not supported yet";
+ NDArray& indices = target_val_rowids[0].second;
+ PullRowSparse_(key, recv_buf, indices, priority);
+ // The recv_buf contains values pulled from remote server with unique indices.
+ // Directly broadcast w/o rowids if num_vals == 1
+ auto get_val = [](const std::pair<NDArray*, NDArray>& p) { return p.first; };
+ std::vector<NDArray*> grouped_val(grouped_val_rowid.size());
+ std::transform(grouped_val_rowid.begin(), grouped_val_rowid.end(),
+ grouped_val.begin(), get_val);
+ comm_->Broadcast(key, recv_buf, grouped_val, priority);
}
}
@@ -462,10 +458,12 @@ class KVStoreDist : public KVStoreLocal {
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
// allocate memory for the buffer
- size_t num_rows = indices.shape().Size();
+ CHECK_EQ(indices.dtype(), mshadow::kInt64);
+ const TBlob idx_data = indices.data();
+ size_t num_rows = idx_data.shape_.Size();
recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
real_t* data = recv_buf.data().dptr<real_t>();
- const auto offsets = indices.data().dptr<int64_t>();
+ const auto offsets = idx_data.dptr<int64_t>();
const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
const int64_t size = num_rows * unit_len;
// convert to ps keys in row sparse format
@@ -480,7 +478,7 @@ class KVStoreDist : public KVStoreLocal {
// because after pull is done, the callback function returns and locks are released.
// at this point, later functions may access the indices variable while copy happens
mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
- indices.data().FlatTo1D<cpu, int64_t>());
+ idx_data.FlatTo1D<cpu, int64_t>());
CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
static_cast<int>(DataHandleType::kRowSparsePushPull),
[vals, cb]() { delete vals; cb(); });
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 7b3d6fa..69fb37e 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -35,6 +35,7 @@
#include <algorithm>
#include "./comm.h"
#include "./kvstore_utils.h"
+#include "../ndarray/ndarray_function.h"
namespace mxnet {
namespace kvstore {
@@ -226,12 +227,9 @@ class KVStoreLocal : public KVStore {
const size_t num_vals = target_val_rowids.size();
for (size_t j = 0; j < num_vals; j++) {
auto &row_id = target_val_rowids[j].second;
- NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
- CopyFromTo(row_id, &indices, 0);
- Unique(&indices, priority);
- target_val_rowids[j].second = indices;
+ target_val_rowids[j].second = Unique(row_id, local.ctx(), 0);
}
- comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
+ comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], priority);
}
}
@@ -354,42 +352,62 @@ class KVStoreLocal : public KVStore {
}
}
- /**
- * \brief sort and get unique values.
+ /*
+ * \brief Compute the unique values in data and store them in ascending order
+ * in an int64_t row_sparse ndarray on ctx. The opeartion is async. The result
+ * row_sparse ndarray stores the unique values in out.data(). The aux_data()
+ * contains values that are not necessarily meaningful and should be ignored.
+ * \param data the input data
+ * \param ctx the target context
+ * \param priority the priority of the operation
*/
- void Unique(NDArray *out, int priority) {
- Resource rsc = ResourceManager::Get()->Request(out->ctx(),
+ NDArray Unique(const NDArray &data, Context ctx, int priority) {
+ // create kRowSparseStorage output ndarray
+ const size_t num_elements = data.shape().Size();
+ NDArray out(kRowSparseStorage, mshadow::Shape2(num_elements, 1),
+ ctx, true, mshadow::kInt64);
+ bool diff_ctx = data.ctx() != ctx;
+ NDArray data_in_ctx = diff_ctx ? NDArray(data.shape(), ctx, true, data.dtype()) : data;
+ // if data == data_in_ctx, CopyFromTo is smart enough to skip the copy
+ CopyFromTo(data, &data_in_ctx, priority);
+ Resource rsc = ResourceManager::Get()->Request(out.ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
+ // GPU requires temp resources
+ std::vector<Engine::VarHandle> mutate_vars{out.var()};
+ if (out.ctx().dev_mask() == gpu::kDevMask) mutate_vars.emplace_back(rsc.var);
Engine::Get()->PushAsync(
- [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
- NDArray *output = out;
- CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
- nnvm::dim_t size = out->shape()[0];
- switch (out->ctx().dev_mask()) {
+ [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ // copy data.data() to out.data()
+ out.CheckAndAlloc({mshadow::Shape1(num_elements)});
+ TBlob out_data = out.data();
+ switch (out.ctx().dev_mask()) {
case cpu::kDevMask: {
mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
- UniqueImpl(rsc, s, output, size);
+ ndarray::Copy<cpu, cpu>(data_in_ctx.data(), &out_data,
+ ctx, ctx, rctx);
+ UniqueImpl(rsc, s, out);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
- UniqueImpl(rsc, s, output, size);
+ ndarray::Copy<gpu, gpu>(data_in_ctx.data(), &out_data,
+ ctx, ctx, rctx);
+ UniqueImpl(rsc, s, out);
// wait for GPU operations to complete
s->Wait();
break;
}
#endif
default:
- LOG(FATAL) << "GPU not enabled.";
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
- }, out->ctx(), {}, {out->var(), rsc.var},
+ }, out.ctx(), {data_in_ctx.var()}, mutate_vars,
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
- out->WaitToRead();
+ return out;
}
-
/// reducer and broadcaster
Comm* comm_;
/// pinned context
diff --git a/src/kvstore/kvstore_utils.cc b/src/kvstore/kvstore_utils.cc
index 9e14d8b..e187b0c 100644
--- a/src/kvstore/kvstore_utils.cc
+++ b/src/kvstore/kvstore_utils.cc
@@ -28,15 +28,18 @@
namespace mxnet {
namespace kvstore {
-
template<>
void UniqueImpl<cpu>(const Resource& rsc, mshadow::Stream<cpu> *s,
- NDArray *out, nnvm::dim_t size) {
- MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
- IType *dptr = out->data().dptr<IType>();
- common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
- size_t num_unique_idx = std::unique(dptr, dptr + size) - dptr;
- *out = out->Reshape(mshadow::Shape1(num_unique_idx));
+ const NDArray& out) {
+ const size_t num_elements = out.shape().Size();
+ CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected";
+ MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, {
+ IType *dptr = out.data().dptr<IType>();
+ common::ParallelSort(dptr, dptr + num_elements,
+ engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
+ const size_t num_selected_out = std::unique(dptr, dptr + num_elements) - dptr;
+ // set the shape of data/aux_data according to the number of unique values
+ out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out));
});
}
diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu
index 00f316f..438fe29 100644
--- a/src/kvstore/kvstore_utils.cu
+++ b/src/kvstore/kvstore_utils.cu
@@ -40,63 +40,73 @@
namespace mxnet {
namespace kvstore {
-
template<typename IType>
size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream<gpu> *s,
- IType *dptr, nnvm::dim_t size) {
-#ifndef SORT_WITH_THRUST
+ IType *dptr, const size_t size) {
+ // estimate unique temp space. The first byte is reserved to store the number
+ // of unique values selected
+ const size_t num_selected_bytes = sizeof(size_t);
+ size_t unique_temp_bytes = 0;
+ size_t *null_ptr = nullptr;
+ size_t *null_dptr = nullptr;
+ cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+ cub::DeviceSelect::Unique(NULL, unique_temp_bytes, null_dptr, null_dptr,
+ null_ptr, size, stream);
+ // estimate sort temp space
+ const size_t sort_output_bytes = size * sizeof(IType);
size_t sort_temp_bytes = 0;
- cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes,
- dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
- mshadow::Tensor<gpu, 1, char> sort_space = rsc
- .get_space_typed<gpu, 1, char>(
- mshadow::Shape1(sort_temp_bytes), s);
- void *sort_temp_storage = static_cast<void*>(sort_space.dptr_);
- cub::DeviceRadixSort::SortKeys(sort_temp_storage, sort_temp_bytes,
- dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+#ifndef SORT_WITH_THRUST
+ // The least-significant bit index (inclusive) needed for key comparison
+ const int begin_bit = 0;
+ // The most-significant bit index (exclusive) needed for key comparison
+ const int end_bit = sizeof(IType) * 8;
+ cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes, null_dptr, null_dptr,
+ size, begin_bit, end_bit, stream);
#else
- thrust::sort(thrust::cuda::par.on(mshadow::Stream<gpu>::GetStream(s)),
- dptr, dptr + size, thrust::greater<IType>());
+ // sort_temp_bytes remains 0 because thrust request memory by itself
#endif
- size_t unique_temp_bytes = 0;
- mshadow::Tensor<gpu, 1, char> dummy_space = rsc
- .get_space_typed<gpu, 1, char>(
- mshadow::Shape1(sizeof(size_t)), s);
- size_t *dummy_ptr = reinterpret_cast<size_t*>(dummy_space.dptr_);
- cub::DeviceSelect::Unique(NULL, unique_temp_bytes, dptr, dptr,
- dummy_ptr, size, mshadow::Stream<gpu>::GetStream(s));
-
- mshadow::Tensor<gpu, 1, char> unique_space = rsc
- .get_space_typed<gpu, 1, char>(
- mshadow::Shape1((unique_temp_bytes + sizeof(size_t) + 7) / 8 * 8), s);
-
- void *unique_temp_storage = static_cast<void*>(
- unique_space.dptr_);
- size_t *d_num_selected_out = reinterpret_cast<size_t*>(
- unique_space.dptr_ + (unique_temp_bytes + 7) / 8 * 8);
-
- cub::DeviceSelect::Unique(unique_temp_storage, unique_temp_bytes, dptr, dptr,
- d_num_selected_out, size, mshadow::Stream<gpu>::GetStream(s));
-
+ // request temp storage
+ const size_t total_workspace = num_selected_bytes + sort_output_bytes +
+ std::max(sort_temp_bytes, unique_temp_bytes);
+ mshadow::Tensor<gpu, 1, char> workspace = rsc
+ .get_space_typed<gpu, 1, char>(mshadow::Shape1(total_workspace), s);
+ // temp space layout: num_selected_ptr, sort_output_bytes, unique/sort_temp_storage
+ size_t* num_selected_ptr = reinterpret_cast<size_t*>(workspace.dptr_);
+ IType* sort_output_ptr = reinterpret_cast<IType*>(workspace.dptr_ + num_selected_bytes);
+ void *temp_storage = static_cast<void*>(workspace.dptr_ +
+ num_selected_bytes + sort_output_bytes);
+ // execute the sort kernel
+#ifndef SORT_WITH_THRUST
+ cub::DeviceRadixSort::SortKeys(temp_storage, sort_temp_bytes, dptr, sort_output_ptr,
+ size, begin_bit, end_bit, stream);
+#else
+ thrust::sort(thrust::cuda::par.on(stream),
+ dptr, dptr + size, thrust::greater<IType>());
+ CUDA_CALL(cudaMemcpy(sort_output_ptr, dptr, sort_output_bytes,
+ cudaMemcpyDeviceToDevice));
+#endif
+ // execute unique kernel
+ cub::DeviceSelect::Unique(temp_storage, unique_temp_bytes, sort_output_ptr, dptr,
+ num_selected_ptr, size, stream);
+ // retrieve num selected unique values
size_t num_selected_out = 0;
- CUDA_CALL(cudaMemcpy(&num_selected_out, d_num_selected_out, sizeof(size_t),
+ CUDA_CALL(cudaMemcpy(&num_selected_out, num_selected_ptr, num_selected_bytes,
cudaMemcpyDeviceToHost));
return num_selected_out;
}
-/*!
- * \brief sort and get unique values.
- */
template<>
void UniqueImpl<gpu>(const Resource& rsc, mshadow::Stream<gpu> *s,
- NDArray *out, nnvm::dim_t size) {
- MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
- IType *dptr = out->data().dptr<IType>();
- size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, size);
- *out = out->Reshape(mshadow::Shape1(num_selected_out));
+ const NDArray &out) {
+ const size_t num_elements = out.shape().Size();
+ CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected";
+ MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, {
+ IType *dptr = out.data().dptr<IType>();
+ size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, num_elements);
+ // set the shape of data/aux_data according to the number of unique values
+ out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out));
});
}
-
} // namespace kvstore
} // namespace mxnet
diff --git a/src/kvstore/kvstore_utils.h b/src/kvstore/kvstore_utils.h
index 8255619..ee173b4 100644
--- a/src/kvstore/kvstore_utils.h
+++ b/src/kvstore/kvstore_utils.h
@@ -35,12 +35,15 @@ namespace kvstore {
/*!
- * \brief sort and get unique values.
+ * \brief compute unique and sorted values in a row_sparse ndarray.
+ * \param rsc Temp resource for computation
+ * \param s Stream
+ * \param out Input and output ndarray. The ndarray stores the
+ * unique elements in out.data().
*/
template<typename xpu>
void UniqueImpl(const Resource& rsc, mshadow::Stream<xpu> *s,
- NDArray *out, nnvm::dim_t size);
-
+ const NDArray& out);
} // namespace kvstore
} // namespace mxnet
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index df85fe5..3a3c916 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -99,7 +99,7 @@ def test_sync_push_pull():
# select a random subset of rows this worker is interested in
num_rows = shape[0]
row_ids_np = np.random.randint(num_rows, size=num_rows)
- row_ids = mx.nd.array(row_ids_np, dtype='int64')
+ row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
# perform pull
val = mx.nd.zeros(shape, stype='row_sparse')
kv.row_sparse_pull('9', out=val, row_ids=row_ids)
@@ -170,7 +170,7 @@ def test_sync_push_pull():
rnd.seed(my_rank)
num_rows = big_shape[0]
row_ids_np = np.random.randint(num_rows, size=num_rows)
- row_ids = mx.nd.array(row_ids_np)
+ row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
# perform pull
val = mx.nd.zeros(big_shape, stype='row_sparse')
kv.row_sparse_pull('100', out=val, row_ids=row_ids)
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 5fd3097..1fc3a4d 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -57,14 +57,14 @@ def test_rsp_push_pull():
vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)]
if is_same_rowid:
row_id = np.random.randint(num_rows, size=num_rows)
- row_ids = [mx.nd.array(row_id, dtype='int64')] * count
+ row_ids = [mx.nd.array(row_id)] * count
elif use_slice:
- total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows), dtype='int64')
+ total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows))
row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)]
else:
for i in range(count):
row_id = np.random.randint(num_rows, size=num_rows)
- row_ids.append(mx.nd.array(row_id, dtype='int64'))
+ row_ids.append(mx.nd.array(row_id))
row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids
vals_to_pull = vals[0] if len(vals) == 1 else vals
@@ -91,6 +91,16 @@ def test_rsp_push_pull():
check_rsp_push_pull('device')
check_rsp_push_pull('device', is_push_cpu=False)
+def test_rsp_push_pull_large_rowid():
+ num_rows = 793470
+ val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu())
+ kv = mx.kv.create('device')
+ kv.init('a', val)
+ out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu())
+ kv.push('a', val)
+ kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64'))
+ assert(out.indices.shape[0] == num_rows)
if __name__ == '__main__':
- test_rsp_push_pull()
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py
index 6bab06c..c56046c 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -76,7 +76,7 @@ def test_row_sparse_pull():
for i in range(count):
vals.append(mx.nd.zeros(shape).tostype('row_sparse'))
row_id = np.random.randint(num_rows, size=num_rows)
- row_ids.append(mx.nd.array(row_id))
+ row_ids.append(mx.nd.array(row_id).reshape((2, num_rows//2)))
row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
vals_to_pull = vals[0] if len(vals) == 1 else vals
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.