You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/07/19 20:22:57 UTC

[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #15545: Softmax fwd optimization for GPU

ptrendx commented on a change in pull request #15545: Softmax fwd optimization for GPU
URL: https://github.com/apache/incubator-mxnet/pull/15545#discussion_r305512405
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -218,6 +219,157 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi
   }
 }
 
+const int softmax_threads_per_block = 512;
+
+template <typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+  return value;
+}
+
+template <typename OP>
+__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
+  float v = static_cast<float>(value);
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+  return mshadow::half::half_t(v);
+}
+
+template<typename OP, bool negate, typename AType, typename LType,
+  typename DType, typename OType>
+__global__ void softmax_compute_kernel2(const DType *in, OType *out, const index_t M,
+                                       const double temperature, int rows_per_block,
+                                       const index_t total_rows) {
+  __shared__ AType scratch[softmax_threads_per_block];
+  __shared__ LType persistent_storage[20*1024 / sizeof(LType)];
+  const int warp_size = 32;
+  const int threads_per_row = softmax_threads_per_block / rows_per_block;
+  const int my_local_row = threadIdx.x / threads_per_row;
+  const int my_row = blockIdx.x * rows_per_block + my_local_row;
+  if (my_row >= total_rows) return;
+  const int my_id = threadIdx.x % threads_per_row;
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  // Due to usage of MSHADOW_TYPE_SWITCH macro we are generating
+  // kernels where sizeof(LType) may be less than sizeof(DType),
+  // resulting in entries_per_load being 0.
+  // This is not a valid combination and is being checked against
+  // in the launcher code. This switch here is just to silence
+  // the division by zero warning generated for such invalid cases.
+  const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;
+
+  const LType * in_aligned = reinterpret_cast<const LType *>(in);
 
 Review comment:
   @sxjscience That is why the code that launches this kernel chooses LType based on the array dimensions - if the leading dimension is odd it will not choose LType larger than DType.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services