You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/15 21:28:47 UTC

[incubator-mxnet] branch master updated: Parallelize CPU version and add GPU version of boolean_mask op (#14090)

This is an automated email from the ASF dual-hosted git repository.

zhengda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ba97fb6  Parallelize CPU version and add GPU version of boolean_mask op (#14090)
ba97fb6 is described below

commit ba97fb6663bf1116c8db492668fbc984e3541612
Author: HyperZealot <40...@users.noreply.github.com>
AuthorDate: Fri Feb 15 13:28:24 2019 -0800

    Parallelize CPU version and add GPU version of boolean_mask op (#14090)
---
 src/operator/contrib/boolean_mask-inl.h | 102 +++++++-------------
 src/operator/contrib/boolean_mask.cc    | 116 +++++++++++++++++++++-
 src/operator/contrib/boolean_mask.cu    | 165 ++++++++++++++++++++++++++++++++
 tests/python/unittest/test_operator.py  |   2 -
 4 files changed, 313 insertions(+), 72 deletions(-)

diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h
index ac0681b..775981f 100644
--- a/src/operator/contrib/boolean_mask-inl.h
+++ b/src/operator/contrib/boolean_mask-inl.h
@@ -50,83 +50,53 @@ struct BooleanMaskParam : public dmlc::Parameter<BooleanMaskParam> {
   }
 };
 
+struct BooleanMaskForwardKernel {
+  template<typename DType>
+  static void MSHADOW_XINLINE Map(int i,
+                                  DType* out,
+                                  const DType* data,
+                                  const int32_t* idx,
+                                  const size_t col_size) {
+    int row_id = i / col_size;
+    int col_id = i % col_size;
+    int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
+    int32_t curr = idx[row_id];
+    if (prev != curr) {
+      out[prev * col_size + col_id] = data[i];
+    }
+  }
+};
+
+struct BooleanMaskBackwardKernel {
+  template<typename DType>
+  static void MSHADOW_XINLINE Map(int i,
+                                  DType* igrad,
+                                  const DType* ograd,
+                                  const int32_t* idx,
+                                  const size_t col_size) {
+    int row_id = i / col_size;
+    int col_id = i % col_size;
+    int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
+    int32_t curr = idx[row_id];
+    if (prev != curr) {
+      igrad[i] = ograd[prev * col_size + col_id];
+    }
+  }
+};
+
 template<typename xpu>
 inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs,
                                const OpContext &ctx,
                                const std::vector<NDArray> &inputs,
                                const std::vector<OpReqType> &req,
-                               const std::vector<NDArray> &outputs) {
-  // TODO(@junrushao1994): This implementation is a proof-of-concept,
-  // hence very slow actually. Performance should be improved in the future.
-  CHECK_EQ(inputs.size(), 2U);
-  CHECK_EQ(outputs.size(), 1U);
-  const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
-  const int axis = param.axis;
-  const NDArray &data = inputs[0];
-  const NDArray &idx = inputs[1];
-  const NDArray &out = outputs[0];
-  CHECK_EQ(axis, 0) << "Not supported yet";
-  CHECK_EQ(data.shape()[axis], idx.shape()[0]);
-  CHECK_EQ(idx.shape().ndim(), 1U);
-  // count the number of 1s in `idx`, so that we could know the output dimension
-  size_t valid_num = 0;
-  MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
-    DType* idx_dptr = idx.data().dptr<DType>();
-    int length = idx.shape()[0];
-    for (int i = 0; i < length; i++) {
-      if (idx_dptr[i]) {
-        ++valid_num;
-      }
-    }
-  });
-  // set the output shape forcefully
-  TShape s = data.shape();
-  s[axis] = valid_num;
-  const_cast<NDArray &>(out).Init(s);
-  // do the copy
-  MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
-    DType* idx_dptr = idx.data().dptr<DType>();
-    int length = idx.shape()[0];
-    mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
-    for (int i = 0, j = 0; i < length; ++i) {
-      if (idx_dptr[i]) {
-        NDArray src = data.At(i);
-        NDArray dst = out.At(j++);
-        CHECK(src.shape() == dst.shape());
-        mxnet_op::copy(stream, dst.data(), src.data());
-      }
-    }
-  });
-}
+                               const std::vector<NDArray> &outputs);
 
 template<typename xpu>
 inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs,
                                 const OpContext &ctx,
                                 const std::vector<NDArray> &inputs,
                                 const std::vector<OpReqType> &req,
