You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/02/16 23:57:36 UTC

[GitHub] [incubator-mxnet] ptrendx opened a new pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

ptrendx opened a new pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905


   ## Description ##
   This PR moves the GPU softmax implementation (not yet the masked softmax implementation) to use RTC and adds multiple optimizations to it to improve performance.
   
   ## Checklist ##
   ### Essentials ###
   - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   - [x] Code is well-documented
   
   ### Changes ###
   - [x] Moved both stride1 and non-stride1 versions of the softmax kernels to use RTC
   - [x] The performance of the non-stride1 version was improved by running multiple rows per block and coalescing memory accesses. Benchmarks show ~4x improvement in time for the typical case, and much more (up to ~40x) when the size of the row over which the summation happens is very small.
   - [x] The performance of the stride1 kernel was improved by downloading multiple rows to shared memory collectively by the entire block and increasing amount of work per thread (including ability for the entire row to be summed by even a single thread, down from the minimum of 1 full warp per block in the previous version).
   - [x] The vectorization requirements of the previous implementation were eliminated, resulting in especially big speedup for cases where row length is odd.
   - [x] The stride1 kernel can now be used when the type of the output does not match the type of input (e.g. float16 input, float32 output)
   - [x] Overall, the performance of the stride1 kernel got improved ranging form 1.1x for BERT-like shapes (12 * 32, 128, 128), ~2x for the typical sizes with even row length and ~4x for the typical sizes with odd row length, to >20x for sizes with very small row length.
   - [x] Performance improvements quoted in the previous points are for the forward pass, but backward has similar (albeit slightly smaller) performance improvements.  
   - [x] Improved the mixed_type utility for RTC kernels (now one can use `type_util::mixed_type<DType, DType2>` instead of the previous verbose `typename type_util::mixed_type<DType, DType2>::type`, and arbitrary number of types can be passed as template arguments)
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] MoisesHer commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
MoisesHer commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r585267506



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {

Review comment:
       Is there a way to avoid duplicated code for this structure?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;

Review comment:
       is the one defined in https://github.com/apache/incubator-mxnet/blob/ba2c3b411e953b650fd919db59b42b5535d80e83/src/common/cuda/rtc/util-inl.h#L280 visible here?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {

Review comment:
        move this to util-inl.h if you think this can be reused by other ops?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {
+  size_t ret = 1;
+  size_t copyN = N;
+  while (N >= 2) {
+    ret *= 2;
+    N /= 2;
+  }
+  if (ret < copyN) {
+    ret *= 2;
+  }
+  return ret;
+}
+
+int get_rows_per_block(const index_t row_size, const int nvec,
+                       const index_t max_storage, const int num_threads_per_block,
+                       const index_t total_rows, const int dev_id) {
+  CHECK(IsPower2(num_threads_per_block))
+    << "Number of threads in a block must be power of 2 to use get_rows_per_block function";
+  // How many read instructions should 1 thread at least do
+  const int read_instructions = 16;
+  const size_t row_size_in_vec = (row_size + nvec - 1) / nvec;
+  int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1) / read_instructions;
+  desired_num_threads_per_row = RoundToPower2(desired_num_threads_per_row);
+  desired_num_threads_per_row = std::min(desired_num_threads_per_row, num_threads_per_block);
+  const int desired_rows_per_block = num_threads_per_block / desired_num_threads_per_row;
+  int actual_rows_per_block = desired_rows_per_block;
+  int num_sms = MultiprocessorCount(dev_id);
+  while (actual_rows_per_block > 1 &&
+         ((max_storage != -1 && max_storage < row_size * actual_rows_per_block) ||
+          (total_rows + actual_rows_per_block - 1) / actual_rows_per_block < num_sms)) {
+    actual_rows_per_block /= 2;
+  }
+  return actual_rows_per_block;
+}
+
+}  // namespace
+
+void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<TBlob>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using common::mshadow_type_info;
+  using namespace common::cuda::rtc;
+  using common::div_round;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double temperature = param.temperature.has_value() ?
+                             param.temperature.value() : 1.0;
+  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+  void* length_ptr = nullptr;
+  std::string length_typename = "int";
+  if (param.use_length.value()) {
+    CHECK(inputs.size() > 1)
+      << "Mask needs to be provided when using softmax with use_length=True.";
+    length_ptr = inputs[1].dptr_;
+    length_typename = mshadow_type_info(inputs[1].type_flag_).name;
+  }
+  CHECK_EQ(outputs.size(), 1);
+  index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
+  index_t stride = 1;
+  if (axis == shape.ndim() - 2) {
+    stride = shape[shape.ndim() - 1];
+  }
+  const index_t N = shape.Size() / M;
+  softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr},
+                           {outputs[0].dptr_},
+                           stride, M,
+                           temperature, 1, N};
+  std::string code = "#define OP " + OP + "\n"
+                     "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+                     "const bool negate = " + std::to_string(negate) + ";\n"
+                     "using InputType1 = " + length_typename + ";\n";
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  constexpr int nvec = 2;
+  // Using 20 kB of shared memory for persistent storage in the optimized case
+  const size_t acc_type_size = std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+                                        mshadow_type_info(outputs[0].type_flag_).acc_size);
+  const size_t max_opt_M = 20 * 1024 / acc_type_size;
+  int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+                                          vectorized_kernel_thread_num,
+                                          N, ctx.run_ctx.ctx.dev_id);
+  if (stride == 1 &&
+      static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+    const int warp_size = 32;

Review comment:
       maybe you can use https://github.com/apache/incubator-mxnet/blob/ba2c3b411e953b650fd919db59b42b5535d80e83/src/common/cuda/rtc/util-inl.h#L280?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {

Review comment:
       maybe we could move this to util-inl.h if you think this can be reused by other ops?

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];

