You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/08/18 06:09:21 UTC

[incubator-mxnet] branch master updated: Faster GPU frozen BatchNorm (#17368)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e06ee4e  Faster GPU frozen BatchNorm (#17368)
e06ee4e is described below

commit e06ee4e1a725aed62a66b525036676500010e7f0
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon Aug 17 23:07:37 2020 -0700

    Faster GPU frozen BatchNorm (#17368)
    
    * Better frozen batchnorm
    
    * Continue FreezeBN
    
    * Optimizations
    
    * Reduce number of mod operations
    
    * Cleaning
    
    * Fixing frozen bn with fix_gamma=False
    
    * Fix lint in BN
    
    * Backward frozen batchnorm
    
    * More work on backward of Frozen BN
    
    * Let it compile
    
    * NCHW Frozen BN backward
    
    * Frozen BN backward NHWC
    
    * Cleaning
    
    * Remove the change to Makefile
    
    * Fix from rebase
    
    * Temp space for BN backward
    
    * Fix from review
    
    * Fix lint
    
    * Changes from review
---
 src/common/cuda_utils.h       |  83 ++++++-
 src/operator/nn/batch_norm.cc |   2 -
 src/operator/nn/batch_norm.cu | 563 ++++++++++++++++++++++++++++++++++++------
 src/operator/nn/softmax-inl.h |   6 +-
 4 files changed, 561 insertions(+), 93 deletions(-)

diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 0971cfd..22ac42c 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -783,27 +783,86 @@ __device__ inline DType ldg(const DType* address) {
 #endif
 }
 
-template <typename OP, typename T>
+namespace mxnet {
+namespace common {
+/*! \brief common utils for cuda */
+namespace cuda {
+
+static constexpr const int warp_size = 32;
+
+/*! \brief Reduction inside a warp.
+ * Template parameters:
+ * NVALUES - number of values to reduce (defaults to warp_size).
+ * \param value - values to be reduced.
+ * \param redfun - function used to perform reduction.
+ */
+template <int NVALUES = warp_size, typename OP, typename T>
 __device__ inline T warp_reduce(T value, OP redfun) {
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
-  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+#pragma unroll
+  for (int i = warp_size / 2; i >= 1; i /= 2) {
+    if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+  }
   return value;
 }
 
-template <typename OP>
+template <int NValues = warp_size, typename OP>
 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
   float v = static_cast<float>(value);
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
-  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+#pragma unroll
+  for (int i = warp_size / 2; i >= 1; i /= 2) {
+    if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
+  }
   return mshadow::half::half_t(v);
 }
 
+/*! \brief Reduction inside a block, requires all threads in a block to participate.
+ *         It uses a 2 step approach:
+ *          - all warps in a block perform intermediate reduction
+ *          - first warp reduces the intermediate results.
+ * Template parameters:
+ * NTHREADS - number of threads in a block.
+ * all_reduce - whether all threads need the result of the reduction. If set to
+ *              true, then all threads return with the same value. If set to
+ *              false, then only thread 0 has the valid result. Defaults to true.
+ * \param value - value from each thread to be reduced
+ * \param redfun - function used to perform reduction
+ */
+template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
+__device__ inline T reduce(const T& value, OP redfun) {
+  static_assert(NTHREADS <= warp_size * warp_size,
+                "Number of threads too large for reduction");
+  __shared__ T scratch[NTHREADS / warp_size];
+  const int thread_idx_in_warp = threadIdx.x % warp_size;
+  const int warp_id = threadIdx.x / warp_size;
+  const T my_val = warp_reduce<warp_size>(value, redfun);
+  if (thread_idx_in_warp == 0) {
+    scratch[warp_id] = my_val;
+  }
+  __syncthreads();
+  T ret = 0;
+  if (warp_id == 0) {
+    const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0;
+    const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun);
+    if (all_reduce) {
+      scratch[threadIdx.x] = my_val;
+    } else {
+      ret = my_val;
+    }
+  }
+  // Necessary to synchronize in order to use this function again
+  // as the shared memory scratch space is reused between calls
+  __syncthreads();
+  if (all_reduce) {
+    ret = scratch[0];
+    __syncthreads();
+  }
+  return ret;
+}
+
+}  // namespace cuda
+}  // namespace common
+}  // namespace mxnet
+
 #endif  // __CUDACC__
 
 #endif  // MXNET_COMMON_CUDA_UTILS_H_
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 2fdd31e..2a91a37 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -662,11 +662,9 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
 })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#if MXNET_USE_MKLDNN == 1
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
-#endif
 .set_attr_parser(ParamParser<BatchNormParam>)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index c7e991f..72e4a76 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -27,6 +27,8 @@
 #include <cuda_runtime_api.h>
 #include <algorithm>
 #include "batch_norm-inl.h"
