You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/07/17 16:32:46 UTC

[GitHub] [incubator-mxnet] haojin2 commented on a change in pull request #15545: Softmax fwd optimization for GPU

haojin2 commented on a change in pull request #15545: Softmax fwd optimization for GPU
URL: https://github.com/apache/incubator-mxnet/pull/15545#discussion_r304520971
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -218,6 +219,157 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axi
   }
 }
 
+const int softmax_threads_per_block = 512;
+
+template <typename OP, typename T>
+__device__ inline T warp_reduce(T value, OP redfun) {
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
+  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+  return value;
+}
+
+template <typename OP>
+__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
+  float v = static_cast<float>(value);
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
+  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+  return mshadow::half::half_t(v);
+}
+
+template<typename OP, bool negate, typename AType, typename LType,
+  typename DType, typename OType>
+__global__ void softmax_compute_kernel2(const DType *in, OType *out, const index_t M,
 
 Review comment:
   +1 to haibin's comment, maybe something like `softmax_stride_1_compute_kernel` ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services