You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/01/04 09:03:17 UTC

[incubator-mxnet] 01/06: Improve performance of take operator

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch take_opt
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit ad38228e1b334bf657f4133b903febc42f0d74d2
Author: Bartlomiej Gawrych <ba...@intel.com>
AuthorDate: Mon Nov 15 09:07:35 2021 +0100

    Improve performance of take operator
---
 src/operator/tensor/indexing_op.cc | 94 ++++++++++++++++++++++++++------------
 1 file changed, 64 insertions(+), 30 deletions(-)

diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 3082541..4602640 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -60,6 +60,46 @@ struct TakeZeroAxisCPU {
   }
 };
 
+template <bool clip = true>
+struct TakeNonzeroAxisCPU {
+  /*!
+   * \brief Map function for take operator
+   * \param i                 global thread id
+   * \param out_data          ptr to output buffer
+   * \param in_data           ptr to input buffer
+   * \param idx               ptr to indices buffer
+   * \param outer_dim_stride  stride of dimension before axis
+   * \param axis_dim_stride   stride of axis dimension
+   * \param idx_size          size of the indices tensor
+   * \param axis_dim          dim size of the axis dimension
+   * \param axis              axis id
+   */
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(index_t i,
+                                  DType* out_data,
+                                  const DType* in_data,
+                                  const IType* indices,
+                                  const index_t outer_dim_stride,
+                                  const index_t axis_dim_stride,
+                                  const int idx_size,
+                                  const int axis_dim,
+                                  const int axis) {
+    for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {  // 4
+      int index = indices[j];
+      if (clip) {
+        index = (index < 0) ? 0 : index;
+        index = (index > axis_dim - 1) ? (axis_dim - 1) : index;
+      } else {
+        index %= axis_dim;
+        index += (index < 0) ? axis_dim : 0;
+      }
+      size_t in_offset  = i * outer_dim_stride + index * axis_dim_stride;
+      size_t out_offset = (i * idx_size + j) * axis_dim_stride;
+      memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
+    }
+  }
+};
+
 /*
  * \brief returns true if all indices are between [min, max]
  * \param data_ptr the indices to check
@@ -323,6 +363,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
                         const std::vector<OpReqType>& req,
                         const std::vector<TBlob>& outputs) {
   using namespace mxnet_op;
+
   if (req[take_::kOut] == kNullOp)
     return;
   const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
@@ -375,39 +416,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
         for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
           in_strides[i] = stride;
         }
-        mshadow::Shape<10> out_strides;
-        stride = 1;
-        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
-          out_strides[i] = stride;
+        int outer_dimensions = 1;
+        for (int i = 0; i < actual_axis; i++) {
+          outer_dimensions *= oshape[i];
         }
         if (param.mode == take_::kClip) {
-          Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s,
-                                                     oshape.Size(),
-                                                     outputs[take_::kOut].dptr<DType>(),
-                                                     inputs[take_::kArr].dptr<DType>(),
-                                                     inputs[take_::kIdx].dptr<IType>(),
-                                                     out_strides[actual_axis - 1],
-                                                     in_strides[actual_axis - 1],
-                                                     in_strides[actual_axis],
-                                                     arrshape.ndim(),
-                                                     oshape.ndim(),
-                                                     idxshape.ndim(),
-                                                     arrshape[actual_axis],
-                                                     actual_axis);
+          Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
+                                                        outer_dimensions,
+                                                        outputs[take_::kOut].dptr<DType>(),
+                                                        inputs[take_::kArr].dptr<DType>(),
+                                                        inputs[take_::kIdx].dptr<IType>(),
+                                                        in_strides[actual_axis - 1],
+                                                        in_strides[actual_axis],
+                                                        idxshape.Size(),
+                                                        arrshape[actual_axis],
+                                                        actual_axis);
         } else {
-          Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s,
-                                                      oshape.Size(),
-                                                      outputs[take_::kOut].dptr<DType>(),
-                                                      inputs[take_::kArr].dptr<DType>(),
-                                                      inputs[take_::kIdx].dptr<IType>(),
-                                                      out_strides[actual_axis - 1],
-                                                      in_strides[actual_axis - 1],
-                                                      in_strides[actual_axis],
-                                                      arrshape.ndim(),
-                                                      oshape.ndim(),
-                                                      idxshape.ndim(),
-                                                      arrshape[actual_axis],
-                                                      actual_axis);
+          Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
+                                                         outer_dimensions,
+                                                         outputs[take_::kOut].dptr<DType>(),
+                                                         inputs[take_::kArr].dptr<DType>(),
+                                                         inputs[take_::kIdx].dptr<IType>(),
+                                                         in_strides[actual_axis - 1],
+                                                         in_strides[actual_axis],
+                                                         idxshape.Size(),
+                                                         arrshape[actual_axis],
+                                                         actual_axis);
         }
       }
     });