+#include "../../common/cuda_utils.h"
+
 
 #define WRITE_DATA_FLAG       1
 #define WRITE_GAMMA_FLAG      2
@@ -47,9 +49,30 @@
 
 using namespace mxnet;
 
+namespace {
+
 /*! \brief inverse standard deviation <-> variance */
-#define VARIANCE_TO_INVSTD(__var$,    __eps$)   (1.0/sqrt((__var$) + DType(__eps$)))
-#define INVSTD_TO_VARIANCE(__invstd$, __eps$)   ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
+template <typename DType, typename AccReal>
+MSHADOW_XINLINE AccReal variance_to_invstd(DType var, AccReal eps) {
+  return rsqrtf(static_cast<AccReal>(var) + eps);
+}
+
+template <>
+MSHADOW_XINLINE double variance_to_invstd(double var, double eps) {
+  return rsqrt(var + eps);
+}
+
+template <typename AccReal>
+MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) {
+  return static_cast<AccReal>(1.0) / (invstd * invstd) - eps;
+}
+
+template <>
+MSHADOW_XINLINE double invstd_to_variance(double invstd, double eps) {
+  return 1.0 / (invstd * invstd) - eps;
+}
+
+}  // namespace
 
 namespace mxnet {
 namespace op {
@@ -206,41 +229,90 @@ static __device__ T reduce(Op op, DeviceTensor tensor, int plane) {
   return shared[0];
 }
 
-template <typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
+namespace {
+  constexpr int inference_forward_threads = 512;
+  constexpr int shmem_elements = 1536;
+}  // namespace
+
+template <typename DType, typename AType, typename LType, bool small_num_channels>
+__launch_bounds__(inference_forward_threads)
 __global__ void BatchNormalizationUpdateOutputInferenceKernel(
-  DeviceTensor input,
-  DeviceTensor output,
-  DeviceTensor1 runningMean,
-  DeviceTensor1 runningVar,
-  DeviceTensor1 saveMean,
-  DeviceTensor1 saveInvStd,
-  DeviceTensor1 weight,
-  DeviceTensor1 bias,
-  const DType epsilon,
+  const DType* input,
+  DType* output,
+  const index_t size,
+  const index_t outer_size,
+  const index_t num_channels,
+  const index_t inner_size,
+  const AType* runningMean,
+  const AType* runningVar,
+  AType* saveMean,
+  AType* saveInvStd,
+  AType* weight,
+  AType* bias,
+  const AType epsilon,
   const uint32_t flags) {
-  int plane = blockIdx.x;
-
-  AccReal invstd = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
-  AccReal mean = ScalarConvert<DType, AccReal>::to(runningMean[plane]);
-  AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
-                  ? ScalarConvert<DType, AccReal>::to(weight[plane])
-                  : ScalarConvert<int, AccReal>::to(1);
-  AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
-                                        : ScalarConvert<int, AccReal>::to(0);
-  if (threadIdx.x == 0) {
-    saveMean[plane] = runningMean[plane];
-    saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
-    if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
-        && weight.numElements() > 0) {
-      weight[plane] = AccReal(1);
+  constexpr int nvec = sizeof(LType) / sizeof(DType);
+  __shared__ AType saved_invstd[shmem_elements];
+  __shared__ AType saved_mean[shmem_elements];
+  __shared__ AType saved_weight[shmem_elements];
+  __shared__ AType saved_bias[shmem_elements];
+  union vectorized_loader {
+    LType aligned;
+    DType separate[nvec];  // NOLINT(*)
+
+    __device__ inline vectorized_loader() {}
+    __device__ inline ~vectorized_loader() {}
+  } scratch;
+
+  if (small_num_channels) {
+    for (int i = threadIdx.x; i < num_channels; i += blockDim.x) {
+      saved_invstd[i] = variance_to_invstd(runningVar[i], epsilon);
+      saved_mean[i] = runningMean[i];
+      saved_weight[i] = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0)
+                        ? weight[i]
+                        : 1;
+      saved_bias[i] = (bias != nullptr) ? bias[i] : 0;
     }
+    __syncthreads();
   }
-  // Write normalized and update the output
-  for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
-    for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
-      const DType inp = input.get_ref(batch, plane, x);
-      output.get_ref(batch, plane, x) =
-        ScalarConvert<AccReal, DType>::to(gamma * (inp - mean) * invstd + beta);
+
+  const index_t tid = threadIdx.x + blockIdx.x * blockDim.x;
+  const index_t stride = blockDim.x * gridDim.x;
+  const LType* input_aligned = reinterpret_cast<const LType*>(input);
+  LType* output_aligned = reinterpret_cast<LType*>(output);
+  for (index_t i = tid; i < size / nvec; i += stride) {
+    scratch.aligned = input_aligned[i];
+      const index_t my_channel_base = (nvec * i) % (inner_size * num_channels);
+#pragma unroll
+    for (int j = 0; j < nvec; ++j) {
+      index_t my_channel = (my_channel_base + j) / inner_size;
+      if (my_channel >= num_channels) my_channel = my_channel % num_channels;
+      AType current_input = static_cast<AType>(scratch.separate[j]);
+
+      AType invstd = small_num_channels ? saved_invstd[my_channel]
+                                        : variance_to_invstd(runningVar[my_channel], epsilon);
+      AType mean = small_num_channels ? saved_mean[my_channel]
+                                      : runningMean[my_channel];
+      AType gamma = small_num_channels ? saved_weight[my_channel]
+                                       : ((weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0)
+                                          ? weight[my_channel]
+                                          : 1);
+      AType beta = small_num_channels ? saved_bias[my_channel]
+                                      : ((bias != nullptr) ? bias[my_channel]
+                                                           : 0);
+      current_input = gamma * (current_input - mean) * invstd + beta;
+      scratch.separate[j] = current_input;
+    }
+
+    output_aligned[i] = scratch.aligned;
+
+    if (i < num_channels) {
+      saveMean[i] = runningMean[i];
+      saveInvStd[i] = variance_to_invstd(runningVar[i], epsilon);
+      if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
+          && weight != nullptr) {
+        weight[i] = 1;
+      }
     }
   }
 }
@@ -312,6 +384,266 @@ struct CUDATensors {
   DeviceTensor1 saveInvStd;
 };
 
+namespace {
+  inline int ceil_div(int x, int y) {
+    return (x + y - 1) / y;
+  }
+}  // namespace
+
+template<int NTHREADS, typename DType, typename AType, typename LType>
+__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1(
+    const DType* input, const DType* gradOutput, AType* temp_space,
+    DType* gradInput, const AType* weight, const AType* runningMean,
+    const AType* runningVar, const index_t outer, const index_t num_channels,
+    const AType eps, const uint32_t flags) {
+  using mxnet::common::cuda::warp_size;
+  constexpr int num_warps = NTHREADS / warp_size;
+  constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
+  const size_t stride = num_channels / nvec;
+
+  union vectorized_loader {
+    LType aligned;
+    DType separate[nvec];  // NOLINT(*)
+
+    __device__ inline vectorized_loader() {}
+    __device__ inline ~vectorized_loader() {}
+  };
+
+  vectorized_loader vec_input, vec_gradOutput;
+
+  __shared__ AType scratch[NTHREADS * 2 * nvec];
+  AType * my_values_gamma = &(scratch[threadIdx.x * nvec]);
+  AType * my_values_beta = &(scratch[(NTHREADS + threadIdx.x) * nvec]);
+
+  AType sum_gamma[nvec];  // NOLINT(*)
+  AType sum_beta[nvec];  // NOLINT(*)
+#pragma unroll
+  for (int i = 0; i < nvec; ++i) {
+    sum_gamma[i] = 0;
+    sum_beta[i] = 0;
+  }
+
+  const size_t offset = blockIdx.x * warp_size;
+  const int my_warp = threadIdx.x / warp_size;
+  const int thread_idx_in_warp = threadIdx.x % warp_size;
+
+  AType invstd[nvec];  // NOLINT(*)
+  AType mean[nvec];  // NOLINT(*)
+  AType gamma[nvec];  // NOLINT(*)
+  size_t channel_offset = (offset + thread_idx_in_warp) * nvec;
+
+  if (channel_offset < num_channels) {
+#pragma unroll
+    for (int i = 0; i < nvec; ++i) {
+      invstd[i] = variance_to_invstd(runningVar[channel_offset + i], eps);
+      mean[i] = runningMean[channel_offset + i];
+      gamma[i] = weight != nullptr ? weight[channel_offset + i] : 1;
+    }
+  }
+
+  const LType* aligned_gradOutput = reinterpret_cast<const LType*>(gradOutput);
+  const LType* aligned_input = reinterpret_cast<const LType*>(input);
+  LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
+
+  const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y;
+  const size_t start_row = my_warp + rows_per_block * blockIdx.y;
+  const size_t end_row = min(outer, static_cast<index_t>(rows_per_block * (blockIdx.y + 1)));
+  if (offset + thread_idx_in_warp < stride) {
+    for (size_t i = start_row; i < end_row; i += num_warps) {
+      const index_t idx = i * stride + offset + thread_idx_in_warp;
+      vec_gradOutput.aligned = aligned_gradOutput[idx];
+      vec_input.aligned = aligned_input[idx];
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        sum_beta[j]  += static_cast<AType>(vec_gradOutput.separate[j]);
+        sum_gamma[j] += static_cast<AType>(vec_gradOutput.separate[j]) *
+                        (static_cast<AType>(vec_input.separate[j]) - mean[j]);
+      }
+      if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
+        // Gradient to input
+#pragma unroll
+        for (int j = 0; j < nvec; ++j) {
+          vec_gradOutput.separate[j] *= invstd[j] * gamma[j];
+        }
+        if (flags & ADDTO_DATA_FLAG) {
+           vec_input.aligned = gradInput_aligned[idx];
+#pragma unroll
+           for (int j = 0; j < nvec; ++j) {
+             vec_gradOutput.separate[j] += vec_input.separate[j];
+           }
+        }
+        gradInput_aligned[idx] = vec_gradOutput.aligned;
+      }
+    }
+  }
+  __syncthreads();
+#pragma unroll
+  for (int i = 0; i < nvec; ++i) {
+    my_values_gamma[i] = sum_gamma[i];
+    my_values_beta[i] = sum_beta[i];
+  }
+
+  __syncthreads();
+
+  for (int i = num_warps / 2; i > 0; i /= 2) {
+    if (my_warp < i) {
+      const int shared_offset = nvec * i * warp_size;
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        my_values_gamma[j] += my_values_gamma[j + shared_offset];
+        my_values_beta[j] += my_values_beta[j + shared_offset];
+      }
+    }
+    __syncthreads();
+  }
+
+  if (threadIdx.x < min(warp_size * nvec,
+                        static_cast<int>(num_channels - nvec * offset))) {
+    const size_t offset_out = nvec * offset +
+                              blockIdx.y * num_channels;
+    const size_t offset_beta = gridDim.y * num_channels;
+    temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
+    temp_space[offset_beta + offset_out + threadIdx.x] = scratch[NTHREADS * nvec + threadIdx.x];
+  }
+}
+
+template <typename AType>
+__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType * temp_space,
+                                                                  const AType * runningVar,
+                                                                  AType * out_gamma,
+                                                                  AType * out_beta,
+                                                                  int lead_dim, int n_blocks,
+                                                                  AType epsilon, uint32_t flags) {
+  int tid = threadIdx.x + blockIdx.x * blockDim.x;
+  if (tid < lead_dim) {
+    AType sum_gamma = 0;
+    AType sum_beta = 0;
+    for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
+      sum_gamma += temp_space[i];
+      sum_beta += temp_space[i + lead_dim * n_blocks];
+    }
+    if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
+      if ((flags & FIX_GAMMA_FLAG) == 0) {
+        const AType invstd = variance_to_invstd(runningVar[tid], epsilon);
+        if (flags & WRITE_GAMMA_FLAG) {
+          out_gamma[tid] = sum_gamma * invstd;
+        } else {
+          out_gamma[tid] += sum_gamma * invstd;
+        }
+      } else {
+        if (flags & WRITE_GAMMA_FLAG) {
+          out_gamma[tid] = 0;
+        }
+      }
+    }
+    if (flags & WRITE_BETA_FLAG) {
+      out_beta[tid] = sum_beta;
+    } else if (flags & ADDTO_BETA_FLAG) {
+      out_beta[tid] += sum_beta;
+    }
+  }
+}
+
+template<int NTHREADS, typename DType, typename AType, typename LType>
+__global__ void FrozenBatchNormalizationBackwardKernel(
+  const DType* input,
+  const DType* gradOutput,
+  DType* gradInput,
+  AType* gradWeight,
+  AType* gradBias,
+  const AType* weight,
+  const AType* runningMean,
+  const AType* runningVar,
+  const index_t outer,
+  const index_t inner,
+  const index_t num_channels,
+  const index_t NHW_div_nvec,
+  const AType eps,
+  const uint32_t flags) {
+  const index_t my_channel = blockIdx.x;
+  const AType invstd = variance_to_invstd(runningVar[my_channel], eps);
+  const AType mean = runningMean[my_channel];
+  const AType gamma = weight != nullptr ? weight[my_channel] : 1;
+  constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType)
+                                                     : 1;
+  union vectorized_loader {
+    LType aligned;
+    DType separate[nvec];  // NOLINT(*)
+
+    __device__ inline vectorized_loader() {}
+    __device__ inline ~vectorized_loader() {}
+  };
+
+  vectorized_loader vec_input, vec_gradOutput;
+
+  const LType* input_aligned = reinterpret_cast<const LType*>(input);
+  const LType* gradOutput_aligned = reinterpret_cast<const LType*>(gradOutput);
+  LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
+
+  const index_t inner_div_nvec = inner / nvec;
+
+  AType sum_gamma = 0;
+  AType sum_beta = 0;
+
+
+  for (index_t i = threadIdx.x; i < NHW_div_nvec; i += blockDim.x) {
+    const index_t inner_idx = i % inner_div_nvec;
+    const index_t outer_idx = i / inner_div_nvec;
+    const index_t idx = inner_idx +
+                        (my_channel + outer_idx * num_channels) * inner_div_nvec;
+    vec_gradOutput.aligned = gradOutput_aligned[idx];
+    vec_input.aligned = input_aligned[idx];
+#pragma unroll
+    for (int j = 0; j < nvec; ++j) {
+      sum_beta += static_cast<AType>(vec_gradOutput.separate[j]);
+      sum_gamma += static_cast<AType>(vec_gradOutput.separate[j]) *
+                   (static_cast<AType>(vec_input.separate[j]) - mean);
+    }
+
+    if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
+      // Gradient to input
+#pragma unroll
+      for (int j = 0; j < nvec; ++j) {
+        vec_gradOutput.separate[j] *= invstd * gamma;
+      }
+      if (flags & ADDTO_DATA_FLAG) {
+         vec_input.aligned = gradInput_aligned[idx];
+#pragma unroll
+         for (int j = 0; j < nvec; ++j) {
+           vec_gradOutput.separate[j] += vec_input.separate[j];
+         }
+      }
+      gradInput_aligned[idx] = vec_gradOutput.aligned;
+    }
+  }
+
+  sum_gamma = common::cuda::reduce<NTHREADS, false>(sum_gamma,
+                                                    [](AType a, AType b) { return a + b; });
+  sum_beta = common::cuda::reduce<NTHREADS, false>(sum_beta,
+                                                   [](AType a, AType b) { return a + b; });
+
+  if (threadIdx.x == 0) {
+    if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
+      if ((flags & FIX_GAMMA_FLAG) == 0) {
+        if (flags & WRITE_GAMMA_FLAG) {
+          gradWeight[my_channel] = sum_gamma * invstd;
+        } else {
+          gradWeight[my_channel] += sum_gamma * invstd;
+        }
+      } else {
+        if (flags & WRITE_GAMMA_FLAG) {
+          gradWeight[my_channel] = 0;
+        }
+      }
+    }
+    if (flags & WRITE_BETA_FLAG) {
+      gradBias[my_channel] = sum_beta;
+    } else if (flags & ADDTO_BETA_FLAG) {
+      gradBias[my_channel] += sum_beta;
+    }
+  }
+}
+
 template<typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
 static __global__ void BatchNormalizationBackwardKernel(
   const DeviceTensor input,
@@ -320,21 +652,13 @@ static __global__ void BatchNormalizationBackwardKernel(
   CUDATensors<DeviceTensor1> tensors,
   const uint32_t flags,
   const AccReal momentum,
-  const double eps) {
+  const AccReal eps) {
   int plane = blockIdx.x;
   int N = gradOutput.OuterSize() * gradOutput.InnerSize();
 
-  const bool is_train_and_not_global_stats =
-    (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
-
   AccReal mean, invstd;
-  if (is_train_and_not_global_stats) {
-    mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
-    invstd = tensors.saveInvStd[plane];
-  } else {
-    mean = ScalarConvert<DType, AccReal>::to(tensors.runningMean[plane]);
-    invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps);
-  }
+  mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
+  invstd = tensors.saveInvStd[plane];
 
   const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ?
                       ScalarConvert<DType, AccReal>::to(tensors.weight[plane]) : AccReal(1);
@@ -353,8 +677,8 @@ static __global__ void BatchNormalizationBackwardKernel(
   const AccReal projScale = dotP * norm * invstd * invstd;
   const AccReal gradScale = invstd * weightVal;
 
-  if (threadIdx.x == 0 && is_train_and_not_global_stats) {
-    const AccReal localVariance = INVSTD_TO_VARIANCE(tensors.saveInvStd[plane], eps);
+  if (threadIdx.x == 0) {
+    const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps);
     const AccReal localMean = tensors.saveMean[plane];
 
     // update running averages
@@ -370,15 +694,10 @@ static __global__ void BatchNormalizationBackwardKernel(
       for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
         for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
           const DType gradOut = gradOutput.get_ref(batch, plane, x);
-          if (is_train_and_not_global_stats) {
-            const DType inp = input.get_ref(batch, plane, x);
-            const AccReal proj = (inp - mean) * projScale;
-            gradInput.get_ref(batch, plane, x) =
-              ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
-          } else {
-            gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
-              gradOut * gradScale);
-          }
+          const DType inp = input.get_ref(batch, plane, x);
+          const AccReal proj = (inp - mean) * projScale;
+          gradInput.get_ref(batch, plane, x) =
+            ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
         }
       }
     } else {
@@ -386,15 +705,10 @@ static __global__ void BatchNormalizationBackwardKernel(
       for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
         for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
           const DType gradOut = gradOutput.get_ref(batch, plane, x);
-          if (is_train_and_not_global_stats) {
-            const DType inp = input.get_ref(batch, plane, x);
-            const AccReal proj = (inp - mean) * projScale;
-            gradInput.get_ref(batch, plane, x) +=
-              ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
-          } else {
-            gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
-              gradOut * gradScale);
-          }
+          const DType inp = input.get_ref(batch, plane, x);
+          const AccReal proj = (inp - mean) * projScale;
+          gradInput.get_ref(batch, plane, x) +=
+            ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
         }
       }
     }
