You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/20 21:31:15 UTC
[incubator-mxnet] branch master updated: [MXNET-404]
elemwise_add/sub between rsp and rsp on GPU (#11179)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ba9784d [MXNET-404] elemwise_add/sub between rsp and rsp on GPU (#11179)
ba9784d is described below
commit ba9784dec62d00cef77ddcc4b2fedb4bece753c9
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Wed Jun 20 14:31:09 2018 -0700
[MXNET-404] elemwise_add/sub between rsp and rsp on GPU (#11179)
* Support for elemwise_add/sub between rsp and rsp on GPU
* add extra test coverage for inplace cases
---
include/mxnet/ndarray.h | 2 +-
src/operator/tensor/elemwise_binary_op-inl.h | 18 +---
src/operator/tensor/elemwise_binary_op.h | 2 +-
src/operator/tensor/elemwise_binary_op_basic.cu | 134 ++++++++++++++++++++++++
tests/python/unittest/test_sparse_operator.py | 35 ++++++-
5 files changed, 171 insertions(+), 20 deletions(-)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ae96fd8..faffe1b 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -156,7 +156,7 @@ class NDArray {
}
/* \brief Check whether the two arrays are the same array */
- inline bool IsSame(const NDArray& other) {
+ inline bool IsSame(const NDArray& other) const {
return ptr_ == other.ptr_ &&
shape_ == other.shape_ &&
byte_offset_ == other.byte_offset_ &&
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h
index 911c369..878dfb2 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -31,22 +31,6 @@
namespace mxnet {
namespace op {
-template<typename OP>
-void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- const OpReqType req,
- const NDArray &output,
- const bool lhs_may_be_dense,
- const bool rhs_may_be_dense,
- const bool allow_inplace,
- const bool scatter) {
- LOG(FATAL) << "GPU not supported for RspRspOp";
-}
-
-
/*! \brief binary op handling for the following row sparse inputs/outputs
rsp, rsp -> rsp,
dns, rsp -> rsp,
@@ -622,7 +606,7 @@ void ElemwiseBinaryOp::DnsRspDnsOp(mshadow::Stream<xpu> *s,
const bool reverse) {
using namespace mshadow;
using namespace mxnet_op;
- CHECK_EQ(dns.storage_type(), kDefaultStorage);
+ CHECK(dns.storage_type() == kDefaultStorage || dns.storage_type() == kRowSparseStorage);
CHECK_EQ(rsp.storage_type(), kRowSparseStorage);
CHECK_EQ(output.data().Size(), dns.data().Size());
CHECK(req != kAddTo);
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index ad4b3e7..fbd79bb 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -420,7 +420,7 @@ class ElemwiseBinaryOp : public OpBase {
if (!dispatched && rsp && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) {
// rsp, rsp, ... -> rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
- dispatch_mode, dispatch_ex);
+ dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && csr && ContainsOnlyStorage(*in_attrs, kCSRStorage)) {
// csr, csr, ... -> csr
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu
index 5cdd894..ea8c1fb 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -22,12 +22,146 @@
* \file elemwise_binary_scalar_op.cu
* \brief GPU Implementation of unary function.
*/
+#include <cub/cub.cuh>
#include "./elemwise_binary_op.h"
#include "./elemwise_binary_op-inl.h"
namespace mxnet {
namespace op {
+template<typename OP>
+struct RspElemwiseKernel {
+ template<typename DType, typename IType>
+ static MSHADOW_XINLINE void Map(int i, DType* out, const IType* lookup_table,
+ const DType* data, const IType* indices,
+ const nnvm::dim_t nz_rows, const nnvm::dim_t num_cols) {
+ if (i < nz_rows * num_cols) {
+ const nnvm::dim_t row = i / num_cols;
+ const nnvm::dim_t col = i % num_cols;
+ const nnvm::dim_t out_row = lookup_table[indices[row]] - 1;
+ const nnvm::dim_t out_idx = out_row * num_cols + col;
+ out[out_idx] = OP::Map(out[out_idx], data[i]);
+ }
+ }
+};
+
+template<typename OP>
+void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ const OpReqType req,
+ const NDArray &output,
+ const bool lhs_may_be_dense,
+ const bool rhs_may_be_dense,
+ const bool allow_inplace,
+ const bool scatter) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace mshadow::expr;
+ using namespace rowsparse;
+
+ if (req == kNullOp) return;
+
+ CHECK(!scatter) << "scatter is not supported in RspRspOp on GPU yet...";
+ CHECK(lhs.storage_type() == kRowSparseStorage && rhs.storage_type() == kRowSparseStorage);
+ CHECK(output.storage_type() == kRowSparseStorage);
+ CHECK(req != kAddTo);
+
+ const nnvm::dim_t num_rows = output.shape()[0];
+ MSHADOW_TYPE_SWITCH(lhs.data().type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(lhs.aux_data(kIdx).type_flag_, IType, {
+ if (lhs.storage_initialized() && rhs.storage_initialized()) {
+ const nnvm::dim_t lhs_nz_rows = lhs.storage_shape()[0];
+ const nnvm::dim_t rhs_nz_rows = rhs.storage_shape()[0];
+ const nnvm::dim_t num_cols = lhs.data().Size() / lhs_nz_rows;
+ // Optimize for the case where one of the rsps is actually dense
+ if ((lhs_nz_rows == num_rows || rhs_nz_rows == num_rows) && req == kWriteInplace) {
+ const NDArray& dns = (output.IsSame(lhs)) ? lhs : rhs;
+ const NDArray& rsp = (output.IsSame(lhs)) ? rhs : lhs;
+ const bool reverse = !(lhs_nz_rows == num_rows);
+ ElemwiseBinaryOp::DnsRspDnsOp<gpu, OP>(s, attrs, ctx, dns, rsp, req, output, reverse);
+ return;
+ }
+ CHECK(req == kWriteTo) << "Should be kWriteTo but got " << req;
+ const TBlob& lhs_indices = lhs.aux_data(kIdx);
+ const TBlob& rhs_indices = rhs.aux_data(kIdx);
+ size_t common_row_table_bytes = num_rows * sizeof(IType);
+ IType* common_row_table = NULL;
+ void* temp_storage_ptr = NULL;
+ size_t temp_storage_bytes = 0;
+ cub::DeviceScan::InclusiveSum(temp_storage_ptr,
+ temp_storage_bytes,
+ common_row_table,
+ common_row_table,
+ num_rows,
+ mshadow::Stream<gpu>::GetStream(s));
+ size_t workspace_bytes = common_row_table_bytes + temp_storage_bytes;
+ Tensor<gpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes), s);
+ common_row_table = reinterpret_cast<IType*>(workspace.dptr_);
+ temp_storage_ptr = workspace.dptr_ + common_row_table_bytes;
+ mxnet_op::Kernel<set_zero, gpu>::Launch(s, num_rows, common_row_table);
+ Kernel<MarkRspRowFlgKernel, gpu>::Launch(
+ s, lhs_nz_rows, common_row_table, lhs_indices.dptr<IType>(), lhs_nz_rows);
+ Kernel<MarkRspRowFlgKernel, gpu>::Launch(
+ s, rhs_nz_rows, common_row_table, rhs_indices.dptr<IType>(), rhs_nz_rows);
+ cub::DeviceScan::InclusiveSum(temp_storage_ptr,
+ temp_storage_bytes,
+ common_row_table,
+ common_row_table,
+ num_rows,
+ mshadow::Stream<gpu>::GetStream(s));
+ nnvm::dim_t nnr_out = 0;
+ CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t),
+ cudaMemcpyDeviceToHost));
+ output.CheckAndAlloc({mshadow::Shape1(nnr_out)});
+ Kernel<FillRspRowIdxKernel, gpu>::Launch(
+ s, num_rows, output.aux_data(kIdx).dptr<IType>(), common_row_table, num_rows);
+ Kernel<set_zero, gpu>::Launch(s, nnr_out * num_cols, output.data().dptr<DType>());
+ Kernel<RspElemwiseKernel<mshadow_op::plus>, gpu>::Launch(
+ s, lhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table,
+ lhs.data().dptr<DType>(), lhs_indices.dptr<IType>(), lhs_nz_rows, num_cols);
+ Kernel<RspElemwiseKernel<OP>, gpu>::Launch(
+ s, rhs_nz_rows * num_cols, output.data().dptr<DType>(), common_row_table,
+ rhs.data().dptr<DType>(), rhs_indices.dptr<IType>(), rhs_nz_rows, num_cols);
+ } else {
+ if (lhs.storage_initialized()) {
+ if (req == kWriteTo) {
+ output.CheckAndAlloc({lhs.aux_shape(kIdx)});
+ Copy(output.data().FlatTo1D<gpu, DType>(),
+ lhs.data().FlatTo1D<gpu, DType>(), s);
+ Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(),
+ lhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s);
+ } else if (req == kWriteInplace && rhs.IsSame(output)) {
+ LOG(FATAL) << "Inplace on an empty rhs is not supported";
+ }
+ } else if (rhs.storage_initialized()) {
+ if (req == kWriteTo) {
+ output.CheckAndAlloc({rhs.aux_shape(kIdx)});
+ } else if (req == kWriteInplace && lhs.IsSame(output)) {
+ LOG(FATAL) << "Inplace on an empty lhs is not supported";
+ }
+ if (std::is_same<OP, mshadow_op::minus>::value) {
+ Kernel<op_with_req<mshadow_op::negation, kWriteTo>, gpu>::Launch(
+ s, rhs.data().Size(), output.data().dptr<DType>(), rhs.data().dptr<DType>());
+ } else if (req == kWriteTo) {
+ Copy(output.data().FlatTo1D<gpu, DType>(),
+ rhs.data().FlatTo1D<gpu, DType>(), s);
+ }
+ if (req == kWriteTo) {
+ Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(),
+ rhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s);
+ }
+ } else {
+ FillZerosRspImpl(s, output);
+ }
+ }
+ });
+ });
+}
+
NNVM_REGISTER_OP(elemwise_add)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 62f5f3e..70af2bc 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -360,7 +360,8 @@ def test_elemwise_binary_ops():
verbose=False)
if ((lhs_stype is 'default' and rhs_stype is 'row_sparse') or
- (lhs_stype is 'default' and rhs_stype is 'csr')):
+ (lhs_stype is 'default' and rhs_stype is 'csr') or
+ (lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (rhs_density == 0.0)):
test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=l),
lambda l, r: l + r,
@@ -371,6 +372,38 @@ def test_elemwise_binary_ops():
force_grad_overlap=force_grad_overlap,
lhs_density=lhs_density, rhs_density=rhs_density,
verbose=False)
+ test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
+ lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l),
+ lambda l, r: l - r,
+ lambda outg, l, r: (outg, -outg),
+ lhs_grad_stype, rhs_grad_stype,
+ ograd_density=ograd_density,
+ force_lr_overlap=force_lr_overlap,
+ force_grad_overlap=force_grad_overlap,
+ lhs_density=lhs_density, rhs_density=rhs_density,
+ verbose=False)
+
+ if ((lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (lhs_density == 0.0)):
+ test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
+ lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=r),
+ lambda l, r: l + r,
+ lambda outg, l, r: (outg, outg),
+ lhs_grad_stype, rhs_grad_stype,
+ ograd_density=ograd_density,
+ force_lr_overlap=force_lr_overlap,
+ force_grad_overlap=force_grad_overlap,
+ lhs_density=lhs_density, rhs_density=rhs_density,
+ verbose=False)
+ test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
+ lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l),
+ lambda l, r: l - r,
+ lambda outg, l, r: (outg, -outg),
+ lhs_grad_stype, rhs_grad_stype,
+ ograd_density=ograd_density,
+ force_lr_overlap=force_lr_overlap,
+ force_grad_overlap=force_grad_overlap,
+ lhs_density=lhs_density, rhs_density=rhs_density,
+ verbose=False)
test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
lambda l, r: mx.sym.sparse.elemwise_sub(l, r),