You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/12/07 11:39:13 UTC
[GitHub] [incubator-mxnet] bartekkuncer commented on a change in pull request #20745: Optimize 'take' operator for CPU
bartekkuncer commented on a change in pull request #20745:
URL: https://github.com/apache/incubator-mxnet/pull/20745#discussion_r763881978
##########
File path: src/operator/tensor/indexing_op.cc
##########
@@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
}
};
+template <bool clip = true>
+struct TakeNonzeroAxisCPU {
+ /*!
+ * \brief Map function for take operator
+ * \param i global thread id
+ * \param out_data ptr to output buffer
+ * \param in_data ptr to input buffer
+ * \param idx ptr to indices buffer
+ * \param outer_dim_stride stride of dimension before axis
+ * \param axis_dim_stride stride of axis dimension
+ * \param idx_size size of the indices tensor
+ * \param axis_dim dim size of the axis dimension
+ * \param axis axis id
+ */
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(index_t i,
+ DType* out_data,
+ const DType* in_data,
+ const IType* indices,
+ const index_t outer_dim_stride,
+ const index_t axis_dim_stride,
+ const int idx_size,
+ const int axis_dim,
+ const int axis) {
+ for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
+ int index = indices[j];
+ if (clip) {
+ index = (index < 0) ? 0 : index;
Review comment:
Maybe use max(0,index)?
##########
File path: src/operator/tensor/indexing_op.cc
##########
@@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
}
};
+template <bool clip = true>
+struct TakeNonzeroAxisCPU {
+ /*!
+ * \brief Map function for take operator
+ * \param i global thread id
+ * \param out_data ptr to output buffer
+ * \param in_data ptr to input buffer
+ * \param idx ptr to indices buffer
+ * \param outer_dim_stride stride of dimension before axis
+ * \param axis_dim_stride stride of axis dimension
+ * \param idx_size size of the indices tensor
+ * \param axis_dim dim size of the axis dimension
+ * \param axis axis id
+ */
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(index_t i,
+ DType* out_data,
+ const DType* in_data,
+ const IType* indices,
+ const index_t outer_dim_stride,
+ const index_t axis_dim_stride,
+ const int idx_size,
+ const int axis_dim,
+ const int axis) {
+ for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
+ int index = indices[j];
+ if (clip) {
+ index = (index < 0) ? 0 : index;
+ index = (index > axis_dim - 1) ? (axis_dim - 1) : index;
Review comment:
And here min?
##########
File path: src/operator/tensor/indexing_op.h
##########
@@ -217,6 +217,7 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
+ * or TakeNonZeroAxisCPU for CPU optimized version
Review comment:
Please read the whole comment and amend it so that it does make sense :)
##########
File path: src/operator/tensor/indexing_op.cc
##########
@@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
}
};
+template <bool clip = true>
+struct TakeNonzeroAxisCPU {
+ /*!
+ * \brief Map function for take operator
+ * \param i global thread id
+ * \param out_data ptr to output buffer
+ * \param in_data ptr to input buffer
+ * \param idx ptr to indices buffer
+ * \param outer_dim_stride stride of dimension before axis
+ * \param axis_dim_stride stride of axis dimension
+ * \param idx_size size of the indices tensor
+ * \param axis_dim dim size of the axis dimension
+ * \param axis axis id
+ */
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(index_t i,
+ DType* out_data,
+ const DType* in_data,
+ const IType* indices,
Review comment:
Here we have indices but above you listed idx.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org