Review comment:
       why 3 inputs?
   seems you only make use of 2




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826501721


   @mxnet-bot run ci [unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r618735191



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {
+  size_t ret = 1;
+  size_t copyN = N;
+  while (N >= 2) {
+    ret *= 2;
+    N /= 2;
+  }
+  if (ret < copyN) {
+    ret *= 2;
+  }
+  return ret;
+}
+
+int get_rows_per_block(const index_t row_size, const int nvec,
+                       const index_t max_storage, const int num_threads_per_block,
+                       const index_t total_rows, const int dev_id) {
+  CHECK(IsPower2(num_threads_per_block))
+    << "Number of threads in a block must be power of 2 to use get_rows_per_block function";
+  // How many read instructions should 1 thread at least do
+  const int read_instructions = 16;
+  const size_t row_size_in_vec = (row_size + nvec - 1) / nvec;
+  int desired_num_threads_per_row = (row_size_in_vec + read_instructions - 1) / read_instructions;
+  desired_num_threads_per_row = RoundToPower2(desired_num_threads_per_row);
+  desired_num_threads_per_row = std::min(desired_num_threads_per_row, num_threads_per_block);
+  const int desired_rows_per_block = num_threads_per_block / desired_num_threads_per_row;
+  int actual_rows_per_block = desired_rows_per_block;
+  int num_sms = MultiprocessorCount(dev_id);
+  while (actual_rows_per_block > 1 &&
+         ((max_storage != -1 && max_storage < row_size * actual_rows_per_block) ||
+          (total_rows + actual_rows_per_block - 1) / actual_rows_per_block < num_sms)) {
+    actual_rows_per_block /= 2;
+  }
+  return actual_rows_per_block;
+}
+
+}  // namespace
+
+void SoftmaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<TBlob>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using common::mshadow_type_info;
+  using namespace common::cuda::rtc;
+  using common::div_round;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  int axis = CheckAxis(param.axis, inputs[0].ndim());
+  const double temperature = param.temperature.has_value() ?
+                             param.temperature.value() : 1.0;
+  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
+
+  void* length_ptr = nullptr;
+  std::string length_typename = "int";
+  if (param.use_length.value()) {
+    CHECK(inputs.size() > 1)
+      << "Mask needs to be provided when using softmax with use_length=True.";
+    length_ptr = inputs[1].dptr_;
+    length_typename = mshadow_type_info(inputs[1].type_flag_).name;
+  }
+  CHECK_EQ(outputs.size(), 1);
+  index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
+  index_t stride = 1;
+  if (axis == shape.ndim() - 2) {
+    stride = shape[shape.ndim() - 1];
+  }
+  const index_t N = shape.Size() / M;
+  softmax_params params = {{inputs[0].dptr_, length_ptr, nullptr},
+                           {outputs[0].dptr_},
+                           stride, M,
+                           temperature, 1, N};
+  std::string code = "#define OP " + OP + "\n"
+                     "const OpReqType req = " + util::to_string(req[0]) + ";\n"
+                     "const bool negate = " + std::to_string(negate) + ";\n"
+                     "using InputType1 = " + length_typename + ";\n";
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  constexpr int nvec = 2;
+  // Using 20 kB of shared memory for persistent storage in the optimized case
+  const size_t acc_type_size = std::max(mshadow_type_info(inputs[0].type_flag_).acc_size,
+                                        mshadow_type_info(outputs[0].type_flag_).acc_size);
+  const size_t max_opt_M = 20 * 1024 / acc_type_size;
+  int rows_per_block = get_rows_per_block(M, nvec, max_opt_M,
+                                          vectorized_kernel_thread_num,
+                                          N, ctx.run_ctx.ctx.dev_id);
+  if (stride == 1 &&
+      static_cast<size_t>(M * rows_per_block) <= max_opt_M) {
+    const int warp_size = 32;

Review comment:
       I will use `mxnet::common::cuda::warp_size`, thanks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r592659861



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {

Review comment:
       Not really unfortunately (or at least not with the current build system). In order to do that we would need to be able to import files both as actual headers and as strings (or alternatively we would need to ship those headers inside the package and have nvRTC include them during the compilation). It would be nice to have it though, I agree.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826030710


   @mxnet-bot run ci [centos-gpu, unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r592660682



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;

Review comment:
       Should be, I will remove that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] MoisesHer commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
MoisesHer commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826969970


   thanks @ptrendx , looks good to me


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826501754


   Jenkins CI successfully triggered : [unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r618735053



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {

Review comment:
       Moved to common/utils.h

##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+const char softmax_common_functions[] = R"code(
+struct softmax_params {
+  const void* inputs[3];
+  void* outputs[1];
+  index_t stride;
+  index_t num_elements;
+  double temperature;
+  int rows_per_block;
+  index_t total_rows;
+};
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+softmax_fwd(const DType a, const DType2 b) {
+  return op::exp(a) / b;
+}
+
+template <typename DType, typename DType2>
+__device__ inline type_util::mixed_type<DType, DType2>
+log_softmax_fwd(const DType a, const DType2 b) {
+  return a - op::log(b);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return out * (ograd - sum);
+}
+
+template <typename DType, typename DType2, typename DType3>
+__device__ inline type_util::mixed_type<DType, DType2, DType3>
+log_softmax_bwd(DType ograd, DType2 out, DType3 sum) {
+    return ograd - op::exp(out) * sum;
+}
+
+)code";
+
+const char simple_softmax_kernel_fwd[] = R"code(
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void simple_softmax_kernel(const softmax_params param,
+                                      const index_t lead_dim) {
+  using LengthType = AccType<InputType1>;
+  const InputType0* input = reinterpret_cast<const InputType0*>(param.inputs[0]);
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  const index_t len = length == nullptr
+                      ? lead_dim
+                      : static_cast<index_t>(LengthType::from(length[blockIdx.x]));
+  const int my_row = threadIdx.x % param.rows_per_block;
+  const int my_id = threadIdx.x / param.rows_per_block;
+  const int threads_per_row = blockDim.x / param.rows_per_block;
+  const index_t base_x = (blockIdx.x * param.rows_per_block + my_row) % param.stride;
+  const index_t base_n = (blockIdx.x * param.rows_per_block + my_row) / param.stride;
+  const index_t base = base_x + param.stride * lead_dim * base_n;
+  if (base >= param.num_elements * param.total_rows) return;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType smem[kRTCMaxThreadsPerBlock];
+  AType max;
+  red::maximum::SetInitValue(max);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    max = op::max(max, negate ? -val : val);
+  }
+  smem[threadIdx.x] = max;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::max(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::max(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  AType smax = smem[my_row];
+  __syncthreads();
+
+  AType sum;
+  red::sum::SetInitValue(sum);
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val :val;
+    sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  smem[threadIdx.x] = sum;
+  __syncthreads();
+  for (int size = blockDim.x / 2; size >= warp_size; size /= 2) {
+    if (threadIdx.x < size) {
+      smem[threadIdx.x] = op::add(smem[threadIdx.x], smem[threadIdx.x + size]);
+    }
+    __syncthreads();
+  }
+  if (threadIdx.x < warp_size) {
+    AType my_value = util::strided_grouped_warp_reduce(smem[threadIdx.x],
+                                                       [](AType x, AType y)
+                                                         { return op::add(x, y); },
+                                                       param.rows_per_block);
+    smem[threadIdx.x] = my_value;
+  }
+  __syncthreads();
+  sum = smem[my_row];
+  __syncthreads();
+
+  OutputType0* output = reinterpret_cast<OutputType0*>(param.outputs[0]);
+  for (index_t i = my_id; i < lead_dim; i += threads_per_row) {
+    auto val = IType::from(input[base + i * param.stride]);
+    val = negate ? -val : val;
+    val = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), sum) : 0;
+    if (req == OpReqType::kAddTo) {
+      if (i < len) {
+        output[base + i * param.stride] = OType::to(val +
+                                                    OType::from(output[base + i * param.stride]));
+      }
+    } else {
+      output[base + i * param.stride] = OType::to(val);
+    }
+  }
+}
+)code";
+
+const char softmax_stride1_kernel_fwd[] = R"code(
+__launch_bounds__(vector::vectorized_kernel_thread_num)
+__global__ void softmax_stride1_compute_kernel(const softmax_params param,
+                                               const index_t total_length,
+                                               const index_t other_dim,
+                                               const index_t N,
+                                               const index_t num_aligned_elements) {
+  using namespace vector;
+  using IType = AccType<InputType0>;
+  using OType = AccType<OutputType0>;
+  using LengthType = AccType<InputType1>;
+  const InputType1* length = reinterpret_cast<const InputType1*>(param.inputs[1]);
+  using AType = type_util::mixed_type<typename IType::type,
+                                      typename OType::type>;
+  __shared__ AType scratch[vectorized_kernel_thread_num];
+  __shared__ AType persistent_storage[20 * 1024 / sizeof(AType)];
+  const int warp_size = 32;
+  const int threads_per_row = vectorized_kernel_thread_num / param.rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int base_row = blockIdx.x * param.rows_per_block;
+  const int my_row = base_row + my_local_row;
+  const index_t len = (length == nullptr ||
+                       my_row >= param.total_rows) ? param.num_elements
+                                                   : LengthType::from(length[my_row]);
+  const int my_id = threadIdx.x % threads_per_row;
+
+  AType* row;
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    // full rows_per_block rows to compute
+    VectorizedLoader<InputType0, nvec, aligned> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      total_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, total_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  } else {
+    // less than rows_per_block rows to compute
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedLoader<InputType0, nvec, false> loader(
+      reinterpret_cast<const InputType0*>(param.inputs[0]) + base_row * param.num_elements,
+      real_length);
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      loader.load(i, real_length);
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        persistent_storage[i*nvec + j] = IType::from(loader.separate()[j]);
+      }
+    }
+    row = persistent_storage + my_local_row * param.num_elements + loader.alignment();
+  }
+  __syncthreads();
+
+  AType my_max_value;
+  red::maximum::SetInitValue(my_max_value);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    my_max_value = ::max(my_max_value, negate ? -row[i] : row[i]);
+  }
+  AType smax;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_max_value;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] = ::max(scratch[threadIdx.x], scratch[threadIdx.x + size]);
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return op::max(x, y); },
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+    smax = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+    smax = util::grouped_warp_allreduce(my_max_value,
+                                        [](AType x, AType y) { return op::max(x, y); },
+                                        threads_per_row);
+  }
+
+  AType my_sum;
+  red::sum::SetInitValue(my_sum);
+
+  for (index_t i = my_id; i < len; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    my_sum += op::exp((val - smax) / static_cast<AType>(param.temperature));
+  }
+  AType ssum;
+  if (!reduction_inside_warp) {
+    scratch[threadIdx.x] = my_sum;
+    __syncthreads();
+    for (int size = threads_per_row / 2; size >= warp_size; size /= 2) {
+      if (my_id < size) {
+        scratch[threadIdx.x] += scratch[threadIdx.x + size];
+      }
+      __syncthreads();
+    }
+    if (my_id < warp_size) {
+      AType my_value = util::grouped_warp_allreduce(scratch[threadIdx.x],
+                                                    [](AType x, AType y) { return x + y;},
+                                                    min(threads_per_row, warp_size));
+      scratch[threadIdx.x] = my_value;
+    }
+    __syncthreads();
+
+    ssum = scratch[threadIdx.x - my_id];
+    __syncthreads();
+  } else {
+      ssum = util::grouped_warp_allreduce(my_sum,
+                                          [](AType x, AType y) { return x + y;},
+                                          threads_per_row);
+  }
+
+  for (index_t i = my_id; i < param.num_elements; i += threads_per_row) {
+    const AType val = negate ? -row[i] : row[i];
+    row[i] = (i < len) ? OP((val - smax)/static_cast<AType>(param.temperature), ssum) :
+                         0;
+  }
+  __syncthreads();
+
+  if (only_full_blocks || blockIdx.x < gridDim.x - 1) {
+    VectorizedStorer<OutputType0, nvec, aligned> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      total_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, total_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, total_length);
+    }
+  } else {
+    const index_t real_length = min(total_length,
+                                    (param.total_rows - base_row) * param.num_elements);
+    VectorizedStorer<OutputType0, nvec, false> storer(
+      reinterpret_cast<OutputType0*>(param.outputs[0]) + base_row * param.num_elements,
+      real_length);
+
+    for (index_t i = threadIdx.x; i < num_aligned_elements; i += blockDim.x) {
+      if (req == OpReqType::kAddTo) {
+        storer.load(i, real_length);
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(op::add(persistent_storage[i*nvec + j],
+                                                   OType::from(storer.separate()[j])));
+        }
+      } else {
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          storer.separate()[j] = OType::to(persistent_storage[i*nvec + j]);
+        }
+      }
+      storer.store(i, real_length);
+    }
+  }
+}
+)code";
+
+bool IsPower2(size_t N) {
+  return ((N & (N - 1)) == 0) && N != 0;
+}
+
+index_t RoundToPower2(index_t N) {

Review comment:
       Moved to common/utils.h




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-825204848


   About the vectorization being independent from RTC - generally I agree with you and the first approach to vectorization was actually before RTC was introduced. There was a problem, however, in that using it produced quite a lot of kernels bloating the library size and increasing the GPU memory usage (see PR #17767 and then issue https://github.com/apache/incubator-mxnet/issues/18280). That is why I reintroduced it as part of the RTC effort to make sure that only the needed kernels get compiled.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-780195345


   Hey @ptrendx , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [unix-cpu, centos-cpu, unix-gpu, windows-gpu, clang, website, centos-gpu, edge, miscellaneous, sanity, windows-cpu]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826501721


   @mxnet-bot run ci [unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826501754


   Jenkins CI successfully triggered : [unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#issuecomment-826030716


   Jenkins CI successfully triggered : [centos-gpu, unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] MoisesHer merged pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
MoisesHer merged pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19905: [PERF] Moving GPU softmax to RTC and optimizations

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19905:
URL: https://github.com/apache/incubator-mxnet/pull/19905#discussion_r592661320



##########
File path: src/operator/nn/softmax.cu
##########
@@ -22,18 +22,809 @@
  * \file softmax.cu
  * \brief GPU Implementation of softmax
  */
+#include <string>
 #include "./softmax-inl.h"
-#include "../tensor/elemwise_unary_op.h"
+#include "../../common/cuda/utils.h"
+#include "../../common/cuda/rtc.h"
+#include "../../common/cuda/rtc/vectorization-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace {
+
+struct softmax_params {
+  const void* inputs[3];

Review comment:
       This is to have the same struct for both fwd and bwd (which uses 3 inputs).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org