You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/01/04 09:03:17 UTC
[incubator-mxnet] 01/06: Improve performance of take operator
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch take_opt
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit ad38228e1b334bf657f4133b903febc42f0d74d2
Author: Bartlomiej Gawrych <ba...@intel.com>
AuthorDate: Mon Nov 15 09:07:35 2021 +0100
Improve performance of take operator
---
src/operator/tensor/indexing_op.cc | 94 ++++++++++++++++++++++++++------------
1 file changed, 64 insertions(+), 30 deletions(-)
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 3082541..4602640 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -60,6 +60,46 @@ struct TakeZeroAxisCPU {
}
};
+template <bool clip = true>
+struct TakeNonzeroAxisCPU {
+ /*!
+ * \brief Map function for take operator
+ * \param i global thread id
+ * \param out_data ptr to output buffer
+ * \param in_data ptr to input buffer
+ * \param idx ptr to indices buffer
+ * \param outer_dim_stride stride of dimension before axis
+ * \param axis_dim_stride stride of axis dimension
+ * \param idx_size size of the indices tensor
+ * \param axis_dim dim size of the axis dimension
+ * \param axis axis id
+ */
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(index_t i,
+ DType* out_data,
+ const DType* in_data,
+ const IType* indices,
+ const index_t outer_dim_stride,
+ const index_t axis_dim_stride,
+ const int idx_size,
+ const int axis_dim,
+ const int axis) {
+ for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) { // 4
+ int index = indices[j];
+ if (clip) {
+ index = (index < 0) ? 0 : index;
+ index = (index > axis_dim - 1) ? (axis_dim - 1) : index;
+ } else {
+ index %= axis_dim;
+ index += (index < 0) ? axis_dim : 0;
+ }
+ size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
+ size_t out_offset = (i * idx_size + j) * axis_dim_stride;
+ memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
+ }
+ }
+};
+
/*
* \brief returns true if all indices are between [min, max]
* \param data_ptr the indices to check
@@ -323,6 +363,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
+
if (req[take_::kOut] == kNullOp)
return;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
@@ -375,39 +416,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
- mshadow::Shape<10> out_strides;
- stride = 1;
- for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
- out_strides[i] = stride;
+ int outer_dimensions = 1;
+ for (int i = 0; i < actual_axis; i++) {
+ outer_dimensions *= oshape[i];
}
if (param.mode == take_::kClip) {
- Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s,
- oshape.Size(),
- outputs[take_::kOut].dptr<DType>(),
- inputs[take_::kArr].dptr<DType>(),
- inputs[take_::kIdx].dptr<IType>(),
- out_strides[actual_axis - 1],
- in_strides[actual_axis - 1],
- in_strides[actual_axis],
- arrshape.ndim(),
- oshape.ndim(),
- idxshape.ndim(),
- arrshape[actual_axis],
- actual_axis);
+ Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
+ outer_dimensions,
+ outputs[take_::kOut].dptr<DType>(),
+ inputs[take_::kArr].dptr<DType>(),
+ inputs[take_::kIdx].dptr<IType>(),
+ in_strides[actual_axis - 1],
+ in_strides[actual_axis],
+ idxshape.Size(),
+ arrshape[actual_axis],
+ actual_axis);
} else {
- Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s,
- oshape.Size(),
- outputs[take_::kOut].dptr<DType>(),
- inputs[take_::kArr].dptr<DType>(),
- inputs[take_::kIdx].dptr<IType>(),
- out_strides[actual_axis - 1],
- in_strides[actual_axis - 1],
- in_strides[actual_axis],
- arrshape.ndim(),
- oshape.ndim(),
- idxshape.ndim(),
- arrshape[actual_axis],
- actual_axis);
+ Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
+ outer_dimensions,
+ outputs[take_::kOut].dptr<DType>(),
+ inputs[take_::kArr].dptr<DType>(),
+ inputs[take_::kIdx].dptr<IType>(),
+ in_strides[actual_axis - 1],
+ in_strides[actual_axis],
+ idxshape.Size(),
+ arrshape[actual_axis],
+ actual_axis);
}
}
});