You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/04/22 21:00:11 UTC

[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r618735053



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {

Review comment:
       Moved to common/utils.h

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {

Review comment:
       Moved to common/utils.h




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org