You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/05 17:57:37 UTC
[incubator-mxnet] branch master updated: [MXNET-354] Support
elemwise_add/sub between dense and row sparse tensors on CPU/GPU (#10645)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 602523e [MXNET-354] Support elemwise_add/sub between dense and row sparse tensors on CPU/GPU (#10645)
602523e is described below
commit 602523e643d315f2293442689000dc41e20ba3d9
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Sat May 5 10:57:30 2018 -0700
[MXNET-354] Support elemwise_add/sub between dense and row sparse tensors on CPU/GPU (#10645)
* support elemwise_add/sub between dense and rsp
* refactor to add support for gpu
---
src/operator/tensor/elemwise_binary_op-inl.h | 681 ++++++++++++++----------
src/operator/tensor/elemwise_binary_op.h | 107 ++--
src/operator/tensor/elemwise_binary_op_basic.cc | 4 +
src/operator/tensor/elemwise_binary_op_basic.cu | 8 +-
src/operator/tensor/elemwise_scatter_op.h | 8 +-
tests/python/unittest/test_sparse_operator.py | 13 +
6 files changed, 497 insertions(+), 324 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h
index 54b7aa6..2cf6481 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -31,6 +31,22 @@
namespace mxnet {
namespace op {
+template<typename OP>
+void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ const OpReqType req,
+ const NDArray &output,
+ const bool lhs_may_be_dense,
+ const bool rhs_may_be_dense,
+ const bool allow_inplace,
+ const bool scatter) {
+ LOG(FATAL) << "GPU not supported for RspRspOp";
+}
+
+
/*! \brief binary op handling for the following row sparse inputs/outputs
rsp, rsp -> rsp,
dns, rsp -> rsp,
@@ -38,7 +54,7 @@ namespace op {
dns, rsp -> dns,
rsp, dns -> dns,
*/
-template<typename DType, typename IType, typename OP>
+template<typename OP>
void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
@@ -52,199 +68,217 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
const bool scatter) {
using namespace mshadow;
using namespace mshadow::expr;
-
+ const NDArray& rsp = lhs.storage_type() == kRowSparseStorage ? lhs : rhs;
const bool is_dense_result = output.storage_type() == kDefaultStorage;
const bool lhs_is_dense = lhs.storage_type() == kDefaultStorage;
const bool rhs_is_dense = rhs.storage_type() == kDefaultStorage;
CHECK(!lhs_is_dense || lhs_may_be_dense) << "rvalue cannot be dense";
CHECK(!rhs_is_dense || rhs_may_be_dense) << "rvalue cannot be dense";
CHECK(!lhs_is_dense || !rhs_is_dense);
- // Only one item at most may be dense (lhs, rhs or result)
- if (rhs_is_dense) {
- // For right-side dense, in order to have sparse output, lhs input zero should
- // always output zero
- CHECK(fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < 1e-4f);
- CHECK(!is_dense_result); // Currently not handled
- }
- if (lhs_is_dense) {
- // For left-side dense, in order to have sparse output, lhs input zero should
- // always output zero
- CHECK(fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < 1e-4f);
- CHECK(!is_dense_result); // Currently not handled
- }
+ MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
+ MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+ // Only one item at most may be dense (lhs, rhs or result)
+ if (rhs_is_dense) {
+ // For right-side dense, in order to have sparse output, lhs input zero should
+ // always output zero
+ CHECK(fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < 1e-4f);
+ CHECK(!is_dense_result); // Currently not handled
+ }
+ if (lhs_is_dense) {
+ // For left-side dense, in order to have sparse output, lhs input zero should
+ // always output zero
+ CHECK(fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < 1e-4f);
+ CHECK(!is_dense_result); // Currently not handled
+ }
- // Memory Estimation: This is (roughly) the number of result rows. We may still
- // need to subtract the number of common rows
- bool lhs_in_place = false, rhs_in_place = false;
- const size_t num_rows_l = lhs_is_dense ? lhs.shape()[0] : lhs.aux_shape(rowsparse::kIdx).Size();
- const size_t num_rows_r = rhs_is_dense ? rhs.shape()[0] : rhs.aux_shape(rowsparse::kIdx).Size();
- if (is_dense_result) {
- output.CheckAndAlloc();
- } else {
- if (rhs_is_dense || scatter) {
- output.CheckAndAlloc({mshadow::Shape1(num_rows_l)});
- } else if (lhs_is_dense) {
- output.CheckAndAlloc({mshadow::Shape1(num_rows_r)});
- } else {
- lhs_in_place = IsSameArray(lhs, output);
- rhs_in_place = IsSameArray(rhs, output);
- if (!lhs_in_place && !rhs_in_place) {
- output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)});
+ // Memory Estimation: This is (roughly) the number of result rows. We may still
+ // need to subtract the number of common rows
+ bool lhs_in_place = false, rhs_in_place = false;
+ const size_t num_rows_l = lhs_is_dense ? lhs.shape()[0] :
+ lhs.aux_shape(rowsparse::kIdx).Size();
+ const size_t num_rows_r = rhs_is_dense ? rhs.shape()[0] :
+ rhs.aux_shape(rowsparse::kIdx).Size();
+ if (is_dense_result) {
+ output.CheckAndAlloc();
} else {
- CHECK_EQ(allow_inplace, true);
- CHECK_EQ(is_dense_result, false);
- if (lhs_in_place) {
- // For in-place, zero L-value must always be zero output
- DCHECK(fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < DType(1e-3));
+ if (rhs_is_dense || scatter) {
+ output.CheckAndAlloc({mshadow::Shape1(num_rows_l)});
+ } else if (lhs_is_dense) {
+ output.CheckAndAlloc({mshadow::Shape1(num_rows_r)});
} else {
- // For in-place, zero R-value must always be zero output
- DCHECK(fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < DType(1e-3));
+ lhs_in_place = IsSameArray(lhs, output);
+ rhs_in_place = IsSameArray(rhs, output);
+ if (!lhs_in_place && !rhs_in_place) {
+ output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)});
+ } else {
+ CHECK_EQ(allow_inplace, true);
+ CHECK_EQ(is_dense_result, false);
+ if (lhs_in_place) {
+ // For in-place, zero L-value must always be zero output
+ DCHECK(fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < DType(1e-3));
+ } else {
+ // For in-place, zero R-value must always be zero output
+ DCHECK(fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < DType(1e-3));
+ }
+ }
}
}
- }
- }
- // Indices
- const Tensor<cpu, 1, IType> indices_l = lhs_is_dense
- ? Tensor<cpu, 1, IType>()
- : lhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
- const Tensor<cpu, 1, IType> indices_r = rhs_is_dense
- ? Tensor<cpu, 1, IType>()
- : rhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
- Tensor<cpu, 1, IType> indices_out = is_dense_result
- ? Tensor<cpu, 1, IType>()
- : output.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
-
- // Data
- // TODO(cjolivier01): Change to get_with_shape() calls
- const Tensor<cpu, 2, DType> data_l = AsRowise2D<DType>(s, lhs.data());
- const Tensor<cpu, 2, DType> data_r = AsRowise2D<DType>(s, rhs.data());
- Tensor<cpu, 2, DType> out = AsRowise2D<DType>(s, output.data());
-
- size_t iter_l = 0;
- size_t iter_r = 0;
- size_t iter_out = 0;
- int32_t num_common_rows = 0;
-
- if (is_dense_result) {
- if (!num_rows_l && !num_rows_r) {
- const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
- iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
- }
- }
-
- while (iter_l < num_rows_l && iter_r < num_rows_r) {
- IType idx_l = lhs_is_dense ? indices_r[iter_r] : indices_l[iter_l];
- IType idx_r = rhs_is_dense ? idx_l : indices_r[iter_r];
- if (lhs_in_place) {
- while (idx_r < idx_l && ++iter_r < num_rows_r) {
- idx_r = indices_r[iter_r];
- }
- if (iter_r >= num_rows_r) {
- break;
- }
- } else if (rhs_in_place) {
- while (idx_l < idx_r && ++iter_l < num_rows_l) {
- idx_l = indices_l[iter_l];
- }
- if (iter_l >= num_rows_l) {
- break;
+ // Indices
+ const Tensor<cpu, 1, IType> indices_l = lhs_is_dense ?
+ Tensor<cpu, 1, IType>() :
+ lhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
+ const Tensor<cpu, 1, IType> indices_r = rhs_is_dense ?
+ Tensor<cpu, 1, IType>() :
+ rhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
+ Tensor<cpu, 1, IType> indices_out = is_dense_result ?
+ Tensor<cpu, 1, IType>() :
+ output.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
+
+ // Data
+ // TODO(cjolivier01): Change to get_with_shape() calls
+ const Tensor<cpu, 2, DType> data_l = AsRowise2D<DType>(s, lhs.data());
+ const Tensor<cpu, 2, DType> data_r = AsRowise2D<DType>(s, rhs.data());
+ Tensor<cpu, 2, DType> out = AsRowise2D<DType>(s, output.data());
+
+ size_t iter_l = 0;
+ size_t iter_r = 0;
+ size_t iter_out = 0;
+ int32_t num_common_rows = 0;
+
+ if (is_dense_result) {
+ if (!num_rows_l && !num_rows_r) {
+ const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
+ iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
+ }
}
- }
- if (is_dense_result) {
- iter_out = FillDense<DType, OP>(s, idx_l, idx_r, req, &out, iter_out);
- DCHECK_EQ(iter_out, static_cast<size_t>(std::min(idx_l, idx_r)));
- }
- if (idx_l == idx_r) {
- // Same row
- if (!is_dense_result) {
- indices_out[iter_out] = idx_l;
+
+ while (iter_l < num_rows_l && iter_r < num_rows_r) {
+ IType idx_l = lhs_is_dense ? indices_r[iter_r] : indices_l[iter_l];
+ IType idx_r = rhs_is_dense ? idx_l : indices_r[iter_r];
+ if (lhs_in_place) {
+ while (idx_r < idx_l && ++iter_r < num_rows_r) {
+ idx_r = indices_r[iter_r];
+ }
+ if (iter_r >= num_rows_r) {
+ break;
+ }
+ } else if (rhs_in_place) {
+ while (idx_l < idx_r && ++iter_l < num_rows_l) {
+ idx_l = indices_l[iter_l];
+ }
+ if (iter_l >= num_rows_l) {
+ break;
+ }
+ }
+ if (is_dense_result) {
+ iter_out = FillDense<DType, OP>(s, idx_l, idx_r, req, &out, iter_out);
+ DCHECK_EQ(iter_out, static_cast<size_t>(std::min(idx_l, idx_r)));
+ }
+ if (idx_l == idx_r) {
+ // Same row
+ if (!is_dense_result) {
+ indices_out[iter_out] = idx_l;
+ }
+ Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
+ Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
+ DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size());
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
+ s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_);
+ });
+ num_common_rows++;
+ } else if (idx_l < idx_r) {
+ // Left only
+ if (!is_dense_result) {
+ indices_out[iter_out] = idx_l;
+ }
+ Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
+ s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_);
+ });
+ } else {
+ // Right only
+ if (scatter) {
+ ++iter_r;
+ continue; // skip '++iter_out' below
+ }
+ if (!is_dense_result) {
+ indices_out[iter_out] = idx_r;
+ }
+ Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
+ s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_);
+ });
+ }
+ ++iter_out;
}
- Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
- Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
- DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size());
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
- s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_);
- });
- num_common_rows++;
- } else if (idx_l < idx_r) {
- // Left only
- if (!is_dense_result) {
- indices_out[iter_out] = idx_l;
+ // Evaluate the remaining rows beyond the l and r value row intersetion
+ while (iter_l < num_rows_l && !lhs_is_dense && !rhs_in_place) {
+ if (!is_dense_result) {
+ indices_out[iter_out] = indices_l[iter_l];
+ } else {
+ const IType idx_l = indices_l[iter_l];
+ iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_l, req, &out, iter_out);
+ }
+ Tensor<cpu, 1, DType> lvalue = data_l[iter_l++];
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
+ s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_);
+ });
}
- Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
- s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_);
- });
- } else {
- // Right only
- if (scatter) {
- ++iter_r;
- continue; // skip '++iter_out' below
+ while (iter_r < num_rows_r && !rhs_is_dense && !lhs_in_place && !scatter) {
+ if (!is_dense_result) {
+ indices_out[iter_out] = indices_r[iter_r];
+ } else {
+ const IType idx_r = indices_r[iter_r];
+ iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_r, req, &out, iter_out);
+ }
+ Tensor<cpu, 1, DType> rvalue = data_r[iter_r++];
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
+ s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_);
+ });
}
- if (!is_dense_result) {
- indices_out[iter_out] = idx_r;
+ if (is_dense_result) {
+ const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
+ iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
+ } else {
+ if (lhs_in_place) {
+ CHECK_LE(iter_out, num_rows_l);
+ }
+ if (rhs_in_place) {
+ CHECK_LE(iter_out, num_rows_r);
+ }
+ DCHECK_LE(iter_out, num_rows_l + num_rows_r); // Make sure that we didn't overrun
+ nnvm::TShape new_shape = output.aux_shape(rowsparse::kIdx);
+ CHECK_LE(iter_out, new_shape.Size());
+ if (!rhs_is_dense && !lhs_is_dense && !lhs_in_place && !rhs_in_place && !scatter) {
+ // Reduce the first-dimension size by the number of common rows
+ new_shape[0] -= num_common_rows;
+ output.set_aux_shape(rowsparse::kIdx, new_shape);
+ }
}
- Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
- s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_);
- });
- }
- ++iter_out;
- }
- // Evaluate the remaining rows beyond the l and r value row intersetion
- while (iter_l < num_rows_l && !lhs_is_dense && !rhs_in_place) {
- if (!is_dense_result) {
- indices_out[iter_out] = indices_l[iter_l];
- } else {
- const IType idx_l = indices_l[iter_l];
- iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_l, req, &out, iter_out);
- }
- Tensor<cpu, 1, DType> lvalue = data_l[iter_l++];
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
- s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_);
});
- }
- while (iter_r < num_rows_r && !rhs_is_dense && !lhs_in_place && !scatter) {
- if (!is_dense_result) {
- indices_out[iter_out] = indices_r[iter_r];
- } else {
- const IType idx_r = indices_r[iter_r];
- iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_r, req, &out, iter_out);
- }
- Tensor<cpu, 1, DType> rvalue = data_r[iter_r++];
- MXNET_ASSIGN_REQ_SWITCH(req, Req, {
- mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
- s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_);
- });
- }
- if (is_dense_result) {
- const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
- iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
- } else {
- if (lhs_in_place) {
- CHECK_LE(iter_out, num_rows_l);
- }
- if (rhs_in_place) {
- CHECK_LE(iter_out, num_rows_r);
- }
- DCHECK_LE(iter_out, num_rows_l + num_rows_r); // Make sure that we didn't overrun
- nnvm::TShape new_shape = output.aux_shape(rowsparse::kIdx);
- CHECK_LE(iter_out, new_shape.Size());
- if (!rhs_is_dense && !lhs_is_dense && !lhs_in_place && !rhs_in_place && !scatter) {
- // Reduce the first-dimension size by the number of common rows
- new_shape[0] -= num_common_rows;
- output.set_aux_shape(rowsparse::kIdx, new_shape);
- }
- }
+ });
}
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
-template<typename DType, typename IType, typename CType, typename OP>
+template<typename OP>
+void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<gpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ const OpReqType req,
+ const NDArray &output) {
+ LOG(FATAL) << "GPU not supported for CsrCsrOp";
+}
+
+/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
+template<typename OP>
void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
@@ -276,102 +310,108 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
mshadow::Shape1(std::min(output_nnz_guess, lhs.shape().Size()))});
DCHECK_EQ(output.aux_shape(csr::kIndPtr), lhs.aux_shape(csr::kIndPtr));
- const size_t alloc_size = nr_cols * sizeof(IType) + 2 * nr_cols * sizeof(DType);
-
- Tensor<cpu, 1, uint8_t> workspace =
- ctx.requested[ResourceRequestType::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
- mshadow::Shape1(alloc_size), s);
-
- // Allocate temp space and partition into three tensors
- mshadow::Tensor<cpu, 1, IType> next(reinterpret_cast<IType *>(workspace.dptr_),
- Shape1(nr_cols));
- mshadow::Tensor<cpu, 1, DType> lhs_row(reinterpret_cast<DType *>(workspace.dptr_
- + nr_cols * sizeof(IType)),
- Shape1(nr_cols));
- mshadow::Tensor<cpu, 1, DType> rhs_row;
-
- OpBase::FillDense<IType>(s, next.shape_.Size(), IType(-1), req, next.dptr_);
- OpBase::FillDense<DType>(s, lhs_row.shape_.Size(), DType(0), req, lhs_row.dptr_);
-
- if (!same_lhs_rhs) {
- rhs_row = Tensor<cpu, 1, DType>(lhs_row.dptr_ + nr_cols, Shape1(nr_cols));
- OpBase::FillDense<DType>(s, rhs_row.shape_.Size(), DType(0), req, rhs_row.dptr_);
- } else {
- rhs_row = lhs_row;
- }
+ MSHADOW_IDX_TYPE_SWITCH(lhs.aux_type(csr::kIdx), IType, {
+ MSHADOW_IDX_TYPE_SWITCH(lhs.aux_type(csr::kIndPtr), CType, {
+ MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+ const size_t alloc_size = nr_cols * sizeof(IType) + 2 * nr_cols * sizeof(DType);
+
+ Tensor<cpu, 1, uint8_t> workspace =
+ ctx.requested[ResourceRequestType::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
+ mshadow::Shape1(alloc_size), s);
+
+ // Allocate temp space and partition into three tensors
+ mshadow::Tensor<cpu, 1, IType> next(reinterpret_cast<IType *>(workspace.dptr_),
+ Shape1(nr_cols));
+ mshadow::Tensor<cpu, 1, DType> lhs_row(reinterpret_cast<DType *>(
+ workspace.dptr_ + nr_cols * sizeof(IType)),
+ Shape1(nr_cols));
+ mshadow::Tensor<cpu, 1, DType> rhs_row;
+
+ OpBase::FillDense<IType>(s, next.shape_.Size(), IType(-1), req, next.dptr_);
+ OpBase::FillDense<DType>(s, lhs_row.shape_.Size(), DType(0), req, lhs_row.dptr_);
+
+ if (!same_lhs_rhs) {
+ rhs_row = Tensor<cpu, 1, DType>(lhs_row.dptr_ + nr_cols, Shape1(nr_cols));
+ OpBase::FillDense<DType>(s, rhs_row.shape_.Size(), DType(0), req, rhs_row.dptr_);
+ } else {
+ rhs_row = lhs_row;
+ }
- // Column indices
- const Tensor<cpu, 1, IType> col_indices_l = lhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
- const Tensor<cpu, 1, IType> col_indices_r = rhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
- Tensor<cpu, 1, IType> col_indices_out = output.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
-
- // Row pointers
- const Tensor<cpu, 1, CType> row_ptr_l = lhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
- const Tensor<cpu, 1, CType> row_ptr_r = rhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
- Tensor<cpu, 1, CType> row_ptr_out = output.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
-
- Tensor<cpu, 1, DType> data_l = lhs.data().FlatTo1D<cpu, DType>(s);
- Tensor<cpu, 1, DType> data_r = rhs.data().FlatTo1D<cpu, DType>(s);
- Tensor<cpu, 1, DType> data_out = output.data().FlatTo1D<cpu, DType>(s);
-
- IType nnz = 0;
- row_ptr_out[0] = 0;
-
- for (IType i = 0; i < static_cast<IType>(nr_rows); i++) {
- IType head = -2;
- IType length = 0;
-
- // add a row of A to lhs_row
- const IType i_start_l = row_ptr_l[i];
- const IType i_end_l = row_ptr_l[i + 1];
- for (IType jj = i_start_l; jj < i_end_l; jj++) {
- IType col = col_indices_l[jj];
- lhs_row[col] += data_l[jj];
-
- if (next[col] == -1) {
- next[col] = head;
- head = col;
- ++length;
- }
- }
+ // Column indices
+ const Tensor<cpu, 1, IType> col_indices_l = lhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
+ const Tensor<cpu, 1, IType> col_indices_r = rhs.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
+ Tensor<cpu, 1, IType> col_indices_out = output.aux_data(csr::kIdx).FlatTo1D<cpu, IType>(s);
+
+ // Row pointers
+ const Tensor<cpu, 1, CType> row_ptr_l = lhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
+ const Tensor<cpu, 1, CType> row_ptr_r = rhs.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
+ Tensor<cpu, 1, CType> row_ptr_out = output.aux_data(csr::kIndPtr).FlatTo1D<cpu, CType>(s);
+
+ Tensor<cpu, 1, DType> data_l = lhs.data().FlatTo1D<cpu, DType>(s);
+ Tensor<cpu, 1, DType> data_r = rhs.data().FlatTo1D<cpu, DType>(s);
+ Tensor<cpu, 1, DType> data_out = output.data().FlatTo1D<cpu, DType>(s);
+
+ IType nnz = 0;
+ row_ptr_out[0] = 0;
+
+ for (IType i = 0; i < static_cast<IType>(nr_rows); i++) {
+ IType head = -2;
+ IType length = 0;
+
+ // add a row of A to lhs_row
+ const IType i_start_l = row_ptr_l[i];
+ const IType i_end_l = row_ptr_l[i + 1];
+ for (IType jj = i_start_l; jj < i_end_l; jj++) {
+ IType col = col_indices_l[jj];
+ lhs_row[col] += data_l[jj];
+
+ if (next[col] == -1) {
+ next[col] = head;
+ head = col;
+ ++length;
+ }
+ }
- if (!same_lhs_rhs) {
- // add a row of B to rhs_row
- const IType i_start_r = row_ptr_r[i];
- const IType i_end_r = row_ptr_r[i + 1];
- for (IType jj = i_start_r; jj < i_end_r; jj++) {
- const IType col = col_indices_r[jj];
- rhs_row[col] += data_r[jj];
-
- if (next[col] == -1) {
- next[col] = head;
- head = col;
- ++length;
- }
- }
- }
+ if (!same_lhs_rhs) {
+ // add a row of B to rhs_row
+ const IType i_start_r = row_ptr_r[i];
+ const IType i_end_r = row_ptr_r[i + 1];
+ for (IType jj = i_start_r; jj < i_end_r; jj++) {
+ const IType col = col_indices_r[jj];
+ rhs_row[col] += data_r[jj];
+
+ if (next[col] == -1) {
+ next[col] = head;
+ head = col;
+ ++length;
+ }
+ }
+ }
- // scan through columns where A or B has
- // contributed a non-zero entry
- for (IType jj = 0; jj < length; jj++) {
- const DType result = OP::Map(lhs_row[head], rhs_row[head]);
+ // scan through columns where A or B has
+ // contributed a non-zero entry
+ for (IType jj = 0; jj < length; jj++) {
+ const DType result = OP::Map(lhs_row[head], rhs_row[head]);
- if (result != 0) {
- col_indices_out[nnz] = head;
- data_out[nnz] = result;
- ++nnz;
- }
+ if (result != 0) {
+ col_indices_out[nnz] = head;
+ data_out[nnz] = result;
+ ++nnz;
+ }
- const IType temp = head;
- head = next[head];
+ const IType temp = head;
+ head = next[head];
- next[temp] = -1;
- lhs_row[temp] = 0;
- if (!same_lhs_rhs) rhs_row[temp] = 0;
- }
+ next[temp] = -1;
+ lhs_row[temp] = 0;
+ if (!same_lhs_rhs) rhs_row[temp] = 0;
+ }
- row_ptr_out[i + 1] = nnz;
- }
+ row_ptr_out[i + 1] = nnz;
+ }
+ });
+ });
+ });
}
/*!
@@ -386,12 +426,13 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
* \param num_rows number of rows of both inputs
* \param num_cols number of columns of both inputs
*/
-template<typename OP>
+template<int req, typename OP>
struct ElemwiseDnsCsrDnsKernel {
template<typename DType, typename IType, typename CType>
- static void inline Map(int i, OpReqType req, DType* out, DType* dns_data, const DType* csr_data,
- const IType* csr_indices, const CType* csr_indptr,
- const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) {
+ MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
+ const DType* csr_data, const IType* csr_indices,
+ const CType* csr_indptr, const nnvm::dim_t num_rows,
+ const nnvm::dim_t num_cols) {
if (i < num_rows) {
for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
KERNEL_ASSIGN(out[i * num_cols + csr_indices[j]], req,
@@ -402,8 +443,8 @@ struct ElemwiseDnsCsrDnsKernel {
};
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
-template<typename OP>
-void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s,
+template<typename xpu, typename OP>
+void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &dns,
@@ -430,17 +471,21 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s,
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
- mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, cpu>::Launch(
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, xpu>::Launch(
s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
- mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<mshadow_op::plus>, cpu>::Launch(
- s, num_csr_rows, Req, output.data().dptr<DType>(),
+ if (!csr.storage_initialized()) { return; }
+ mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, mshadow_op::plus>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
} else {
- mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, cpu>::Launch(
- s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
- mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<OP>, cpu>::Launch(
- s, num_csr_rows, Req, output.data().dptr<DType>(),
+ if (req == kWriteTo) {
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
+ s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
+ }
+ if (!csr.storage_initialized()) { return; }
+ mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<Req, OP>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
output.data().dptr<DType>(), csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
}
@@ -450,6 +495,92 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s,
});
}
+/*!
+ * \brief Kernel for performing elemwise op between dense and rsp tensor
+ * \param i global thread id
+ * \param req type of request
+ * \param out output array
+ * \param dns_data data array of dense input
+ * \param rsp_data data array of rsp input
+ * \param rsp_indices indices array of rsp input
+ * \param num_rows number of rows of both inputs
+ * \param nz_rows number of non-zero rows of rsp tensor
+ * \param num_cols number of columns of both inputs
+ */
+template<int req, typename OP>
+struct ElemwiseDnsRspDnsKernel {
+ template<typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
+ const DType* rsp_data, const IType* rsp_indices,
+ const nnvm::dim_t num_rows, const nnvm::dim_t nz_rows,
+ const nnvm::dim_t num_cols) {
+ if (i < nz_rows * num_cols) {
+ const nnvm::dim_t rsp_idx = i / num_cols;
+ const nnvm::dim_t dns_row = rsp_indices[rsp_idx];
+ const nnvm::dim_t col = i % num_cols;
+ KERNEL_ASSIGN(out[dns_row * num_cols + col], req,
+ OP::Map(dns_data[dns_row * num_cols + col],
+ rsp_data[rsp_idx * num_cols + col]));
+ }
+ }
+};
+
+/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
+template<typename xpu, typename OP>
+void ElemwiseBinaryOp::DnsRspDnsOp(mshadow::Stream<xpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &dns,
+ const NDArray &rsp,
+ const OpReqType req,
+ const NDArray &output,
+ const bool reverse) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ CHECK_EQ(dns.storage_type(), kDefaultStorage);
+ CHECK_EQ(rsp.storage_type(), kRowSparseStorage);
+ CHECK_EQ(output.data().Size(), dns.data().Size());
+ CHECK(req != kAddTo);
+ if (req == kNullOp) return;
+ const bool supported_op = std::is_same<OP, mshadow_op::minus>::value ||
+ std::is_same<OP, mshadow_op::plus>::value;
+ CHECK(supported_op == true) <<
+ "Only plus and minus supported now for elemwise operation between default and rsp matrices";
+ const nnvm::dim_t num_rows = dns.shape()[0];
+ const nnvm::dim_t num_cols = dns.data().Size() / num_rows;
+ const nnvm::dim_t nz_rows = rsp.aux_shape(rowsparse::kIdx).Size();
+ TBlob rsp_data = rsp.data();
+ TBlob rsp_indices = rsp.aux_data(rowsparse::kIdx);
+
+ MSHADOW_SGL_DBL_TYPE_SWITCH(rsp_data.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(rsp_indices.type_flag_, IType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, xpu>::Launch(
+ s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
+ if (rsp.storage_initialized()) {
+ mxnet_op::Kernel<ElemwiseDnsRspDnsKernel<Req, mshadow_op::plus>, xpu>::Launch(
+ s, nz_rows * num_cols, output.data().dptr<DType>(),
+ output.data().dptr<DType>(), rsp_data.dptr<DType>(), rsp_indices.dptr<IType>(),
+ num_rows, nz_rows, num_cols);
+ }
+ } else {
+ if (req == kWriteTo) {
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
+ s, output.data().Size(), output.data().dptr<DType>(), dns.data().dptr<DType>());
+ }
+ if (rsp.storage_initialized()) {
+ mxnet_op::Kernel<ElemwiseDnsRspDnsKernel<Req, OP>, xpu>::Launch(
+ s, nz_rows * num_cols, output.data().dptr<DType>(),
+ output.data().dptr<DType>(), rsp_data.dptr<DType>(), rsp_indices.dptr<IType>(),
+ num_rows, nz_rows, num_cols);
+ }
+ }
+ });
+ });
+ });
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 645907b..9d3f6e0 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -180,37 +180,29 @@ class ElemwiseBinaryOp : public OpBase {
// lhs grad
if (req[0] != kNullOp) {
// RspRspOp can handle dense outputs so long as OP(0, 0) == 0
- MSHADOW_IDX_TYPE_SWITCH(inputs[1].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, LOP>(
+ RspRspOp<LOP>(
s, attrs, ctx, inputs[1], inputs[2], req[0], outputs[0],
false, false, false, false);
- });
// lhs in-place
- MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, op::mshadow_op::mul>(
+ RspRspOp<op::mshadow_op::mul>(
s, attrs, ctx, outputs[0], inputs[0], req[0], outputs[0],
false, false, true, false);
- });
}
// rhs grad
if (req[1] != kNullOp) {
- MSHADOW_IDX_TYPE_SWITCH(inputs[1].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, ROP>(
- s, attrs, ctx, inputs[1], inputs[2], req[1], outputs[1],
- false, false, false, false);
- });
+ RspRspOp<ROP>(
+ s, attrs, ctx, inputs[1], inputs[2], req[1], outputs[1],
+ false, false, false, false);
// rhs in-place
- MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, op::mshadow_op::mul>(
- s, attrs, ctx, inputs[0], outputs[1], req[1], outputs[1],
- false, false, true, false);
- });
+ RspRspOp<op::mshadow_op::mul>(
+ s, attrs, ctx, inputs[0], outputs[1], req[1], outputs[1],
+ false, false, true, false);
}
}
protected:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
- template<typename DType, typename IType, typename OP>
+ template<typename OP>
static void RspRspOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
@@ -223,8 +215,22 @@ class ElemwiseBinaryOp : public OpBase {
bool allow_inplace,
bool scatter);
+ /*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
+ template<typename OP>
+ static void RspRspOp(mshadow::Stream<gpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ bool lhs_may_be_dense,
+ bool rhs_may_be_dense,
+ bool allow_inplace,
+ bool scatter);
+
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
- template<typename DType, typename IType, typename CType, typename OP>
+ template<typename OP>
static inline void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
@@ -233,9 +239,30 @@ class ElemwiseBinaryOp : public OpBase {
OpReqType req,
const NDArray &output);
- /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+ /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
- static inline void DnsCsrDnsOp(mshadow::Stream<cpu> *s,
+ static inline void CsrCsrOp(mshadow::Stream<gpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output);
+
+ /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+ template<typename xpu, typename OP>
+ static inline void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
+
+ /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
+ template<typename xpu, typename OP>
+ static inline void DnsRspDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
@@ -360,7 +387,13 @@ class ElemwiseBinaryOp : public OpBase {
(lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage))) {
// dense, csr -> dense / csr, dense -> dense
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
- dispatch_mode, dispatch_ex);
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched && ((lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage) ||
+ (lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage))) {
+ // dense, rsp -> dense / rsp, dense -> dense
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
@@ -448,23 +481,11 @@ class ElemwiseBinaryOp : public OpBase {
(out_stype == kRowSparseStorage || out_stype == kDefaultStorage)) {
// rsp, rsp -> rsp
// rsp, rsp -> dns
- const int rsp_input_idx = lhs_stype == kRowSparseStorage ? 0 : 1;
- MSHADOW_IDX_TYPE_SWITCH(inputs[rsp_input_idx].aux_type(rowsparse::kIdx), IType, {
- MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
- RspRspOp<DType, IType, OP>(
+ RspRspOp<OP>(
s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0], false, false, false, false);
- });
- });
} else if (ContainsOnlyStorage(inputs, kCSRStorage) && out_stype == kCSRStorage) {
// csr, csr -> csr
- MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIdx), IType, {
- MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIndPtr), CType, {
- MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
- CsrCsrOp<DType, IType, CType, OP>(
- s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0]);
- });
- });
- });
+ CsrCsrOp<OP>(s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0]);
} else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kDefaultStorage) {
@@ -472,7 +493,15 @@ class ElemwiseBinaryOp : public OpBase {
const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kCSRStorage);
- DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
+ DnsCsrDnsOp<xpu, OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
+ } else if (((lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) &&
+ out_stype == kDefaultStorage) {
+ const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
+ const bool reverse = (lhs_stype == kRowSparseStorage);
+ const NDArray& rsp = (reverse)? inputs[0] : inputs[1];
+
+ DnsRspDnsOp<xpu, OP>(s, attrs, ctx, dns, rsp, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -506,13 +535,9 @@ class ElemwiseBinaryOp : public OpBase {
// rsp, dns -> dns <-- NOT ALLOWED
// dns, rsp -> dns <-- NOT ALLOWED
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
- MSHADOW_IDX_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, OP>(
+ RspRspOp<OP>(
s, attrs, ctx, inputs[0], inputs[1],
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false);
- });
- });
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index 4f1e03b..dbb26ea 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -96,6 +96,8 @@ The storage type of ``elemwise_add`` output depends on storage types of inputs
- elemwise_add(csr, csr) = csr
- elemwise_add(default, csr) = default
- elemwise_add(csr, default) = default
+ - elemwise_add(default, rsp) = default
+ - elemwise_add(rsp, default) = default
- otherwise, ``elemwise_add`` generates output with default storage
)code")
@@ -170,6 +172,8 @@ The storage type of ``elemwise_sub`` output depends on storage types of inputs
- elemwise_sub(csr, csr) = csr
- elemwise_sub(default, csr) = default
- elemwise_sub(csr, default) = default
+ - elemwise_sub(default, rsp) = default
+ - elemwise_sub(rsp, default) = default
- otherwise, ``elemwise_sub`` generates output with default storage
)code")
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu
index c8e208e..9c1fd0e 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -23,11 +23,14 @@
* \brief GPU Implementation of unary function.
*/
#include "./elemwise_binary_op.h"
+#include "./elemwise_binary_op-inl.h"
namespace mxnet {
namespace op {
+
NNVM_REGISTER_OP(elemwise_add)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);
NNVM_REGISTER_OP(_grad_add)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>);
@@ -39,7 +42,8 @@ NNVM_REGISTER_OP(_backward_add)
NNVM_REGISTER_OP(elemwise_sub)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<
- gpu, op::mshadow_op::minus>);
+ gpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::minus>);
NNVM_REGISTER_OP(_backward_sub)
.set_attr<FCompute>("FCompute<gpu>",
diff --git a/src/operator/tensor/elemwise_scatter_op.h b/src/operator/tensor/elemwise_scatter_op.h
index de6b23c..33bc0da 100644
--- a/src/operator/tensor/elemwise_scatter_op.h
+++ b/src/operator/tensor/elemwise_scatter_op.h
@@ -184,12 +184,8 @@ class ElemwiseScatterBinaryOp : public ElemwiseBinaryOp,
&& (input1_stype == kRowSparseStorage || input1_stype == kDefaultStorage)
&& outputs[0].storage_type() == kRowSparseStorage) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
- MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
- MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(rowsparse::kIdx), IType, {
- RspRspOp<DType, IType, OP>(s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0],
- false, true, false, true);
- });
- });
+ RspRspOp<OP>(s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0],
+ false, true, false, true);
CHECK_EQ(inputs[0].aux_shape(rowsparse::kIdx).Size(),
outputs[0].aux_shape(rowsparse::kIdx).Size());
} else {
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 111c0b7..2d5ed5a 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -349,6 +349,19 @@ def test_elemwise_binary_ops():
lhs_density=lhs_density, rhs_density=rhs_density,
verbose=False)
+ if ((lhs_stype is 'default' and rhs_stype is 'row_sparse') or
+ (lhs_stype is 'default' and rhs_stype is 'csr')):
+ test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
+ lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=l),
+ lambda l, r: l + r,
+ lambda outg, l, r: (outg, outg),
+ lhs_grad_stype, rhs_grad_stype,
+ ograd_density=ograd_density,
+ force_lr_overlap=force_lr_overlap,
+ force_grad_overlap=force_grad_overlap,
+ lhs_density=lhs_density, rhs_density=rhs_density,
+ verbose=False)
+
test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
lambda l, r: mx.sym.sparse.elemwise_sub(l, r),
lambda l, r: l - r,
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.