@@ -537,13 +851,35 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream<gpu> *s,
   DCHECK_GT(weight.numElements(), 0);
 
   if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) {
-    dim3 blocks(input.ChannelCount());
-    dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize()));
-    BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, DeviceTensor1,
-      batchnorm::BNTensor3<DType>>
-      <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
-      input, output, runningMean, runningVar, saveMean,
-        saveInvStd, weight, bias, eps, flags);
+    AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr;
+    AccReal* gamma_ptr = weight.numElements() > 0 ? weight.dptr_ : nullptr;
+    int nvec = sizeof(double) / sizeof(DType);
+    index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount();
+    index_t aligned_size = ((size + nvec - 1) / nvec) * nvec;
+    index_t blocks = std::min((size + nvec * inference_forward_threads - 1) /
+                              (nvec * inference_forward_threads),
+                              static_cast<index_t>(512));
+    if (input.ChannelCount() < shmem_elements) {
+      BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, true>
+        <<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+        input.dptr_, output.dptr_,
+        aligned_size, input.OuterSize(),
+        input.ChannelCount(), input.InnerSize(),
+        runningMean.dptr_, runningVar.dptr_,
+        saveMean.dptr_, saveInvStd.dptr_,
+        gamma_ptr, bias_ptr,
+        eps, flags);
+    } else {
+      BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, false>
+        <<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+        input.dptr_, output.dptr_,
+        aligned_size, input.OuterSize(),
+        input.ChannelCount(), input.InnerSize(),
+        runningMean.dptr_, runningVar.dptr_,
+        saveMean.dptr_, saveInvStd.dptr_,
+        gamma_ptr, bias_ptr,
+        eps, flags);
+    }
   } else {
     dim3 blocks(input.ChannelCount());
     dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize()));