-                                const std::vector<NDArray> &outputs) {
-  CHECK_EQ(inputs.size(), 3U);
-  CHECK_EQ(outputs.size(), 2U);
-  // inputs: {ograd, data, idx}
-  // outputs: {igrad_data, igrad_idx}
-  const NDArray& ograd = inputs[0];
-  const NDArray& idx = inputs[2];
-  const NDArray& igrad_data = outputs[0];
-  MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
-    DType* idx_dptr = idx.data().dptr<DType>();
-    int length = idx.shape()[0];
-    mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
-    Fill<false>(stream, igrad_data.data(), req[0], 0);
-    for (int i = 0, j = 0; i < length; ++i) {
-      if (idx_dptr[i]) {
-        NDArray src = ograd.At(j++);
-        NDArray dst = igrad_data.At(i);
-        CHECK(src.shape() == dst.shape());
-        mxnet_op::copy(stream, dst.data(), src.data());
-      }
-    }
-  });
-}
+                                const std::vector<NDArray> &outputs);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc
index 7fd66bc..18ba8c3 100644
--- a/src/operator/contrib/boolean_mask.cc
+++ b/src/operator/contrib/boolean_mask.cc
@@ -28,7 +28,6 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(BooleanMaskParam);
 
-
 bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
                      std::vector<int> *in_attrs,
                      std::vector<int> *out_attrs) {
@@ -75,9 +74,116 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+struct BooleanMaskForwardCPUKernel {
+  template<typename DType>
+  static void Map(int i,
+                  DType* out,
+                  const DType* data,
+                  const int32_t* idx,
+                  const size_t col_size) {
+    // i is row id already
+    int32_t prev = (i == 0) ? 0 : idx[i - 1];
+    int32_t curr = idx[i];
+    if (prev != curr) {
+      std::memcpy(out + prev * col_size, data + i * col_size, col_size * sizeof(DType));
+    }
+  }
+};
+
+struct BooleanMaskBackwardCPUKernel {
+  template<typename DType>
+  static void Map(int i,
+                  DType* igrad,
+                  const DType* ograd,
+                  const int32_t* idx,
+                  const size_t col_size) {
+    // i is row id already
+    int32_t prev = (i == 0) ? 0 : idx[i - 1];
+    int32_t curr = idx[i];
+    if (prev != curr) {
+      std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
+    }
+  }
+};
+
+template<>
+inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<NDArray> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
+  const int axis = param.axis;
+  const NDArray &data = inputs[0];
+  const NDArray &idx = inputs[1];
+  const NDArray &out = outputs[0];
+  CHECK_EQ(axis, 0) << "Not supported yet";
+  CHECK_EQ(data.shape()[axis], idx.shape()[0]);
+  CHECK_EQ(idx.shape().ndim(), 1U);
+  // count the number of 1s in `idx`, so that we could know the output dimension
+  size_t idx_size = idx.shape()[0];
+  std::vector<int32_t> prefix_sum(idx_size, 0);
+  size_t valid_num = 0;
+  // Calculate prefix sum
+  MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
+    DType* idx_dptr = idx.data().dptr<DType>();
+    for (size_t i = 0; i < idx_size; i++) {
+      prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
+      prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
+    }
+    valid_num = prefix_sum[idx_size - 1];
+  });
+  // set the output shape forcefully
+  TShape s = data.shape();
+  s[axis] = valid_num;
+  const_cast<NDArray &>(out).Init(s);
+  // do the copy
+  MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
+    size_t input_size = data.shape().Size();
+    size_t col_size = input_size / idx_size;
+    mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
+    mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
+      stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
+      prefix_sum.data(), col_size);
+  });
+}
+
+template<>
+inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
+                                     const OpContext &ctx,
+                                     const std::vector<NDArray> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 2U);
+  // inputs: {ograd, data, idx}
+  // outputs: {igrad_data, igrad_idx}
+  const NDArray& ograd = inputs[0];
+  const NDArray& idx = inputs[2];
+  const NDArray& igrad_data = outputs[0];
+  MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
+    MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+      size_t input_size = igrad_data.shape().Size();
+      size_t idx_size = idx.shape()[0];
+      size_t col_size = input_size / idx_size;
+      std::vector<int32_t> prefix_sum(idx_size, 0);
+      IType* idx_dptr = idx.data().dptr<IType>();
+      for (size_t i = 0; i < idx_size; i++) {
+        prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
+        prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
+      }
+      mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
+      mxnet_op::Kernel<BooleanMaskBackwardCPUKernel, cpu>::Launch(
+        stream, idx_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
+        prefix_sum.data(), col_size);
+    });
+  });
+}
+
 NNVM_REGISTER_OP(_contrib_boolean_mask)
 .describe(R"code(
-Experimental CPU-only support for boolean masking.
 Given an n-d NDArray data, and a 1-d NDArray index,
 the operator produces an un-predeterminable shaped n-d NDArray out,
 which stands for the rows in x where the corresonding element in index is non-zero.
@@ -94,12 +200,14 @@ which stands for the rows in x where the corresonding element in index is non-ze
 .set_attr_parser(ParamParser<BooleanMaskParam>)
 .set_num_inputs(2)
 .set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "index"};
