You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/15 21:28:47 UTC
[incubator-mxnet] branch master updated: Parallelize CPU version
and add GPU version of boolean_mask op (#14090)
This is an automated email from the ASF dual-hosted git repository.
zhengda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ba97fb6 Parallelize CPU version and add GPU version of boolean_mask op (#14090)
ba97fb6 is described below
commit ba97fb6663bf1116c8db492668fbc984e3541612
Author: HyperZealot <40...@users.noreply.github.com>
AuthorDate: Fri Feb 15 13:28:24 2019 -0800
Parallelize CPU version and add GPU version of boolean_mask op (#14090)
---
src/operator/contrib/boolean_mask-inl.h | 102 +++++++-------------
src/operator/contrib/boolean_mask.cc | 116 +++++++++++++++++++++-
src/operator/contrib/boolean_mask.cu | 165 ++++++++++++++++++++++++++++++++
tests/python/unittest/test_operator.py | 2 -
4 files changed, 313 insertions(+), 72 deletions(-)
diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h
index ac0681b..775981f 100644
--- a/src/operator/contrib/boolean_mask-inl.h
+++ b/src/operator/contrib/boolean_mask-inl.h
@@ -50,83 +50,53 @@ struct BooleanMaskParam : public dmlc::Parameter<BooleanMaskParam> {
}
};
+struct BooleanMaskForwardKernel {
+ template<typename DType>
+ static void MSHADOW_XINLINE Map(int i,
+ DType* out,
+ const DType* data,
+ const int32_t* idx,
+ const size_t col_size) {
+ int row_id = i / col_size;
+ int col_id = i % col_size;
+ int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
+ int32_t curr = idx[row_id];
+ if (prev != curr) {
+ out[prev * col_size + col_id] = data[i];
+ }
+ }
+};
+
+struct BooleanMaskBackwardKernel {
+ template<typename DType>
+ static void MSHADOW_XINLINE Map(int i,
+ DType* igrad,
+ const DType* ograd,
+ const int32_t* idx,
+ const size_t col_size) {
+ int row_id = i / col_size;
+ int col_id = i % col_size;
+ int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
+ int32_t curr = idx[row_id];
+ if (prev != curr) {
+ igrad[i] = ograd[prev * col_size + col_id];
+ }
+ }
+};
+
template<typename xpu>
inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
- const std::vector<NDArray> &outputs) {
- // TODO(@junrushao1994): This implementation is a proof-of-concept,
- // hence very slow actually. Performance should be improved in the future.
- CHECK_EQ(inputs.size(), 2U);
- CHECK_EQ(outputs.size(), 1U);
- const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
- const int axis = param.axis;
- const NDArray &data = inputs[0];
- const NDArray &idx = inputs[1];
- const NDArray &out = outputs[0];
- CHECK_EQ(axis, 0) << "Not supported yet";
- CHECK_EQ(data.shape()[axis], idx.shape()[0]);
- CHECK_EQ(idx.shape().ndim(), 1U);
- // count the number of 1s in `idx`, so that we could know the output dimension
- size_t valid_num = 0;
- MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
- DType* idx_dptr = idx.data().dptr<DType>();
- int length = idx.shape()[0];
- for (int i = 0; i < length; i++) {
- if (idx_dptr[i]) {
- ++valid_num;
- }
- }
- });
- // set the output shape forcefully
- TShape s = data.shape();
- s[axis] = valid_num;
- const_cast<NDArray &>(out).Init(s);
- // do the copy
- MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
- DType* idx_dptr = idx.data().dptr<DType>();
- int length = idx.shape()[0];
- mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
- for (int i = 0, j = 0; i < length; ++i) {
- if (idx_dptr[i]) {
- NDArray src = data.At(i);
- NDArray dst = out.At(j++);
- CHECK(src.shape() == dst.shape());
- mxnet_op::copy(stream, dst.data(), src.data());
- }
- }
- });
-}
+ const std::vector<NDArray> &outputs);
template<typename xpu>
inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
- const std::vector<NDArray> &outputs) {
- CHECK_EQ(inputs.size(), 3U);
- CHECK_EQ(outputs.size(), 2U);
- // inputs: {ograd, data, idx}
- // outputs: {igrad_data, igrad_idx}
- const NDArray& ograd = inputs[0];
- const NDArray& idx = inputs[2];
- const NDArray& igrad_data = outputs[0];
- MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
- DType* idx_dptr = idx.data().dptr<DType>();
- int length = idx.shape()[0];
- mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
- Fill<false>(stream, igrad_data.data(), req[0], 0);
- for (int i = 0, j = 0; i < length; ++i) {
- if (idx_dptr[i]) {
- NDArray src = ograd.At(j++);
- NDArray dst = igrad_data.At(i);
- CHECK(src.shape() == dst.shape());
- mxnet_op::copy(stream, dst.data(), src.data());
- }
- }
- });
-}
+ const std::vector<NDArray> &outputs);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc
index 7fd66bc..18ba8c3 100644
--- a/src/operator/contrib/boolean_mask.cc
+++ b/src/operator/contrib/boolean_mask.cc
@@ -28,7 +28,6 @@ namespace op {
DMLC_REGISTER_PARAMETER(BooleanMaskParam);
-
bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
@@ -75,9 +74,116 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
return true;
}
+struct BooleanMaskForwardCPUKernel {
+ template<typename DType>
+ static void Map(int i,
+ DType* out,
+ const DType* data,
+ const int32_t* idx,
+ const size_t col_size) {
+ // i is row id already
+ int32_t prev = (i == 0) ? 0 : idx[i - 1];
+ int32_t curr = idx[i];
+ if (prev != curr) {
+ std::memcpy(out + prev * col_size, data + i * col_size, col_size * sizeof(DType));
+ }
+ }
+};
+
+struct BooleanMaskBackwardCPUKernel {
+ template<typename DType>
+ static void Map(int i,
+ DType* igrad,
+ const DType* ograd,
+ const int32_t* idx,
+ const size_t col_size) {
+ // i is row id already
+ int32_t prev = (i == 0) ? 0 : idx[i - 1];
+ int32_t curr = idx[i];
+ if (prev != curr) {
+ std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
+ }
+ }
+};
+
+template<>
+inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
+ const int axis = param.axis;
+ const NDArray &data = inputs[0];
+ const NDArray &idx = inputs[1];
+ const NDArray &out = outputs[0];
+ CHECK_EQ(axis, 0) << "Not supported yet";
+ CHECK_EQ(data.shape()[axis], idx.shape()[0]);
+ CHECK_EQ(idx.shape().ndim(), 1U);
+ // count the number of 1s in `idx`, so that we could know the output dimension
+ size_t idx_size = idx.shape()[0];
+ std::vector<int32_t> prefix_sum(idx_size, 0);
+ size_t valid_num = 0;
+ // Calculate prefix sum
+ MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
+ DType* idx_dptr = idx.data().dptr<DType>();
+ for (size_t i = 0; i < idx_size; i++) {
+ prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
+ prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
+ }
+ valid_num = prefix_sum[idx_size - 1];
+ });
+ // set the output shape forcefully
+ TShape s = data.shape();
+ s[axis] = valid_num;
+ const_cast<NDArray &>(out).Init(s);
+ // do the copy
+ MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
+ size_t input_size = data.shape().Size();
+ size_t col_size = input_size / idx_size;
+ mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
+ stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
+ prefix_sum.data(), col_size);
+ });
+}
+
+template<>
+inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 2U);
+ // inputs: {ograd, data, idx}
+ // outputs: {igrad_data, igrad_idx}
+ const NDArray& ograd = inputs[0];
+ const NDArray& idx = inputs[2];
+ const NDArray& igrad_data = outputs[0];
+ MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
+ MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+ size_t input_size = igrad_data.shape().Size();
+ size_t idx_size = idx.shape()[0];
+ size_t col_size = input_size / idx_size;
+ std::vector<int32_t> prefix_sum(idx_size, 0);
+ IType* idx_dptr = idx.data().dptr<IType>();
+ for (size_t i = 0; i < idx_size; i++) {
+ prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
+ prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
+ }
+ mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<BooleanMaskBackwardCPUKernel, cpu>::Launch(
+ stream, idx_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
+ prefix_sum.data(), col_size);
+ });
+ });
+}
+
NNVM_REGISTER_OP(_contrib_boolean_mask)
.describe(R"code(
-Experimental CPU-only support for boolean masking.
Given an n-d NDArray data, and a 1-d NDArray index,
the operator produces an un-predeterminable shaped n-d NDArray out,
which stands for the rows in x where the corresonding element in index is non-zero.
@@ -94,12 +200,14 @@ which stands for the rows in x where the corresonding element in index is non-ze
.set_attr_parser(ParamParser<BooleanMaskParam>)
.set_num_inputs(2)
.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data", "index"};
+ })
.set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"})
-.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
- return std::vector<std::string>{"data", "index"};})
.add_argument("data", "NDArray-or-Symbol", "Data")
.add_argument("index", "NDArray-or-Symbol", "Mask")
.add_arguments(BooleanMaskParam::__FIELDS__());
diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu
new file mode 100644
index 0000000..25a781c
--- /dev/null
+++ b/src/operator/contrib/boolean_mask.cu
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file boolean_mask.cu
+*/
+
+#include "./boolean_mask-inl.h"
+#include <cub/cub.cuh>
+
+namespace mxnet {
+namespace op {
+
+template<>
+inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
+ const int axis = param.axis;
+ const NDArray &data = inputs[0];
+ const NDArray &idx = inputs[1];
+ const NDArray &out = outputs[0];
+ CHECK_EQ(axis, 0) << "Not supported yet";
+ CHECK_EQ(data.shape()[axis], idx.shape()[0]);
+ CHECK_EQ(idx.shape().ndim(), 1U);
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+ // count the number of 1s in `idx`, so that we could know the output dimension
+ size_t idx_size = idx.shape()[0];
+ int32_t valid_num = 0;
+ int32_t* prefix_sum = nullptr;
+ void* d_temp_storage = nullptr;
+ size_t temp_storage_bytes = 0;
+ // Calculate total temporary memory size
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ idx_size,
+ Stream<gpu>::GetStream(s));
+ size_t buffer_size = idx_size * sizeof(int32_t);
+ temp_storage_bytes += buffer_size;
+ // Allocate memory on GPU and allocate pointer
+ Tensor<gpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
+ prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
+ d_temp_storage = workspace.dptr_ + buffer_size;
+ MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+ mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
+ s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
+ });
+ // Calculate prefix sum
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ idx_size,
+ Stream<gpu>::GetStream(s));
+ CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
+ cudaMemcpyDeviceToHost));
+ CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
+ // Set the output shape forcefully
+ TShape data_shape = data.shape();
+ data_shape[axis] = valid_num;
+ const_cast<NDArray &>(out).Init(data_shape);
+ size_t input_size = data.shape().Size();
+ size_t col_size = input_size / idx.shape()[0];
+ // Do the copy
+ MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
+ mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
+ s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
+ });
+}
+
+template<>
+inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 2U);
+ // inputs: {ograd, data, idx}
+ // outputs: {igrad_data, igrad_idx}
+ const NDArray& ograd = inputs[0];
+ const NDArray& idx = inputs[2];
+ const NDArray& igrad_data = outputs[0];
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+ // Count the number of 1s in `idx`, so that we could know the output dimension
+ size_t idx_size = idx.shape()[0];
+ int32_t* prefix_sum = nullptr;
+ void* d_temp_storage = nullptr;
+ size_t temp_storage_bytes = 0;
+ // Calculate total temporary memory size
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ idx_size,
+ Stream<gpu>::GetStream(s));
+ size_t buffer_size = idx_size * sizeof(int32_t);
+ temp_storage_bytes += buffer_size;
+ // Allocate memory on GPU and allocate pointer
+ Tensor<gpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
+ prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
+ d_temp_storage = workspace.dptr_ + buffer_size;
+ MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+ mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
+ s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
+ });
+ // Calculate prefix sum
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ idx_size,
+ Stream<gpu>::GetStream(s));
+ size_t input_size = igrad_data.shape().Size();
+ size_t col_size = input_size / idx_size;
+ // Backward pass
+ MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
+ mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
+ s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
+ prefix_sum, col_size);
+ });
+}
+
+NNVM_REGISTER_OP(_contrib_boolean_mask)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskBackward<gpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index b0c640b..a9b9cc8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4880,8 +4880,6 @@ def test_index_copy():
@with_seed()
def test_boolean_mask():
- if default_context().device_type != 'cpu':
- return
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 1, 0])
data.attach_grad()