You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/08/11 09:05:23 UTC

[GitHub] [incubator-mxnet] shuo-ouyang commented on a change in pull request #17002: Multi_sum_sq review, AtomicAdd removal

shuo-ouyang commented on a change in pull request #17002:
URL: https://github.com/apache/incubator-mxnet/pull/17002#discussion_r686642884



##########
File path: src/operator/contrib/multi_sum_sq.cu
##########
@@ -43,96 +43,121 @@ struct MultiSumSqKernelParam {
   int sizes[ARRAY_LIMIT];
   unsigned char block_to_tensor[BLOCK_LIMIT];
   int block_to_chunk[BLOCK_LIMIT];
+  int max_chunks_per_tensor = -1;
 };
 
 template<typename DType>
-__device__ __forceinline__ DType reduce_block_into_lanes(DType* x,
-                                                         DType val,
-                                                         int lanes = 1,
-                                                         bool share_result = false) {
-  int tid = threadIdx.x + threadIdx.y * blockDim.x;
-  int blockSize = blockDim.x * blockDim.y;  // blockSize is intended to be a multiple of 32.
-
-  if (blockSize >= 64) {
+__device__ __forceinline__ DType ReduceBlockIntoLanes(DType* x,
+                                                      DType val) {
+  int tid = threadIdx.x;
+  int block_size = blockDim.x;
+
+  if (block_size >= 64) {
     x[tid] = val;
     __syncthreads();
   }
 
   #pragma unroll
-  for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
+  for (int i = (block_size >> 1); i >= 64; i >>= 1) {
     if (tid < i)
       x[tid] = x[tid] + x[tid+i];
     __syncthreads();
   }
 
   DType final;
-
   if (tid < 32) {
-    if (blockSize >= 64)
+    if (block_size >= 64)
       final = x[tid] + x[tid+32];
     else
       final = val;
-    // __SYNCWARP();
 
     #pragma unroll
-    for (int i = 16; i >= lanes; i >>= 1)
+    for (int i = 16; i >= 1; i >>= 1)
       final = final + __shfl_down_sync(0xffffffff, final, i);
   }
-
-  if (share_result) {
-    if (tid < lanes)
-      x[tid] = final;  // EpilogueOp
-    // Make sure the smem result is visible to all warps.
-    __syncthreads();
-  }
-
   return final;
 }
 
 template<typename DType>
 __global__ void MultiSumSqKernel(int chunk_size,
                                  MultiSumSqKernelParam<DType> param,
-                                 float* output) {
+                                 float* block_reductions,
+                                 int start_tensor_id) {
   const int tensor_loc = param.block_to_tensor[blockIdx.x];
   const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size;
   const int n = param.sizes[tensor_loc] - chunk_len;
   const DType* x = param.addresses[tensor_loc] + chunk_len;
-  const auto iMax = n <= chunk_size? n : chunk_size;
+  const auto i_max = n <= chunk_size ? n : chunk_size;
   __shared__ float vals[512];
 
   // Non-divergent exit condition for __syncthreads, not necessary here
   float val = 0;
   for (int i_start = 0;
-       i_start < iMax;
+       i_start < i_max;
        i_start += blockDim.x * ILP) {
     int i = i_start + threadIdx.x;
-    //    #pragma unroll
-    for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) {
+#pragma unroll
+    for (int ii = 0; ii < ILP && i < i_max; ++ii, i += blockDim.x) {
       const auto incoming_val = static_cast<float>(x[i]);
       val += incoming_val * incoming_val;
     }
   }
+  const float final = ReduceBlockIntoLanes(vals, val);
+
+  if (threadIdx.x == 0) {
+    block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor +
+                    param.block_to_chunk[blockIdx.x]] = final;

Review comment:
       Maybe we should change the variable name here? `= final` specifies that a virtual function cannot be overridden in a derived class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org