You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/15 00:08:24 UTC
[incubator-mxnet] branch master updated: rsp push and rsp pull for
comm device, used in kvstore('device') (#8732)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 786e376 rsp push and rsp pull for comm device, used in kvstore('device') (#8732)
786e376 is described below
commit 786e376651c7f6f9b05b7758d091b22a7a72ef55
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Mon Jan 15 08:08:18 2018 +0800
rsp push and rsp pull for comm device, used in kvstore('device') (#8732)
* comm device for rsp push and pull
* update
* update test
* optimization for same row_ids
* add stream->wait
* remove using space
* fix race of rsc and extend ElementwiseSum to rsp cases
* add log fatal in ElementwiseSum
* direct copy rows if full rsp and put all outputs on ctx of src
* trigger
* fix
* simplify copy
* move check same rowids to utils and add test for same rowids case
* remove direct copy row by row
* fix checkSameRowid
* gpu unique impl draft
* unique
* update
* fix windows build
* trigger windows build
* support single rowid with multiple vals
* address comments
* check same row_ids and copy in fronted
* revise names and disable test for local kvstore
---
python/mxnet/kvstore.py | 20 ++++-
src/common/utils.cc | 10 +++
src/common/utils.cu | 10 +++
src/common/utils.h | 11 +++
src/kvstore/comm.h | 130 +++++++++++++++++++++++--------
src/kvstore/kvstore_local.h | 51 +++++++-----
src/{common => kvstore}/utils.cc | 24 +++---
src/kvstore/utils.cu | 102 ++++++++++++++++++++++++
src/{common/utils.cc => kvstore/utils.h} | 36 +++++----
src/ndarray/ndarray.cc | 82 ++++++++++++-------
tests/python/gpu/test_kvstore_gpu.py | 74 +++++++++++-------
11 files changed, 414 insertions(+), 136 deletions(-)
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index b2a4bea..890c902 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -298,7 +298,8 @@ class KVStore(object):
def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
- from the store with specified row_ids.
+ from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \
+ is invoked just once and the result is broadcast to all the rest of outputs.
`row_sparse_pull` is executed asynchronously after all previous
`pull`/`row_sparse_pull` calls and the last `push` call for the
@@ -349,7 +350,17 @@ class KVStore(object):
"""
assert(out is not None)
assert(row_ids is not None)
- ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+ if isinstance(row_ids, NDArray):
+ row_ids = [row_ids]
+ assert(isinstance(row_ids, list)), \
+ "row_ids should be NDArray or list of NDArray"
+ first_out = out
+ # whether row_ids are the same
+ single_rowid = False
+ if len(row_ids) == 1 and isinstance(out, list):
+ single_rowid = True
+ first_out = [out[0]]
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, first_out)
_, crow_ids, _ = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), \
"the number of row_ids doesn't match the number of values"
@@ -359,6 +370,11 @@ class KVStore(object):
else:
check_call(_LIB.MXKVStorePullRowSparse(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+ # the result can be copied to other devices without invoking row_sparse_pull
+ # if the indices are the same
+ if single_rowid:
+ for out_i in out[1:]:
+ out[0].copyto(out_i)
def set_gradient_compression(self, compression_params):
""" Specifies type of low-bit quantization for gradient compression \
diff --git a/src/common/utils.cc b/src/common/utils.cc
index 784fcf8..9fe46d9 100644
--- a/src/common/utils.cc
+++ b/src/common/utils.cc
@@ -24,6 +24,7 @@
#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
namespace common {
@@ -35,6 +36,15 @@ void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
}
template<>
+void SparseRetainOpForwardRspWrapper<cpu>(mshadow::Stream<cpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd) {
+ mxnet::op::SparseRetainOpForwardRspImpl<cpu>(s, input_nd, idx_data, req, output_nd);
+}
+
+template<>
void CastStorageDispatch<cpu>(const OpContext& ctx,
const NDArray& input,
const NDArray& output) {
diff --git a/src/common/utils.cu b/src/common/utils.cu
index c6e2bf8..0937d7a 100644
--- a/src/common/utils.cu
+++ b/src/common/utils.cu
@@ -24,6 +24,7 @@
#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
namespace common {
@@ -35,6 +36,15 @@ void CheckFormatWrapper<gpu>(const RunContext &rctx, const NDArray &input,
}
template<>
+void SparseRetainOpForwardRspWrapper<gpu>(mshadow::Stream<gpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd) {
+ mxnet::op::SparseRetainOpForwardRspImpl<gpu>(s, input_nd, idx_data, req, output_nd);
+}
+
+template<>
void CastStorageDispatch<gpu>(const OpContext& ctx,
const NDArray& input,
const NDArray& output) {
diff --git a/src/common/utils.h b/src/common/utils.h
index 6f7e452..4bb8024 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -213,7 +213,18 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input,
}
}
+/*! \brief Pick rows specified by user input index array from a row sparse ndarray
+ * and save them in the output sparse ndarray.
+ */
+template<typename xpu>
+void SparseRetainOpForwardRspWrapper(mshadow::Stream<xpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd);
+/* \brief Casts tensor storage type to the new type.
+ */
template<typename xpu>
void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 5429df7..d41fa64 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -34,6 +34,7 @@
#include "gradient_compression.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/sparse_retain-inl.h"
+#include "./utils.h"
namespace mxnet {
namespace kvstore {
/**
@@ -176,17 +177,17 @@ class CommCPU : public Comm {
reduce[i] = buf.copy_buf[i];
const_vars[i] = reduce[i].var();
}
- auto result = buf.merged;
+ NDArray result = buf.merged;
+ Resource rsc = ResourceManager::Get()->Request(result.ctx(),
+ ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
- [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
- Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
- ResourceRequest(ResourceRequest::kTempSpace));
is_serial_push_?
ReduceSumCPUExSerial(reduce, &out)
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
on_complete();
- }, Context::CPU(), const_vars, {result.var()},
+ }, Context::CPU(), const_vars, {result.var(), rsc.var},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
}
@@ -491,11 +492,7 @@ class CommDevice : public Comm {
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
- if (stype == kDefaultStorage) {
- sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
- } else {
- LOG(FATAL) << "storage type " << stype << " not implemented for device yet";
- }
+ sorted_key_attrs_.emplace_back(key, shape, dtype, stype);
}
void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -528,26 +525,42 @@ class CommDevice : public Comm {
InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
- CopyFromTo(src[0], &(buf.merged), priority);
- reduce[0] = buf.merged;
- if (buf.copy_buf.empty()) {
- // TODO(mli) this results in large device memory usage for huge ndarray,
- // such as the largest fullc in VGG. consider to do segment reduce with
- // NDArray.Slice or gpu direct memory access. for the latter, we need to
- // remove some ctx check, and also it reduces 20% perf
- buf.copy_buf.resize(src.size()-1);
+ const NDArrayStorageType stype = buf.merged.storage_type();
+ if (stype == kDefaultStorage) {
+ CopyFromTo(src[0], &(buf.merged), priority);
+ reduce[0] = buf.merged;
+
+ if (buf.copy_buf.empty()) {
+ // TODO(mli) this results in large device memory usage for huge ndarray,
+ // such as the largest fullc in VGG. consider to do segment reduce with
+ // NDArray.Slice or gpu direct memory access. for the latter, we need to
+ // remove some ctx check, and also it reduces 20% perf
+ buf.copy_buf.resize(src.size()-1);
+ for (size_t i = 0; i < src.size()-1; ++i) {
+ buf.copy_buf[i] = NDArray(
+ buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ }
+ }
for (size_t i = 0; i < src.size()-1; ++i) {
- buf.copy_buf[i] = NDArray(
- buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
+ reduce[i+1] = buf.copy_buf[i];
+ }
+ } else {
+ if (buf.copy_buf.empty()) {
+ buf.copy_buf.resize(src.size());
+ for (size_t j = 0; j < src.size(); ++j) {
+ buf.copy_buf[j] = NDArray(
+ buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
+ true, buf.merged.dtype());
+ }
+ }
+ for (size_t i = 0; i < src.size(); ++i) {
+ CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
+ reduce[i] = buf.copy_buf[i];
}
}
- for (size_t i = 0; i < src.size()-1; ++i) {
- CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
- reduce[i+1] = buf.copy_buf[i];
- }
-
- ElementwiseSum(reduce, &buf.merged);
+ ElementwiseSum(reduce, &buf.merged, priority);
return buf.merged;
}
@@ -621,7 +634,53 @@ class CommDevice : public Comm {
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
- LOG(FATAL) << "Not implemented yet";
+ CHECK_EQ(src.storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row-sparse src NDArray";
+
+ for (size_t i = 0; i < dst.size(); ++i) {
+ NDArray* out = dst[i].first;
+ NDArray row_id = dst[i].second;
+ if (use_copy) {
+ CopyFromTo(src, out, priority);
+ } else {
+ CHECK_EQ(out->storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row_sparse dst NDArray";
+
+ const bool is_diff_ctx = out->ctx() != src.ctx();
+ NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
+ src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+
+ CHECK_EQ(row_id.ctx(), src.ctx())
+ << "row_id and src are expected to be on the same context";
+
+ Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ NDArray temp = out_gpu;
+ const TBlob& indices = row_id.data();
+ switch (temp.ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+ src, indices, kWriteTo, &temp);
+ break;
+ }
+#if MXNET_USE_CUDA
+ case gpu::kDevMask: {
+ mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+ src, indices, kWriteTo, &temp);
+ // wait for GPU operations to complete
+ rctx.get_stream<gpu>()->Wait();
+ break;
+ }
+#endif
+ default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ }
+ on_complete();
+ }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+ if (is_diff_ctx) {
+ CopyFromTo(out_gpu, out, priority);
+ }
+ }
+ }
}
private:
@@ -667,7 +726,7 @@ class CommDevice : public Comm {
#endif
}
- using KeyAttrs = std::tuple<int, TShape, int>;
+ using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
@@ -680,9 +739,10 @@ class CommDevice : public Comm {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
- int key = std::get<0>(sorted_key_attrs_[i]);
- TShape s = std::get<1>(sorted_key_attrs_[i]);
- int type = std::get<2>(sorted_key_attrs_[i]);
+ const int key = std::get<0>(sorted_key_attrs_[i]);
+ const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
+ const int type = std::get<2>(sorted_key_attrs_[i]);
+ const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
@@ -693,8 +753,12 @@ class CommDevice : public Comm {
min_size = size;
}
}
- buf.merged = NDArray(s, ctx, false, type);
- ctx_info[ctx.dev_id].second += s.Size();
+ if (stype == kDefaultStorage) {
+ buf.merged = NDArray(shape, ctx, false, type);
+ } else {
+ buf.merged = NDArray(stype, shape, ctx, true, type);
+ }
+ ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 1bb84fd..78b6c8f 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -34,6 +34,7 @@
#include <functional>
#include <algorithm>
#include "./comm.h"
+#include "./utils.h"
namespace mxnet {
namespace kvstore {
@@ -223,12 +224,12 @@ class KVStoreLocal : public KVStore {
<< "PullRowSparse expects row_sparse src NDArray";
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
- for (size_t i = 0; i < num_vals; i++) {
- auto &row_id = target_val_rowids[i].second;
- NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
+ for (size_t j = 0; j < num_vals; j++) {
+ auto &row_id = target_val_rowids[j].second;
+ NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
CopyFromTo(row_id, &indices, 0);
Unique(&indices, priority);
- target_val_rowids[i].second = indices;
+ target_val_rowids[j].second = indices;
}
comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
}
@@ -354,29 +355,41 @@ class KVStoreLocal : public KVStore {
}
/**
- * \brief sort and get unique values. Output is expected to be on cpu_pinned context
+ * \brief sort and get unique values.
*/
- void Unique(NDArray *out, int priority = 0) {
- CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
- << "Unique expects input with `pinned_ctx_`";
+ void Unique(NDArray *out, int priority) {
+ Resource rsc = ResourceManager::Get()->Request(out->ctx(),
+ ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
- [out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray *output = out;
CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
- const auto size = out->shape()[0];
- auto out_data = output->data();
- MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, {
- auto dptr = output->data().dptr<IType>();
- common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
- auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
- *output = output->Reshape(mshadow::Shape1(num_unique_idx));
- });
+ nnvm::dim_t size = out->shape()[0];
+ switch (out->ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
+ UniqueImpl(rsc, s, output, size);
+ break;
+ }
+ #if MXNET_USE_CUDA
+ case gpu::kDevMask: {
+ mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+ UniqueImpl(rsc, s, output, size);
+ // wait for GPU operations to complete
+ s->Wait();
+ break;
+ }
+ #endif
+ default:
+ LOG(FATAL) << "GPU not enabled.";
+ }
on_complete();
- }, pinned_ctx_, {}, {out->var()},
- FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
+ }, out->ctx(), {}, {out->var(), rsc.var},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
out->WaitToRead();
}
+
/// reducer and broadcaster
Comm* comm_;
/// pinned context
diff --git a/src/common/utils.cc b/src/kvstore/utils.cc
similarity index 64%
copy from src/common/utils.cc
copy to src/kvstore/utils.cc
index 784fcf8..c22553f 100644
--- a/src/common/utils.cc
+++ b/src/kvstore/utils.cc
@@ -23,23 +23,23 @@
*/
#include "./utils.h"
-#include "../operator/tensor/cast_storage-inl.h"
+#include "../common/utils.h"
namespace mxnet {
-namespace common {
+namespace kvstore {
-template<>
-void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
- const TBlob &err_cpu, const bool full_check) {
- CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
-}
template<>
-void CastStorageDispatch<cpu>(const OpContext& ctx,
- const NDArray& input,
- const NDArray& output) {
- mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
+void UniqueImpl<cpu>(const Resource& rsc, mshadow::Stream<cpu> *s,
+ NDArray *out, nnvm::dim_t size) {
+ MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
+ IType *dptr = out->data().dptr<IType>();
+ common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
+ size_t num_unique_idx = std::unique(dptr, dptr + size) - dptr;
+ *out = out->Reshape(mshadow::Shape1(num_unique_idx));
+ });
}
-} // namespace common
+
+} // namespace kvstore
} // namespace mxnet
diff --git a/src/kvstore/utils.cu b/src/kvstore/utils.cu
new file mode 100644
index 0000000..088a49e
--- /dev/null
+++ b/src/kvstore/utils.cu
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file utils.cu
+ * \brief gpu implementation of util functions
+ */
+#if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 44
+// Many CUDA 8 compilers other than V8.0.44 crash on Windows
+#pragma warning("Potential crash on CUDA compiler detected. Switching sorting from CUB to Thrust")
+#define SORT_WITH_THRUST
+#include <thrust/device_ptr.h>
+#include <thrust/sort.h>
+#include <thrust/system/cuda/execution_policy.h>
+#else
+#undef SORT_WITH_THRUST
+#endif
+#include "./utils.h"
+#include "../common/utils.h"
+#include <cub/cub.cuh>
+#include <mxnet/resource.h>
+
+namespace mxnet {
+namespace kvstore {
+
+
+template<typename IType>
+size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream<gpu> *s,
+ IType *dptr, nnvm::dim_t size) {
+#ifndef SORT_WITH_THRUST
+ size_t sort_temp_bytes = 0;
+ cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes,
+ dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+ mshadow::Tensor<gpu, 1, char> sort_space = rsc
+ .get_space_typed<gpu, 1, char>(
+ mshadow::Shape1(sort_temp_bytes), s);
+ void *sort_temp_storage = static_cast<void*>(sort_space.dptr_);
+ cub::DeviceRadixSort::SortKeys(sort_temp_storage, sort_temp_bytes,
+ dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream<gpu>::GetStream(s));
+#else
+ thrust::sort(thrust::cuda::par.on(mshadow::Stream<gpu>::GetStream(s)),
+ dptr, dptr + size, thrust::greater<IType>());
+#endif
+ size_t unique_temp_bytes = 0;
+ mshadow::Tensor<gpu, 1, char> dummy_space = rsc
+ .get_space_typed<gpu, 1, char>(
+ mshadow::Shape1(sizeof(size_t)), s);
+ size_t *dummy_ptr = reinterpret_cast<size_t*>(dummy_space.dptr_);
+ cub::DeviceSelect::Unique(NULL, unique_temp_bytes, dptr, dptr,
+ dummy_ptr, size, mshadow::Stream<gpu>::GetStream(s));
+
+ mshadow::Tensor<gpu, 1, char> unique_space = rsc
+ .get_space_typed<gpu, 1, char>(
+ mshadow::Shape1((unique_temp_bytes + sizeof(size_t) + 7) / 8 * 8), s);
+
+ void *unique_temp_storage = static_cast<void*>(
+ unique_space.dptr_);
+ size_t *d_num_selected_out = reinterpret_cast<size_t*>(
+ unique_space.dptr_ + (unique_temp_bytes + 7) / 8 * 8);
+
+ cub::DeviceSelect::Unique(unique_temp_storage, unique_temp_bytes, dptr, dptr,
+ d_num_selected_out, size, mshadow::Stream<gpu>::GetStream(s));
+
+ size_t num_selected_out = 0;
+ CUDA_CALL(cudaMemcpy(&num_selected_out, d_num_selected_out, sizeof(size_t),
+ cudaMemcpyDeviceToHost));
+ return num_selected_out;
+}
+
+/*!
+ * \brief sort and get unique values.
+ */
+template<>
+void UniqueImpl<gpu>(const Resource& rsc, mshadow::Stream<gpu> *s,
+ NDArray *out, nnvm::dim_t size) {
+ MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, {
+ IType *dptr = out->data().dptr<IType>();
+ size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, size);
+ *out = out->Reshape(mshadow::Shape1(num_selected_out));
+ });
+}
+
+
+} // namespace kvstore
+} // namespace mxnet
diff --git a/src/common/utils.cc b/src/kvstore/utils.h
similarity index 57%
copy from src/common/utils.cc
copy to src/kvstore/utils.h
index 784fcf8..7547345 100644
--- a/src/common/utils.cc
+++ b/src/kvstore/utils.h
@@ -18,28 +18,30 @@
*/
/*!
- * \file utils.cc
- * \brief cpu implementation of util functions
+ * \file utils.h
+ * \brief Basic utilility functions.
*/
+#ifndef MXNET_KVSTORE_UTILS_H_
+#define MXNET_KVSTORE_UTILS_H_
-#include "./utils.h"
-#include "../operator/tensor/cast_storage-inl.h"
+#include <dmlc/logging.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/resource.h>
+#include <utility>
+#include <vector>
namespace mxnet {
-namespace common {
+namespace kvstore {
-template<>
-void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
- const TBlob &err_cpu, const bool full_check) {
- CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
-}
-template<>
-void CastStorageDispatch<cpu>(const OpContext& ctx,
- const NDArray& input,
- const NDArray& output) {
- mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
-}
+/*!
+ * \brief sort and get unique values.
+ */
+template<typename xpu>
+void UniqueImpl(const Resource& rsc, mshadow::Stream<xpu> *s,
+ NDArray *out, nnvm::dim_t size);
-} // namespace common
+} // namespace kvstore
} // namespace mxnet
+
+#endif // MXNET_KVSTORE_UTILS_H_
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 8a3bb8d..4db314f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -623,36 +623,66 @@ void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priori
// important: callback must always capture by value
NDArray ret = *out;
- switch (out->ctx().dev_mask()) {
- case cpu::kDevMask: {
- Engine::Get()->PushSync([source, ret](RunContext ctx) {
- std::vector<TBlob> source_tblob(source.size());
- for (size_t i = 0; i < source.size(); ++i) {
- source_tblob[i] = source[i].data();
- }
- TBlob tmp = ret.data();
- ndarray::ElementwiseSum<cpu>(source_tblob, &tmp, ctx);
- }, out->ctx(), const_vars, {ret.var()},
- FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
- break;
+ const NDArrayStorageType stype = ret.storage_type();
+
+ if (stype == kDefaultStorage) {
+ switch (out->ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ Engine::Get()->PushSync([source, ret](RunContext ctx) {
+ std::vector<TBlob> source_tblob(source.size());
+ for (size_t i = 0; i < source.size(); ++i) {
+ source_tblob[i] = source[i].data();
+ }
+ TBlob tmp = ret.data();
+ ndarray::ElementwiseSum<cpu>(source_tblob, &tmp, ctx);
+ }, out->ctx(), const_vars, {ret.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
+ break;
+ }
+#if MXNET_USE_CUDA
+ case gpu::kDevMask: {
+ Engine::Get()->PushSync([source, ret](RunContext ctx) {
+ std::vector<TBlob> source_tblob(source.size());
+ for (size_t i = 0; i < source.size(); ++i) {
+ source_tblob[i] = source[i].data();
+ }
+ TBlob tmp = ret.data();
+ ndarray::ElementwiseSum<gpu>(source_tblob, &tmp, ctx);
+ // Wait GPU kernel to complete
+ ctx.get_stream<gpu>()->Wait();
+ }, out->ctx(), const_vars, {ret.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("DenseElementwiseSum"));
+ break;
+ }
+#endif
+ default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
+ } else if (stype == kRowSparseStorage) {
+ Resource rsc = ResourceManager::Get()->Request(ret.ctx(),
+ ResourceRequest(ResourceRequest::kTempSpace));
+
+ Engine::Get()->PushSync(
+ [source, ret, rsc](RunContext rctx) {
+ NDArray result = ret;
+ switch (ret.ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, source, &result);
+ break;
+ }
#if MXNET_USE_CUDA
- case gpu::kDevMask: {
- Engine::Get()->PushSync([source, ret](RunContext ctx) {
- std::vector<TBlob> source_tblob(source.size());
- for (size_t i = 0; i < source.size(); ++i) {
- source_tblob[i] = source[i].data();
+ case gpu::kDevMask: {
+ mxnet::ndarray::ElementwiseSum(rctx.get_stream<gpu>(), rsc, source, &result);
+ // wait for GPU operations to complete
+ rctx.get_stream<gpu>()->Wait();
+ break;
}
- TBlob tmp = ret.data();
- ndarray::ElementwiseSum<gpu>(source_tblob, &tmp, ctx);
- // Wait GPU kernel to complete
- ctx.get_stream<gpu>()->Wait();
- }, out->ctx(), const_vars, {ret.var()},
- FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME);
- break;
- }
#endif
- default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ }
+ }, ret.ctx(), const_vars, {ret.var(), rsc.var},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("RowSparseElementwiseSum"));
+ } else {
+ LOG(FATAL) << "Not implemented for storage_type " << common::stype_string(stype);
}
}
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 20528be..3249c98 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -26,9 +26,9 @@ keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']
-def init_kv_with_str(stype='default'):
+def init_kv_with_str(stype='default', kv_type='local'):
"""init kv """
- kv = mx.kv.create()
+ kv = mx.kv.create(kv_type)
# single
kv.init('a', mx.nd.zeros(shape, stype=stype))
# list
@@ -36,34 +36,54 @@ def init_kv_with_str(stype='default'):
return kv
-def test_row_sparse_pull():
- kv = init_kv_with_str('row_sparse')
- kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+def test_rsp_push_pull():
+ def check_rsp_push_pull(kv_type, is_push_cpu=True):
+ kv = init_kv_with_str('row_sparse', kv_type)
+ kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+ push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)]
+ kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs])
- def check_row_sparse_pull(kv, count, ctx=default_context()):
- num_rows = shape[0]
- vals = []
- row_ids = []
- all_row_ids = np.arange(num_rows)
- for i in range(count):
- vals.append(mx.nd.zeros(shape, ctx=ctx).tostype('row_sparse'))
- row_id = np.random.randint(num_rows, size=num_rows)
- row_ids.append(mx.nd.array(row_id, dtype='int64'))
- row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
- vals_to_pull = vals[0] if len(vals) == 1 else vals
+ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False):
+ num_rows = shape[0]
+ row_ids = []
+ all_row_ids = np.arange(num_rows)
+ vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)]
+ if is_same_rowid:
+ row_id = np.random.randint(num_rows, size=num_rows)
+ row_ids = [mx.nd.array(row_id, dtype='int64')] * count
+ elif use_slice:
+ total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows), dtype='int64')
+ row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)]
+ else:
+ for i in range(count):
+ row_id = np.random.randint(num_rows, size=num_rows)
+ row_ids.append(mx.nd.array(row_id, dtype='int64'))
+ row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids
+ vals_to_pull = vals[0] if len(vals) == 1 else vals
- kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
- for val, row_id in zip(vals, row_ids):
- retained = val.asnumpy()
- excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
- for row in range(num_rows):
- expected_val = np.zeros_like(retained[row])
- expected_val += 0 if row in excluded_row_ids else 1
- assert_almost_equal(retained[row], expected_val)
+ kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
+ for val, row_id in zip(vals, row_ids):
+ retained = val.asnumpy()
+ excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
+ for row in range(num_rows):
+ expected_val = np.zeros_like(retained[row])
+ expected_val += 0 if row in excluded_row_ids else 2
+ assert_almost_equal(retained[row], expected_val)
- check_row_sparse_pull(kv, 1, mx.gpu(0))
- check_row_sparse_pull(kv, 4, mx.gpu(0))
+ check_rsp_pull(kv, 1, [mx.gpu(0)])
+ check_rsp_pull(kv, 1, [mx.cpu(0)])
+ check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)])
+ check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], is_same_rowid=True)
+ check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)])
+ check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], is_same_rowid=True)
+ check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True)
+ check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True)
+
+ # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384
+ # check_rsp_push_pull('local')
+ check_rsp_push_pull('device')
+ check_rsp_push_pull('device', is_push_cpu=False)
if __name__ == '__main__':
- test_row_sparse_pull()
+ test_rsp_push_pull()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].