+  })
 .set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
 .set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
 .set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"})
-.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
-  return std::vector<std::string>{"data", "index"};})
 .add_argument("data", "NDArray-or-Symbol", "Data")
 .add_argument("index", "NDArray-or-Symbol", "Mask")
 .add_arguments(BooleanMaskParam::__FIELDS__());
diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu
new file mode 100644
index 0000000..25a781c
--- /dev/null
+++ b/src/operator/contrib/boolean_mask.cu
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file boolean_mask.cu
+*/
+
+#include "./boolean_mask-inl.h"
+#include <cub/cub.cuh>
+
+namespace mxnet {
+namespace op {
+
+template<>
+inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<NDArray> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
+  const int axis = param.axis;
+  const NDArray &data = inputs[0];
+  const NDArray &idx = inputs[1];
+  const NDArray &out = outputs[0];
+  CHECK_EQ(axis, 0) << "Not supported yet";
+  CHECK_EQ(data.shape()[axis], idx.shape()[0]);
+  CHECK_EQ(idx.shape().ndim(), 1U);
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  // count the number of 1s in `idx`, so that we could know the output dimension
+  size_t idx_size = idx.shape()[0];
+  int32_t valid_num = 0;
+  int32_t* prefix_sum = nullptr;
+  void* d_temp_storage = nullptr;
+  size_t temp_storage_bytes = 0;
+  // Calculate total temporary memory size
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                idx_size,
+                                Stream<gpu>::GetStream(s));
+  size_t buffer_size = idx_size * sizeof(int32_t);
+  temp_storage_bytes += buffer_size;
+  // Allocate memory on GPU and allocate pointer
+  Tensor<gpu, 1, char> workspace =
+    ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
+  prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
+  d_temp_storage = workspace.dptr_ + buffer_size;
+  MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+    mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
+      s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
+  });
+  // Calculate prefix sum
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                idx_size,
+                                Stream<gpu>::GetStream(s));
+  CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
+                       cudaMemcpyDeviceToHost));
+  CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
+  // Set the output shape forcefully
+  TShape data_shape = data.shape();
+  data_shape[axis] = valid_num;
+  const_cast<NDArray &>(out).Init(data_shape);
+  size_t input_size = data.shape().Size();
+  size_t col_size = input_size / idx.shape()[0];
+  // Do the copy
+  MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
+    mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
+      s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
+  });
+}
+
+template<>
+inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
+                                     const OpContext &ctx,
+                                     const std::vector<NDArray> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<NDArray> &outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 2U);
+  // inputs: {ograd, data, idx}
+  // outputs: {igrad_data, igrad_idx}
+  const NDArray& ograd = inputs[0];
+  const NDArray& idx = inputs[2];
+  const NDArray& igrad_data = outputs[0];
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  // Count the number of 1s in `idx`, so that we could know the output dimension
+  size_t idx_size = idx.shape()[0];
+  int32_t* prefix_sum = nullptr;
+  void* d_temp_storage = nullptr;
+  size_t temp_storage_bytes = 0;
+  // Calculate total temporary memory size
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                idx_size,
+                                Stream<gpu>::GetStream(s));
+  size_t buffer_size = idx_size * sizeof(int32_t);
+  temp_storage_bytes += buffer_size;
+  // Allocate memory on GPU and allocate pointer
+  Tensor<gpu, 1, char> workspace =
+    ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
+  prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
+  d_temp_storage = workspace.dptr_ + buffer_size;
+  MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
+    mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
+      s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
+  });
+  // Calculate prefix sum
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                idx_size,
+                                Stream<gpu>::GetStream(s));
+  size_t input_size = igrad_data.shape().Size();
+  size_t col_size = input_size / idx_size;
+  // Backward pass
+  MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
+    mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
+      s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
+      prefix_sum, col_size);
+  });
+}
+
+NNVM_REGISTER_OP(_contrib_boolean_mask)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index b0c640b..a9b9cc8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4880,8 +4880,6 @@ def test_index_copy():
 
 @with_seed()
 def test_boolean_mask():
-    if default_context().device_type != 'cpu':
-        return
     data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
     index = mx.nd.array([0, 1, 0])
     data.attach_grad()