You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/13 04:22:32 UTC

[GitHub] szha closed pull request #12085: Accelerate the performance of topk for CPU side

szha closed pull request #12085: Accelerate the performance of topk for CPU side
URL: https://github.com/apache/incubator-mxnet/pull/12085
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index 105ee8b90db..16e6c0ecd3f 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -170,11 +170,13 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
   // Use full sort when K is relatively large.
   const bool full_sort(K*8 > N);
   // Batch size.
-  const int M(dat.size(0)/N);
+  const int M(work.size(0)/(sizeof(real_t)*N));
   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < M; ++i) {
-    real_t *vals = dat.dptr_;
+    // Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
+    real_t *vals = reinterpret_cast<real_t*>(work.dptr_);
+    real_t *sorted_vals = dat.dptr_+i*N;
     int *indices = ind.dptr_+i*N;
     if (is_ascend) {
       if (full_sort) {
@@ -193,11 +195,9 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
                           [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
       }
     }
-    real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
     for (int j = 0; j < K; ++j) {
-      buff[j] = vals[indices[j]];
+      sorted_vals[j] = vals[indices[j]];
     }
-    std::copy(buff, buff+K, &vals[i*N]);
   }
 }
 
@@ -380,16 +380,7 @@ void TopKImpl(RunContext ctx,
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the original matrix
   workspace_curr_ptr += sizeof(int) * src.Size();
-  if (do_transpose) {
-    sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
-  } else {
-    sorted_dat = reshape(dat, Shape1(src.Size()));
-  }
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
-    kWriteTo, indices.dptr_);
 
-  CHECK_EQ(sorted_dat.CheckContiguous(), true);
-  CHECK_EQ(indices.CheckContiguous(), true);
   if (param.ret_typ == topk_enum::kReturnMask) {
     sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
@@ -401,15 +392,47 @@ void TopKImpl(RunContext ctx,
     CHECK_EQ(sel_indices.CheckContiguous(), true);
     CHECK_EQ(mask_val.CheckContiguous(), true);
   }
-  temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s);  // temp space
-  workspace_curr_ptr += temp_size;
+
+  if (std::is_same<xpu, cpu>::value) {
+    Tensor<xpu, 1, real_t> flattened_data;
+    if (do_transpose) {
+      flattened_data = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+                                              Shape1(src.Size()), s);
+      workspace_curr_ptr += sizeof(real_t) * src.Size();
+      flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
+      CHECK_EQ(flattened_data.CheckContiguous(), true);
+    } else {
+      flattened_data = src.FlatTo1D<xpu, real_t>(s);
+    }
+    // `temp_workspace` stores the flattened data
+    temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
+                                          Shape1(sizeof(real_t)*src.Size()), s);
+    CHECK_EQ(temp_workspace.CheckContiguous(), true);
+  } else {
+    if (do_transpose) {
+      sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
+    } else {
+      sorted_dat = reshape(dat, Shape1(src.Size()));
+    }
+    CHECK_EQ(sorted_dat.CheckContiguous(), true);
+    temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s);  // temp space
+    workspace_curr_ptr += temp_size;
+  }
+
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
+    kWriteTo, indices.dptr_);
+  CHECK_EQ(indices.CheckContiguous(), true);
 
   // 2. Perform inplace batch sort.
   // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
   // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
+  // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
+  // a temporal buffer for GPU device.
   TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
 
   // 3. Assign results to the ret blob
+  // When returning indices, only update(modulo) required elements instead of full elements
+  // to avoid redundant calculation.
   if (param.ret_typ == topk_enum::kReturnMask) {
     Tensor<xpu, 2, real_t> ret_mask =
       ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
@@ -427,7 +450,6 @@ void TopKImpl(RunContext ctx,
     }
     IndexFill(ret_mask, sel_indices, mask_val);
   } else if (param.ret_typ == topk_enum::kReturnIndices) {
-    indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
       Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
       ret_indices = tcast<real_t>(transpose(
@@ -437,14 +459,15 @@ void TopKImpl(RunContext ctx,
                                                       element_num)),
                                0, k),
                       Shape3(0, 2, 1)));
+      ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
     } else {
       Tensor<xpu, 2, real_t> ret_indices =
         ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
       ret_indices = tcast<real_t>(slice<1>(
                       inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
+      ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
     }
   } else {
-    indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
       Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
       Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
@@ -460,6 +483,7 @@ void TopKImpl(RunContext ctx,
                                                       element_num)),
                                0, k),
                       Shape3(0, 2, 1)));
+      ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
     } else {
       Tensor<xpu, 2, real_t> ret_value =
         ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
@@ -468,6 +492,7 @@ void TopKImpl(RunContext ctx,
       ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
       ret_indices = tcast<real_t>(slice<1>(
                       inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
+      ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
     }
   }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services