You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/07/19 07:48:18 UTC

[GitHub] [incubator-mxnet] apeforest commented on a change in pull request #15545: Softmax fwd optimization for GPU

apeforest commented on a change in pull request #15545: Softmax fwd optimization for GPU
URL: https://github.com/apache/incubator-mxnet/pull/15545#discussion_r305238144
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -218,6 +219,157 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi
   }
 }
 
+const int softmax_threads_per_block = 512;
+
+template <typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+  return value;
+}
+
+template <typename OP>
+__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
+  float v = static_cast<float>(value);
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+  return mshadow::half::half_t(v);
+}
+
+template<typename OP, bool negate, typename AType, typename LType,
+  typename DType, typename OType>
+__global__ void softmax_compute_kernel2(const DType *in, OType *out, const index_t M,
+                                       const double temperature, int rows_per_block,
+                                       const index_t total_rows) {
+  __shared__ AType scratch[softmax_threads_per_block];
+  __shared__ LType persistent_storage[20*1024 / sizeof(LType)];
 
 Review comment:
   nit: space between '*' and operands

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services