You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/20 22:08:05 UTC
[incubator-mxnet] branch master updated: sparse support for
take(csr, axis=0) (#12889)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0ba259f sparse support for take(csr, axis=0) (#12889)
0ba259f is described below
commit 0ba259fe5fbc2c2ba355a9f6faf1a20cc57c6d07
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sat Oct 20 15:07:48 2018 -0700
sparse support for take(csr, axis=0) (#12889)
* initial commit
* add test cases for mode
* fix bug
* add comment
* more comments
---
src/operator/tensor/indexing_op.cc | 144 +++++++++++++++++++++++++++
src/operator/tensor/indexing_op.h | 65 ++++++++++++
tests/python/unittest/test_sparse_ndarray.py | 18 ++++
3 files changed, 227 insertions(+)
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index b663ef0..98e2536 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -87,6 +87,142 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
}
}
+template<bool clip>
+struct CsrTakeDataKernel {
+ /*!
+ * \brief Map function for general case of take grad
+ * \param tid global thread id
+ * \param out_idx ptr to out idx
+ * \param out_data ptr to out data
+ * \param out_indptr ptr to out indptr
+ * \param src_data ptr to original csr data
+ * \param src_idx ptr to original csr idx
+ * \param idx_ptr ptr to indices
+ * \param num_rows maximum number of rows in src array
+ */
+ template<typename IType, typename DType, typename RType>
+ MSHADOW_XINLINE static void Map(int tid, RType* out_idx, DType* out_data,
+ const RType* out_indptr, const RType* src_idx,
+ const DType* src_data, const RType* src_indptr,
+ const IType* idx_ptr, const nnvm::dim_t num_rows) {
+ nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid]);
+ // clip mode
+ if (clip) {
+ if (idx < 0) idx = 0;
+ if (idx >= num_rows) idx = num_rows - 1;
+ } else {
+ // wrap mode
+ idx = idx % num_rows;
+ idx += (idx < 0) ? num_rows : 0;
+ }
+ int row_nnz = src_indptr[idx + 1] - src_indptr[idx];
+ for (int i = 0; i < row_nnz; i++) {
+ out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i];
+ out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i];
+ }
+ }
+};
+
+template<bool clip>
+struct CsrTakeRowCountKernel {
+ /*!
+ * \brief Map function for general case of take grad
+ * \param tid global thread id
+ * \param out_indptr ptr to out indptr
+ * \param src_indptr ptr to original csr indptr
+ * \param idx_ptr ptr to indices
+ * \param num_rows maximum number of rows in src array
+ */
+ template<typename IType, typename RType>
+ MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
+ const RType* src_indptr, const IType* idx_ptr,
+ const nnvm::dim_t num_rows) {
+ if (tid == 0) out_indptr[0] = 0;
+ nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
+ // clip mode
+ if (clip) {
+ if (idx < 0) idx = 0;
+ if (idx >= num_rows) idx = num_rows - 1;
+ } else {
+ // wrap mode
+ idx = idx % num_rows;
+ idx += (idx < 0) ? num_rows : 0;
+ }
+ out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx];
+ }
+};
+
+template<>
+void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
+ const OpContext& ctx,
+ const TBlob& idx,
+ const NDArray& arr,
+ OpReqType req,
+ const NDArray& out) {
+ using namespace csr;
+ using namespace mxnet_op;
+ using nnvm::dim_t;
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ if (req == kNullOp) return;
+ if (!arr.storage_initialized()) {
+ FillZerosCsrImpl(s, out);
+ return;
+ }
+ CHECK_EQ(idx.shape_.ndim(), 1U)
+ << "Take with CSR array only supports one-dimensional indices. "
+ << idx.shape_.ndim() << " dimensional input is given instead";
+ CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)";
+ auto axis = params.axis;
+ CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)";
+ CHECK(params.mode == take_::kClip || params.mode == take_::kWrap)
+ << "mode = " << params.mode << " is not supported";
+ const dim_t num_rows = out.shape()[0];
+ const dim_t max_num_rows = arr.shape()[0];
+ out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});
+
+ MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
+ MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
+ MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
+ RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+ const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
+ const IType* idx_ptr = idx.dptr<IType>();
+ // gather per row nnz information for output
+ bool clip = params.mode == take_::kClip;
+ if (clip) {
+ Kernel<CsrTakeRowCountKernel<true>, cpu>::Launch(s, num_rows + 1,
+ out_indptr, src_indptr, idx_ptr, max_num_rows);
+ } else {
+ Kernel<CsrTakeRowCountKernel<false>, cpu>::Launch(s, num_rows + 1,
+ out_indptr, src_indptr, idx_ptr, max_num_rows);
+ }
+ // calculate prefix sum with single thread
+ for (dim_t i = 0; i < num_rows; i++) {
+ out_indptr[i + 1] += out_indptr[i];
+ }
+ // total number of non-zero rows
+ const dim_t nnz = out_indptr[num_rows];
+ if (nnz == 0) {
+ FillZerosCsrImpl(s, out);
+ return;
+ }
+ out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)});
+ out.CheckAndAllocData(Shape1(nnz));
+ RType* out_idx = out.aux_data(kIdx).dptr<RType>();
+ DType* out_data = out.data().dptr<DType>();
+ const RType* src_idx = arr.aux_data(kIdx).dptr<RType>();
+ const DType* src_data = arr.data().dptr<DType>();
+ // copy indices and data for output
+ if (clip) {
+ Kernel<CsrTakeDataKernel<true>, cpu>::Launch(s, num_rows, out_idx,
+ out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
+ } else {
+ Kernel<CsrTakeDataKernel<false>, cpu>::Launch(s, num_rows, out_idx,
+ out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
+ }
+ });
+ });
+ });
+}
template<>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
@@ -400,6 +536,7 @@ dimension of data (by default outer-most one as axis=0) indexed by indices, and
in an output tensor of rank q + (r - 1).
Examples::
+
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
@@ -431,6 +568,11 @@ Examples::
[[ 3., 4.],
[ 5., 6.]]]
+The storage type of ``take`` output depends upon the input storage type:
+
+ - take(default, default) = default
+ - take(csr, default, axis=0) = csr
+
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
@@ -441,11 +583,13 @@ Examples::
})
.set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
+.set_attr<FInferStorageType>("FInferStorageType", TakeOpForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", TakeOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_take", n, ograds,
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 1daf0a2..5282a7e 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -755,6 +755,71 @@ inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}
+// storage type inference function for take
+inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int& idx_stype = in_attrs->at(take_::kIdx);
+ const int& arr_stype = in_attrs->at(take_::kArr);
+ int& out_stype = out_attrs->at(take_::kOut);
+ bool dispatched = false;
+ const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+ if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) {
+ // dns, dns -> dns
+ dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
+ if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
+ param.axis == 0 && (param.mode == take_::kWrap || param.mode == take_::kClip)) {
+ // take(dns, csr, axis=0) -> csr
+ dispatched = storage_type_assign(&out_stype, kCSRStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
+}
+
+
+template<typename xpu>
+void TakeOpForwardCsrImpl(const TakeParam& params,
+ const OpContext& ctx,
+ const TBlob& idx,
+ const NDArray& arr,
+ OpReqType req,
+ const NDArray& output);
+
+
+template<typename xpu>
+void TakeOpForwardEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(req[take_::kOut], kWriteTo);
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ const NDArray& idx = inputs[take_::kIdx];
+ const NDArray& arr = inputs[take_::kArr];
+ const NDArray& out = outputs[take_::kOut];
+ const auto idx_stype = idx.storage_type();
+ const auto arr_stype = arr.storage_type();
+ const auto out_stype = out.storage_type();
+ const auto params = nnvm::get<TakeParam>(attrs.parsed);
+ if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
+ out_stype == kCSRStorage) {
+ // dns, csr -> csr
+ TakeOpForwardCsrImpl<xpu>(params, ctx, idx.data(), arr, req[0], out);
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+}
+
template<typename xpu>
void TakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 875dea7..8dd250c 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -1013,6 +1013,24 @@ def test_sparse_fc():
# test FC with row_sparse weight w/ density=1, csr data (fallback)
check_sparse_fc(5, 10, 8, 'csr')
+@with_seed()
+def test_sparse_take():
+ def check_sparse_take(density, mode):
+ data_shape = rand_shape_2d()
+ idx_shape = (np.random.randint(low=1, high=10),)
+ data = rand_ndarray(data_shape, 'csr', density=density)
+ idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
+ result = mx.nd.take(data, idx, mode=mode)
+ data_np = data.asnumpy()
+ idx_np = idx.asnumpy().astype('int32')
+ expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
+ assert_almost_equal(result.asnumpy(), expected_result)
+ densities = [0, 0.5, 1]
+ modes = ['clip', 'wrap']
+ for d in densities:
+ for m in modes:
+ check_sparse_take(d, m)
+
if __name__ == '__main__':
import nose
nose.runmodule()