You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/10 17:20:26 UTC
[incubator-mxnet] branch master updated: support broadcast_add/sub
between csr and 1D dense vector (#10714)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3972861 support broadcast_add/sub between csr and 1D dense vector (#10714)
3972861 is described below
commit 3972861ae6a95958b527331ace783a3c50f13072
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Thu May 10 10:20:19 2018 -0700
support broadcast_add/sub between csr and 1D dense vector (#10714)
---
src/operator/tensor/elemwise_binary_broadcast_op.h | 183 ++++++++++++++++++++-
.../tensor/elemwise_binary_broadcast_op_basic.cc | 16 +-
.../tensor/elemwise_binary_broadcast_op_basic.cu | 6 +-
src/operator/tensor/elemwise_binary_op.h | 10 +-
tests/python/unittest/test_sparse_operator.py | 25 +++
5 files changed, 226 insertions(+), 14 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 42a8f0f..933d0be 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -105,6 +105,32 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}
+inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int lhs_stype = in_attrs->at(0);
+ const int rhs_stype = in_attrs->at(1);
+ int& out_stype = out_attrs->at(0);
+ bool dispatched = false;
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
+ if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+ dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
+}
+
#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \
if (ndim <= 2) { \
const int NDim = 2; \
@@ -183,6 +209,24 @@ struct binary_broadcast_kernel {
KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
}
}
+
+ /*! \brief Map function for binary_broadcast_kernel */
+ MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+ const Shape <ndim> &lstride, const Shape <ndim> &rstride,
+ const Shape <ndim> &oshape, DType lhs, DType *rhs,
+ DType *out) {
+ Shape <ndim> coord = unravel(base, oshape);
+ auto lidx = static_cast<index_t>(dot(coord, lstride));
+ auto ridx = static_cast<index_t>(dot(coord, rstride));
+ KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx]));
+ // starts from 1 to avoid extra inc at end of loop
+ for (int i = 1; i < length; ++i) {
+ inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+ // When tuning, don't actually run the op, since it's not going to be tuned against
+ // the actual op we'll eventually be using
+ KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx]));
+ }
+ }
};
template<int req, typename OP>
@@ -200,6 +244,25 @@ struct csr_dns_csr_broadcast_kernel {
}
};
+template<int req, typename OP, bool reverse = false>
+struct csr_dns_map_kernel {
+ template <typename DType, typename CType, typename RType>
+ MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+ const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows,
+ const nnvm::dim_t num_cols) {
+ if (row < num_rows) {
+ const nnvm::dim_t curr_row_i = csr_indptr[row];
+ const nnvm::dim_t next_row_i = csr_indptr[row + 1];
+ for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
+ const nnvm::dim_t target = row * num_cols + csr_indices[iter];
+ KERNEL_ASSIGN(out[target], req,
+ reverse ? OP::Map(out[target], csr_data[iter]) :
+ OP::Map(csr_data[iter], out[target]));
+ }
+ }
+ }
+};
+
} // namespace mxnet_op
template<typename xpu, typename OP>
@@ -284,11 +347,77 @@ void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
}
template<typename xpu, typename OP>
-void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx,
+ const NDArray& csr,
+ const NDArray& dns,
+ const OpReqType req,
+ const NDArray& output,
+ const TShape& new_csrshape,
+ const TShape& new_dnsshape,
+ const TShape& new_oshape,
+ const int ndim,
+ const bool reverse) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace csr;
+ CHECK(req == kWriteTo) << "Only kWriteTo supported for broadcast(csr, dns) = dns";
+ const bool legal_op = std::is_same<OP, mshadow_op::plus>::value ||
+ std::is_same<OP, mshadow_op::minus>::value;
+ CHECK(legal_op) << "Only add/sub are supported for broadcast(csr, dns) = dns";
+ CHECK_EQ(csr.shape()[0], output.shape()[0]);
+ CHECK_EQ(csr.shape()[1], output.shape()[1]);
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ const nnvm::dim_t num_rows = output.shape()[0];
+ const nnvm::dim_t num_cols = output.shape()[1];
+ const TBlob& csr_data = csr.data();
+ const TBlob& csr_indices = csr.aux_data(kIdx);
+ const TBlob& csr_indptr = csr.aux_data(kIndPtr);
+ TBlob dns_data = dns.data();
+ TBlob out_data = output.data();
+
+ MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+ BROADCAST_NDIM_SWITCH(ndim, NDim, {
+ Shape<NDim> oshape = new_oshape.get<NDim>();
+ Shape<NDim> lstride = calc_stride(new_csrshape.get<NDim>());
+ Shape<NDim> rstride = calc_stride(new_dnsshape.get<NDim>());
+ if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+ Kernel<binary_broadcast_kernel<NDim, DType, mshadow_op::plus>, xpu>::
+ template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+ DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+ } else {
+ Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::
+ template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+ DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+ }
+ });
+ });
+ if (csr.storage_initialized()) {
+ MSHADOW_TYPE_SWITCH(csr.dtype(), DType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIdx), CType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIndPtr), RType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+ if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+ Kernel<csr_dns_map_kernel<req_type, mshadow_op::minus, true>, xpu>::Launch(
+ s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+ csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+ } else {
+ Kernel<csr_dns_map_kernel<req_type, mshadow_op::plus>, xpu>::Launch(
+ s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+ csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+ }
+ });
+ });
+ });
+ });
+ }
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
@@ -317,6 +446,50 @@ void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
}
}
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ CHECK_LE(inputs[1].shape().ndim(), 2U)
+ << "input dense matrix should have less than or equal to 2 dimensions";
+ if (req[0] == kNullOp) return;
+ const NDArray& lhs = inputs[0];
+ const NDArray& rhs = inputs[1];
+ const NDArray& out = outputs[0];
+ const auto lhs_stype = lhs.storage_type();
+ const auto rhs_stype = rhs.storage_type();
+ const auto out_stype = out.storage_type();
+ bool reverse = (lhs_stype == kDefaultStorage);
+ const NDArray& dns = (reverse) ? lhs : rhs;
+ const NDArray& csr = (reverse) ? rhs : lhs;
+ TShape new_csrshape, new_dnsshape, new_oshape;
+ int ndim = BinaryBroadcastShapeCompact(csr.shape(), dns.shape(), out.shape(),
+ &new_csrshape, &new_dnsshape, &new_oshape);
+
+ if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+ out_stype == kDefaultStorage) {
+ // If the input is a matrix with the same shape, should be elemwise
+ if (!ndim) {
+ mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+ ElemwiseBinaryOp::DnsCsrDnsOp<xpu, OP>(
+ s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
+ } else {
+ // broadcast(CSR, Dense(1D)) = CSR
+ BinaryBroadcastCsrDnsDnsImpl<xpu, OP>(ctx, csr, dns, req[0], out,
+ new_csrshape, new_dnsshape, new_oshape,
+ ndim, reverse);
+ }
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+}
+
template<typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 6be4c26..78b2d45 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -48,8 +48,14 @@ Example::
broadcast_plus(x, y) = [[ 1., 1., 1.],
[ 2., 2., 2.]]
+Supported sparse operations:
+ broadcast_add(csr, dense(1D)) = dense
+ broadcast_add(dense(1D), csr) = dense
+
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::plus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
NNVM_REGISTER_OP(_backward_broadcast_add)
@@ -87,8 +93,14 @@ Example::
broadcast_minus(x, y) = [[ 1., 1., 1.],
[ 0., 0., 0.]]
+Supported sparse operations:
+ broadcast_sub/minus(csr, dense(1D)) = dense
+ broadcast_sub/minus(dense(1D), csr) = dense
+
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::minus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
NNVM_REGISTER_OP(_backward_broadcast_sub)
@@ -125,7 +137,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::mul>)
.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
@@ -164,7 +176,7 @@ Supported sparse operations:
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::div>)
.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
index dc0ba02..976f091 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
@@ -29,14 +29,16 @@
namespace mxnet {
namespace op {
NNVM_REGISTER_OP(broadcast_add)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::plus>);
NNVM_REGISTER_OP(_backward_broadcast_add)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
mshadow_op::identity>);
NNVM_REGISTER_OP(broadcast_sub)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::minus>);
NNVM_REGISTER_OP(_backward_broadcast_sub)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 9d3f6e0..a5b73da 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -200,7 +200,7 @@ class ElemwiseBinaryOp : public OpBase {
}
}
- protected:
+ public:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
template<typename OP>
static void RspRspOp(mshadow::Stream<cpu> *s,
@@ -231,7 +231,7 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
- static inline void CsrCsrOp(mshadow::Stream<cpu> *s,
+ static void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
@@ -241,7 +241,7 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
- static inline void CsrCsrOp(mshadow::Stream<gpu> *s,
+ static void CsrCsrOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
@@ -251,7 +251,7 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
- static inline void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+ static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
@@ -262,7 +262,7 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
- static inline void DnsRspDnsOp(mshadow::Stream<xpu> *s,
+ static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 2d5ed5a..f21334e 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1769,6 +1769,31 @@ def test_sparse_embedding():
check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
@with_seed()
+def test_sparse_broadcast_add_sub():
+ def check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+ assert_almost_equal(mx.nd.sparse.add(mx_lhs, mx_rhs).asnumpy(), np.add(np_lhs, np_rhs), atol=1e-4)
+ def check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+ assert_almost_equal(mx.nd.sparse.subtract(mx_lhs, mx_rhs).asnumpy(), np.subtract(np_lhs, np_rhs), atol=1e-4)
+ stype = 'csr'
+ shape = rand_shape_2d()
+ num_rows = shape[0]
+ num_cols = shape[1]
+ for density in [0.1 * i for i in range(10)]:
+ mx_lhs = rand_ndarray(shape, stype, density)
+ np_lhs = mx_lhs.asnumpy()
+ mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
+ mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
+ mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
+ mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
+ mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
+ for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
+ np_rhs = mx_rhs.asnumpy()
+ check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+ check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+ check_broadcast_add(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+ check_broadcast_sub(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+
+@with_seed()
def test_sparse_broadcast_mul_div():
def check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
assert_almost_equal(mx.nd.sparse.multiply(mx_lhs, mx_rhs).asnumpy(), np.multiply(np_lhs, np_rhs), atol=1e-4)
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.