@@ -588,11 +924,86 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu> *s,
   tensors.saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
 
   DCHECK_GT(tensors.weight.numElements(), 0);
-  dim3 blocks(gradOutput.ChannelCount());
-  dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
-  BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
-    <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
-    input, gradOutput, gradInput, tensors, flags, momentum, eps);
+  const bool is_train_and_not_global_stats =
+    (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
+
+  if (is_train_and_not_global_stats) {
+#ifdef NDEBUG
+    constexpr bool SMALLER_THREADS = false;
+#else
+    constexpr bool SMALLER_THREADS = true;
+#endif
+    dim3 blocks(gradOutput.ChannelCount());
+    dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
+    BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
+      <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
+      input, gradOutput, gradInput, tensors, flags, momentum, eps);
+  } else {
+    uint32_t flags_copy = flags;
+    if (gradInput.Size() <= 0) {
+      flags_copy = (flags_copy & ~WRITE_DATA_FLAG);
+    }
+    if (tensors.gradWeight.numElements() <= 0) {
+      flags_copy = (flags_copy & ~WRITE_GAMMA_FLAG);
+    }
+    if (tensors.gradBias.numElements() <= 0) {
+      flags_copy = (flags_copy & ~WRITE_BETA_FLAG);
+    }
+    AccReal* gamma = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0)
+                     ? tensors.weight.dptr_
+                     : nullptr;
+
+    if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) {
+      const int C = gradOutput.ChannelCount();
+      int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType));
+      const int M = gradOutput.OuterSize();
+      MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+        const unsigned int blocks_x = ceil_div(C * sizeof(DType),
+                                               mxnet::common::cuda::warp_size * sizeof(LType));
+        const unsigned int preferred_number_of_blocks = 2 *
+                                                        MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+        const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
+        const dim3 n_blocks = {blocks_x, blocks_y, 1};
+        auto scratch_space = ctx.requested[batchnorm::kTempSpace]
+                                .get_space_typed<gpu, 1, AccReal>(mshadow::Shape1(C * blocks_y * 2),
+                                                                                  s);
+        auto stream = mshadow::Stream<gpu>::GetStream(s);
+        constexpr int nthreads_phase1 = 512;
+        constexpr int nthreads_phase2 = 128;
+        FrozenBatchNormalizationBackwardKernelCLastPhase1<nthreads_phase1, DType, AccReal, LType>
+          <<<n_blocks, nthreads_phase1, 0, stream>>>(input.dptr_, gradOutput.dptr_,
+                                                     scratch_space.dptr_,
+                                                     gradInput.dptr_,
+                                                     gamma,
+                                                     tensors.runningMean.dptr_,
+                                                     tensors.runningVar.dptr_,
+                                                     M, C, eps, flags_copy);
+        const int nblocks_phase2 = ceil_div(C, nthreads_phase2);
+        FrozenBatchNormalizationBackwardKernelCLastPhase2<AccReal>
+          <<<nblocks_phase2, nthreads_phase2, 0, stream>>>(scratch_space.dptr_,
+                                                           tensors.runningVar.dptr_,
+                                                           tensors.gradWeight.dptr_,
+                                                           tensors.gradBias.dptr_, C,
+                                                           blocks_y, eps, flags_copy);
+      });
+    } else {
+    dim3 blocks(gradOutput.ChannelCount());
+    int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType));
+    MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+      constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
+      const index_t NHW_div_nvec = gradOutput.OuterSize() * gradOutput.InnerSize() / nvec;
+      constexpr int threads = 512;
+      FrozenBatchNormalizationBackwardKernel<threads, DType, AccReal, LType>
+        <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
+        input.dptr_, gradOutput.dptr_, gradInput.dptr_,
+        tensors.gradWeight.dptr_, tensors.gradBias.dptr_,
+        gamma, tensors.runningMean.dptr_,
+        tensors.runningVar.dptr_,
+        gradOutput.OuterSize(), gradOutput.InnerSize(),
+        gradOutput.ChannelCount(), NHW_div_nvec, eps, flags_copy);
+      });
+    }
+  }
   MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward);
 }
 
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 9a67e82..ee27006 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -350,7 +350,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
     __syncthreads();
   }
   if (my_id < warp_size) {
-    AType my_value = warp_reduce(scratch[threadIdx.x],
+    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
                                  [](AType x, AType y) { return ::max(x, y); });
     scratch[threadIdx.x] = my_value;
   }
@@ -374,7 +374,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
     __syncthreads();
   }
   if (my_id < warp_size) {
-    AType my_value = warp_reduce(scratch[threadIdx.x],
+    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
                                  [](AType x, AType y) { return x + y;});
     scratch[threadIdx.x] = my_value;
   }
@@ -488,7 +488,7 @@ __global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd
     __syncthreads();
   }
   if (my_id < warp_size) {
-    AType my_value = warp_reduce(scratch[threadIdx.x],
+    AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
                                  [](AType x, AType y) { return x + y; });
     scratch[threadIdx.x] = my_value;
   }