You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/08/18 06:09:21 UTC
[incubator-mxnet] branch master updated: Faster GPU frozen
BatchNorm (#17368)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e06ee4e Faster GPU frozen BatchNorm (#17368)
e06ee4e is described below
commit e06ee4e1a725aed62a66b525036676500010e7f0
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon Aug 17 23:07:37 2020 -0700
Faster GPU frozen BatchNorm (#17368)
* Better frozen batchnorm
* Continue FreezeBN
* Optimizations
* Reduce number of mod operations
* Cleaning
* Fixing frozen bn with fix_gamma=False
* Fix lint in BN
* Backward frozen batchnorm
* More work on backward of Frozen BN
* Let it compile
* NCHW Frozen BN backward
* Frozen BN backward NHWC
* Cleaning
* Remove the change to Makefile
* Fix from rebase
* Temp space for BN backward
* Fix from review
* Fix lint
* Changes from review
---
src/common/cuda_utils.h | 83 ++++++-
src/operator/nn/batch_norm.cc | 2 -
src/operator/nn/batch_norm.cu | 563 ++++++++++++++++++++++++++++++++++++------
src/operator/nn/softmax-inl.h | 6 +-
4 files changed, 561 insertions(+), 93 deletions(-)
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 0971cfd..22ac42c 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -783,27 +783,86 @@ __device__ inline DType ldg(const DType* address) {
#endif
}
-template <typename OP, typename T>
+namespace mxnet {
+namespace common {
+/*! \brief common utils for cuda */
+namespace cuda {
+
+static constexpr const int warp_size = 32;
+
+/*! \brief Reduction inside a warp.
+ * Template parameters:
+ * NVALUES - number of values to reduce (defaults to warp_size).
+ * \param value - values to be reduced.
+ * \param redfun - function used to perform reduction.
+ */
+template <int NVALUES = warp_size, typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
- value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
- value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
- value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
- value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
- value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
+#pragma unroll
+ for (int i = warp_size / 2; i >= 1; i /= 2) {
+ if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
+ }
return value;
}
-template <typename OP>
+template <int NValues = warp_size, typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
- v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
- v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
- v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
- v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
- v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
+#pragma unroll
+ for (int i = warp_size / 2; i >= 1; i /= 2) {
+ if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
+ }
return mshadow::half::half_t(v);
}
+/*! \brief Reduction inside a block, requires all threads in a block to participate.
+ * It uses a 2 step approach:
+ * - all warps in a block perform intermediate reduction
+ * - first warp reduces the intermediate results.
+ * Template parameters:
+ * NTHREADS - number of threads in a block.
+ * all_reduce - whether all threads need the result of the reduction. If set to
+ * true, then all threads return with the same value. If set to
+ * false, then only thread 0 has the valid result. Defaults to true.
+ * \param value - value from each thread to be reduced
+ * \param redfun - function used to perform reduction
+ */
+template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
+__device__ inline T reduce(const T& value, OP redfun) {
+ static_assert(NTHREADS <= warp_size * warp_size,
+ "Number of threads too large for reduction");
+ __shared__ T scratch[NTHREADS / warp_size];
+ const int thread_idx_in_warp = threadIdx.x % warp_size;
+ const int warp_id = threadIdx.x / warp_size;
+ const T my_val = warp_reduce<warp_size>(value, redfun);
+ if (thread_idx_in_warp == 0) {
+ scratch[warp_id] = my_val;
+ }
+ __syncthreads();
+ T ret = 0;
+ if (warp_id == 0) {
+ const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0;
+ const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun);
+ if (all_reduce) {
+ scratch[threadIdx.x] = my_val;
+ } else {
+ ret = my_val;
+ }
+ }
+ // Necessary to synchronize in order to use this function again
+ // as the shared memory scratch space is reused between calls
+ __syncthreads();
+ if (all_reduce) {
+ ret = scratch[0];
+ __syncthreads();
+ }
+ return ret;
+}
+
+} // namespace cuda
+} // namespace common
+} // namespace mxnet
+
#endif // __CUDACC__
#endif // MXNET_COMMON_CUDA_UTILS_H_
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 2fdd31e..2a91a37 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -662,11 +662,9 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-#endif
.set_attr_parser(ParamParser<BatchNormParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index c7e991f..72e4a76 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -27,6 +27,8 @@
#include <cuda_runtime_api.h>
#include <algorithm>
#include "batch_norm-inl.h"
+#include "../../common/cuda_utils.h"
+
#define WRITE_DATA_FLAG 1
#define WRITE_GAMMA_FLAG 2
@@ -47,9 +49,30 @@
using namespace mxnet;
+namespace {
+
/*! \brief inverse standard deviation <-> variance */
-#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$)))
-#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
+template <typename DType, typename AccReal>
+MSHADOW_XINLINE AccReal variance_to_invstd(DType var, AccReal eps) {
+ return rsqrtf(static_cast<AccReal>(var) + eps);
+}
+
+template <>
+MSHADOW_XINLINE double variance_to_invstd(double var, double eps) {
+ return rsqrt(var + eps);
+}
+
+template <typename AccReal>
+MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) {
+ return static_cast<AccReal>(1.0) / (invstd * invstd) - eps;
+}
+
+template <>
+MSHADOW_XINLINE double invstd_to_variance(double invstd, double eps) {
+ return 1.0 / (invstd * invstd) - eps;
+}
+
+} // namespace
namespace mxnet {
namespace op {
@@ -206,41 +229,90 @@ static __device__ T reduce(Op op, DeviceTensor tensor, int plane) {
return shared[0];
}
-template <typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
+namespace {
+ constexpr int inference_forward_threads = 512;
+ constexpr int shmem_elements = 1536;
+} // namespace
+
+template <typename DType, typename AType, typename LType, bool small_num_channels>
+__launch_bounds__(inference_forward_threads)
__global__ void BatchNormalizationUpdateOutputInferenceKernel(
- DeviceTensor input,
- DeviceTensor output,
- DeviceTensor1 runningMean,
- DeviceTensor1 runningVar,
- DeviceTensor1 saveMean,
- DeviceTensor1 saveInvStd,
- DeviceTensor1 weight,
- DeviceTensor1 bias,
- const DType epsilon,
+ const DType* input,
+ DType* output,
+ const index_t size,
+ const index_t outer_size,
+ const index_t num_channels,
+ const index_t inner_size,
+ const AType* runningMean,
+ const AType* runningVar,
+ AType* saveMean,
+ AType* saveInvStd,
+ AType* weight,
+ AType* bias,
+ const AType epsilon,
const uint32_t flags) {
- int plane = blockIdx.x;
-
- AccReal invstd = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
- AccReal mean = ScalarConvert<DType, AccReal>::to(runningMean[plane]);
- AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
- ? ScalarConvert<DType, AccReal>::to(weight[plane])
- : ScalarConvert<int, AccReal>::to(1);
- AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
- : ScalarConvert<int, AccReal>::to(0);
- if (threadIdx.x == 0) {
- saveMean[plane] = runningMean[plane];
- saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
- if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
- && weight.numElements() > 0) {
- weight[plane] = AccReal(1);
+ constexpr int nvec = sizeof(LType) / sizeof(DType);
+ __shared__ AType saved_invstd[shmem_elements];
+ __shared__ AType saved_mean[shmem_elements];
+ __shared__ AType saved_weight[shmem_elements];
+ __shared__ AType saved_bias[shmem_elements];
+ union vectorized_loader {
+ LType aligned;
+ DType separate[nvec]; // NOLINT(*)
+
+ __device__ inline vectorized_loader() {}
+ __device__ inline ~vectorized_loader() {}
+ } scratch;
+
+ if (small_num_channels) {
+ for (int i = threadIdx.x; i < num_channels; i += blockDim.x) {
+ saved_invstd[i] = variance_to_invstd(runningVar[i], epsilon);
+ saved_mean[i] = runningMean[i];
+ saved_weight[i] = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0)
+ ? weight[i]
+ : 1;
+ saved_bias[i] = (bias != nullptr) ? bias[i] : 0;
}
+ __syncthreads();
}
- // Write normalized and update the output
- for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
- for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
- const DType inp = input.get_ref(batch, plane, x);
- output.get_ref(batch, plane, x) =
- ScalarConvert<AccReal, DType>::to(gamma * (inp - mean) * invstd + beta);
+
+ const index_t tid = threadIdx.x + blockIdx.x * blockDim.x;
+ const index_t stride = blockDim.x * gridDim.x;
+ const LType* input_aligned = reinterpret_cast<const LType*>(input);
+ LType* output_aligned = reinterpret_cast<LType*>(output);
+ for (index_t i = tid; i < size / nvec; i += stride) {
+ scratch.aligned = input_aligned[i];
+ const index_t my_channel_base = (nvec * i) % (inner_size * num_channels);
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ index_t my_channel = (my_channel_base + j) / inner_size;
+ if (my_channel >= num_channels) my_channel = my_channel % num_channels;
+ AType current_input = static_cast<AType>(scratch.separate[j]);
+
+ AType invstd = small_num_channels ? saved_invstd[my_channel]
+ : variance_to_invstd(runningVar[my_channel], epsilon);
+ AType mean = small_num_channels ? saved_mean[my_channel]
+ : runningMean[my_channel];
+ AType gamma = small_num_channels ? saved_weight[my_channel]
+ : ((weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0)
+ ? weight[my_channel]
+ : 1);
+ AType beta = small_num_channels ? saved_bias[my_channel]
+ : ((bias != nullptr) ? bias[my_channel]
+ : 0);
+ current_input = gamma * (current_input - mean) * invstd + beta;
+ scratch.separate[j] = current_input;
+ }
+
+ output_aligned[i] = scratch.aligned;
+
+ if (i < num_channels) {
+ saveMean[i] = runningMean[i];
+ saveInvStd[i] = variance_to_invstd(runningVar[i], epsilon);
+ if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
+ && weight != nullptr) {
+ weight[i] = 1;
+ }
}
}
}
@@ -312,6 +384,266 @@ struct CUDATensors {
DeviceTensor1 saveInvStd;
};
+namespace {
+ inline int ceil_div(int x, int y) {
+ return (x + y - 1) / y;
+ }
+} // namespace
+
+template<int NTHREADS, typename DType, typename AType, typename LType>
+__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1(
+ const DType* input, const DType* gradOutput, AType* temp_space,
+ DType* gradInput, const AType* weight, const AType* runningMean,
+ const AType* runningVar, const index_t outer, const index_t num_channels,
+ const AType eps, const uint32_t flags) {
+ using mxnet::common::cuda::warp_size;
+ constexpr int num_warps = NTHREADS / warp_size;
+ constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
+ const size_t stride = num_channels / nvec;
+
+ union vectorized_loader {
+ LType aligned;
+ DType separate[nvec]; // NOLINT(*)
+
+ __device__ inline vectorized_loader() {}
+ __device__ inline ~vectorized_loader() {}
+ };
+
+ vectorized_loader vec_input, vec_gradOutput;
+
+ __shared__ AType scratch[NTHREADS * 2 * nvec];
+ AType * my_values_gamma = &(scratch[threadIdx.x * nvec]);
+ AType * my_values_beta = &(scratch[(NTHREADS + threadIdx.x) * nvec]);
+
+ AType sum_gamma[nvec]; // NOLINT(*)
+ AType sum_beta[nvec]; // NOLINT(*)
+#pragma unroll
+ for (int i = 0; i < nvec; ++i) {
+ sum_gamma[i] = 0;
+ sum_beta[i] = 0;
+ }
+
+ const size_t offset = blockIdx.x * warp_size;
+ const int my_warp = threadIdx.x / warp_size;
+ const int thread_idx_in_warp = threadIdx.x % warp_size;
+
+ AType invstd[nvec]; // NOLINT(*)
+ AType mean[nvec]; // NOLINT(*)
+ AType gamma[nvec]; // NOLINT(*)
+ size_t channel_offset = (offset + thread_idx_in_warp) * nvec;
+
+ if (channel_offset < num_channels) {
+#pragma unroll
+ for (int i = 0; i < nvec; ++i) {
+ invstd[i] = variance_to_invstd(runningVar[channel_offset + i], eps);
+ mean[i] = runningMean[channel_offset + i];
+ gamma[i] = weight != nullptr ? weight[channel_offset + i] : 1;
+ }
+ }
+
+ const LType* aligned_gradOutput = reinterpret_cast<const LType*>(gradOutput);
+ const LType* aligned_input = reinterpret_cast<const LType*>(input);
+ LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
+
+ const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y;
+ const size_t start_row = my_warp + rows_per_block * blockIdx.y;
+ const size_t end_row = min(outer, static_cast<index_t>(rows_per_block * (blockIdx.y + 1)));
+ if (offset + thread_idx_in_warp < stride) {
+ for (size_t i = start_row; i < end_row; i += num_warps) {
+ const index_t idx = i * stride + offset + thread_idx_in_warp;
+ vec_gradOutput.aligned = aligned_gradOutput[idx];
+ vec_input.aligned = aligned_input[idx];
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ sum_beta[j] += static_cast<AType>(vec_gradOutput.separate[j]);
+ sum_gamma[j] += static_cast<AType>(vec_gradOutput.separate[j]) *
+ (static_cast<AType>(vec_input.separate[j]) - mean[j]);
+ }
+ if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
+ // Gradient to input
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ vec_gradOutput.separate[j] *= invstd[j] * gamma[j];
+ }
+ if (flags & ADDTO_DATA_FLAG) {
+ vec_input.aligned = gradInput_aligned[idx];
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ vec_gradOutput.separate[j] += vec_input.separate[j];
+ }
+ }
+ gradInput_aligned[idx] = vec_gradOutput.aligned;
+ }
+ }
+ }
+ __syncthreads();
+#pragma unroll
+ for (int i = 0; i < nvec; ++i) {
+ my_values_gamma[i] = sum_gamma[i];
+ my_values_beta[i] = sum_beta[i];
+ }
+
+ __syncthreads();
+
+ for (int i = num_warps / 2; i > 0; i /= 2) {
+ if (my_warp < i) {
+ const int shared_offset = nvec * i * warp_size;
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ my_values_gamma[j] += my_values_gamma[j + shared_offset];
+ my_values_beta[j] += my_values_beta[j + shared_offset];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x < min(warp_size * nvec,
+ static_cast<int>(num_channels - nvec * offset))) {
+ const size_t offset_out = nvec * offset +
+ blockIdx.y * num_channels;
+ const size_t offset_beta = gridDim.y * num_channels;
+ temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
+ temp_space[offset_beta + offset_out + threadIdx.x] = scratch[NTHREADS * nvec + threadIdx.x];
+ }
+}
+
+template <typename AType>
+__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType * temp_space,
+ const AType * runningVar,
+ AType * out_gamma,
+ AType * out_beta,
+ int lead_dim, int n_blocks,
+ AType epsilon, uint32_t flags) {
+ int tid = threadIdx.x + blockIdx.x * blockDim.x;
+ if (tid < lead_dim) {
+ AType sum_gamma = 0;
+ AType sum_beta = 0;
+ for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
+ sum_gamma += temp_space[i];
+ sum_beta += temp_space[i + lead_dim * n_blocks];
+ }
+ if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
+ if ((flags & FIX_GAMMA_FLAG) == 0) {
+ const AType invstd = variance_to_invstd(runningVar[tid], epsilon);
+ if (flags & WRITE_GAMMA_FLAG) {
+ out_gamma[tid] = sum_gamma * invstd;
+ } else {
+ out_gamma[tid] += sum_gamma * invstd;
+ }
+ } else {
+ if (flags & WRITE_GAMMA_FLAG) {
+ out_gamma[tid] = 0;
+ }
+ }
+ }
+ if (flags & WRITE_BETA_FLAG) {
+ out_beta[tid] = sum_beta;
+ } else if (flags & ADDTO_BETA_FLAG) {
+ out_beta[tid] += sum_beta;
+ }
+ }
+}
+
+template<int NTHREADS, typename DType, typename AType, typename LType>
+__global__ void FrozenBatchNormalizationBackwardKernel(
+ const DType* input,
+ const DType* gradOutput,
+ DType* gradInput,
+ AType* gradWeight,
+ AType* gradBias,
+ const AType* weight,
+ const AType* runningMean,
+ const AType* runningVar,
+ const index_t outer,
+ const index_t inner,
+ const index_t num_channels,
+ const index_t NHW_div_nvec,
+ const AType eps,
+ const uint32_t flags) {
+ const index_t my_channel = blockIdx.x;
+ const AType invstd = variance_to_invstd(runningVar[my_channel], eps);
+ const AType mean = runningMean[my_channel];
+ const AType gamma = weight != nullptr ? weight[my_channel] : 1;
+ constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType)
+ : 1;
+ union vectorized_loader {
+ LType aligned;
+ DType separate[nvec]; // NOLINT(*)
+
+ __device__ inline vectorized_loader() {}
+ __device__ inline ~vectorized_loader() {}
+ };
+
+ vectorized_loader vec_input, vec_gradOutput;
+
+ const LType* input_aligned = reinterpret_cast<const LType*>(input);
+ const LType* gradOutput_aligned = reinterpret_cast<const LType*>(gradOutput);
+ LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
+
+ const index_t inner_div_nvec = inner / nvec;
+
+ AType sum_gamma = 0;
+ AType sum_beta = 0;
+
+
+ for (index_t i = threadIdx.x; i < NHW_div_nvec; i += blockDim.x) {
+ const index_t inner_idx = i % inner_div_nvec;
+ const index_t outer_idx = i / inner_div_nvec;
+ const index_t idx = inner_idx +
+ (my_channel + outer_idx * num_channels) * inner_div_nvec;
+ vec_gradOutput.aligned = gradOutput_aligned[idx];
+ vec_input.aligned = input_aligned[idx];
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ sum_beta += static_cast<AType>(vec_gradOutput.separate[j]);
+ sum_gamma += static_cast<AType>(vec_gradOutput.separate[j]) *
+ (static_cast<AType>(vec_input.separate[j]) - mean);
+ }
+
+ if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
+ // Gradient to input
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ vec_gradOutput.separate[j] *= invstd * gamma;
+ }
+ if (flags & ADDTO_DATA_FLAG) {
+ vec_input.aligned = gradInput_aligned[idx];
+#pragma unroll
+ for (int j = 0; j < nvec; ++j) {
+ vec_gradOutput.separate[j] += vec_input.separate[j];
+ }
+ }
+ gradInput_aligned[idx] = vec_gradOutput.aligned;
+ }
+ }
+
+ sum_gamma = common::cuda::reduce<NTHREADS, false>(sum_gamma,
+ [](AType a, AType b) { return a + b; });
+ sum_beta = common::cuda::reduce<NTHREADS, false>(sum_beta,
+ [](AType a, AType b) { return a + b; });
+
+ if (threadIdx.x == 0) {
+ if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
+ if ((flags & FIX_GAMMA_FLAG) == 0) {
+ if (flags & WRITE_GAMMA_FLAG) {
+ gradWeight[my_channel] = sum_gamma * invstd;
+ } else {
+ gradWeight[my_channel] += sum_gamma * invstd;
+ }
+ } else {
+ if (flags & WRITE_GAMMA_FLAG) {
+ gradWeight[my_channel] = 0;
+ }
+ }
+ }
+ if (flags & WRITE_BETA_FLAG) {
+ gradBias[my_channel] = sum_beta;
+ } else if (flags & ADDTO_BETA_FLAG) {
+ gradBias[my_channel] += sum_beta;
+ }
+ }
+}
+
template<typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
static __global__ void BatchNormalizationBackwardKernel(
const DeviceTensor input,
@@ -320,21 +652,13 @@ static __global__ void BatchNormalizationBackwardKernel(
CUDATensors<DeviceTensor1> tensors,
const uint32_t flags,
const AccReal momentum,
- const double eps) {
+ const AccReal eps) {
int plane = blockIdx.x;
int N = gradOutput.OuterSize() * gradOutput.InnerSize();
- const bool is_train_and_not_global_stats =
- (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
-
AccReal mean, invstd;
- if (is_train_and_not_global_stats) {
- mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
- invstd = tensors.saveInvStd[plane];
- } else {
- mean = ScalarConvert<DType, AccReal>::to(tensors.runningMean[plane]);
- invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps);
- }
+ mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
+ invstd = tensors.saveInvStd[plane];
const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ?
ScalarConvert<DType, AccReal>::to(tensors.weight[plane]) : AccReal(1);
@@ -353,8 +677,8 @@ static __global__ void BatchNormalizationBackwardKernel(
const AccReal projScale = dotP * norm * invstd * invstd;
const AccReal gradScale = invstd * weightVal;
- if (threadIdx.x == 0 && is_train_and_not_global_stats) {
- const AccReal localVariance = INVSTD_TO_VARIANCE(tensors.saveInvStd[plane], eps);
+ if (threadIdx.x == 0) {
+ const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps);
const AccReal localMean = tensors.saveMean[plane];
// update running averages
@@ -370,15 +694,10 @@ static __global__ void BatchNormalizationBackwardKernel(
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
- if (is_train_and_not_global_stats) {
- const DType inp = input.get_ref(batch, plane, x);
- const AccReal proj = (inp - mean) * projScale;
- gradInput.get_ref(batch, plane, x) =
- ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
- } else {
- gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
- gradOut * gradScale);
- }
+ const DType inp = input.get_ref(batch, plane, x);
+ const AccReal proj = (inp - mean) * projScale;
+ gradInput.get_ref(batch, plane, x) =
+ ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
}
}
} else {
@@ -386,15 +705,10 @@ static __global__ void BatchNormalizationBackwardKernel(
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
- if (is_train_and_not_global_stats) {
- const DType inp = input.get_ref(batch, plane, x);
- const AccReal proj = (inp - mean) * projScale;
- gradInput.get_ref(batch, plane, x) +=
- ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
- } else {
- gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
- gradOut * gradScale);
- }
+ const DType inp = input.get_ref(batch, plane, x);
+ const AccReal proj = (inp - mean) * projScale;
+ gradInput.get_ref(batch, plane, x) +=
+ ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
}
}
}
@@ -537,13 +851,35 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream<gpu> *s,
DCHECK_GT(weight.numElements(), 0);
if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) {
- dim3 blocks(input.ChannelCount());
- dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize()));
- BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, DeviceTensor1,
- batchnorm::BNTensor3<DType>>
- <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
- input, output, runningMean, runningVar, saveMean,
- saveInvStd, weight, bias, eps, flags);
+ AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr;
+ AccReal* gamma_ptr = weight.numElements() > 0 ? weight.dptr_ : nullptr;
+ int nvec = sizeof(double) / sizeof(DType);
+ index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount();
+ index_t aligned_size = ((size + nvec - 1) / nvec) * nvec;
+ index_t blocks = std::min((size + nvec * inference_forward_threads - 1) /
+ (nvec * inference_forward_threads),
+ static_cast<index_t>(512));
+ if (input.ChannelCount() < shmem_elements) {
+ BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, true>
+ <<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ input.dptr_, output.dptr_,
+ aligned_size, input.OuterSize(),
+ input.ChannelCount(), input.InnerSize(),
+ runningMean.dptr_, runningVar.dptr_,
+ saveMean.dptr_, saveInvStd.dptr_,
+ gamma_ptr, bias_ptr,
+ eps, flags);
+ } else {
+ BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, false>
+ <<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ input.dptr_, output.dptr_,
+ aligned_size, input.OuterSize(),
+ input.ChannelCount(), input.InnerSize(),
+ runningMean.dptr_, runningVar.dptr_,
+ saveMean.dptr_, saveInvStd.dptr_,
+ gamma_ptr, bias_ptr,
+ eps, flags);
+ }
} else {
dim3 blocks(input.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize()));
@@ -588,11 +924,86 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu> *s,
tensors.saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
DCHECK_GT(tensors.weight.numElements(), 0);
- dim3 blocks(gradOutput.ChannelCount());
- dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
- BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
- <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
- input, gradOutput, gradInput, tensors, flags, momentum, eps);
+ const bool is_train_and_not_global_stats =
+ (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
+
+ if (is_train_and_not_global_stats) {
+#ifdef NDEBUG
+ constexpr bool SMALLER_THREADS = false;
+#else
+ constexpr bool SMALLER_THREADS = true;
+#endif
+ dim3 blocks(gradOutput.ChannelCount());
+ dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
+ BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
+ <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
+ input, gradOutput, gradInput, tensors, flags, momentum, eps);
+ } else {
+ uint32_t flags_copy = flags;
+ if (gradInput.Size() <= 0) {
+ flags_copy = (flags_copy & ~WRITE_DATA_FLAG);
+ }
+ if (tensors.gradWeight.numElements() <= 0) {
+ flags_copy = (flags_copy & ~WRITE_GAMMA_FLAG);
+ }
+ if (tensors.gradBias.numElements() <= 0) {
+ flags_copy = (flags_copy & ~WRITE_BETA_FLAG);
+ }
+ AccReal* gamma = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0)
+ ? tensors.weight.dptr_
+ : nullptr;
+
+ if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) {
+ const int C = gradOutput.ChannelCount();
+ int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType));
+ const int M = gradOutput.OuterSize();
+ MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+ const unsigned int blocks_x = ceil_div(C * sizeof(DType),
+ mxnet::common::cuda::warp_size * sizeof(LType));
+ const unsigned int preferred_number_of_blocks = 2 *
+ MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
+ const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
+ const dim3 n_blocks = {blocks_x, blocks_y, 1};
+ auto scratch_space = ctx.requested[batchnorm::kTempSpace]
+ .get_space_typed<gpu, 1, AccReal>(mshadow::Shape1(C * blocks_y * 2),
+ s);
+ auto stream = mshadow::Stream<gpu>::GetStream(s);
+ constexpr int nthreads_phase1 = 512;
+ constexpr int nthreads_phase2 = 128;
+ FrozenBatchNormalizationBackwardKernelCLastPhase1<nthreads_phase1, DType, AccReal, LType>
+ <<<n_blocks, nthreads_phase1, 0, stream>>>(input.dptr_, gradOutput.dptr_,
+ scratch_space.dptr_,
+ gradInput.dptr_,
+ gamma,
+ tensors.runningMean.dptr_,
+ tensors.runningVar.dptr_,
+ M, C, eps, flags_copy);
+ const int nblocks_phase2 = ceil_div(C, nthreads_phase2);
+ FrozenBatchNormalizationBackwardKernelCLastPhase2<AccReal>
+ <<<nblocks_phase2, nthreads_phase2, 0, stream>>>(scratch_space.dptr_,
+ tensors.runningVar.dptr_,
+ tensors.gradWeight.dptr_,
+ tensors.gradBias.dptr_, C,
+ blocks_y, eps, flags_copy);
+ });
+ } else {
+ dim3 blocks(gradOutput.ChannelCount());
+ int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType));
+ MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+ constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
+ const index_t NHW_div_nvec = gradOutput.OuterSize() * gradOutput.InnerSize() / nvec;
+ constexpr int threads = 512;
+ FrozenBatchNormalizationBackwardKernel<threads, DType, AccReal, LType>
+ <<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
+ input.dptr_, gradOutput.dptr_, gradInput.dptr_,
+ tensors.gradWeight.dptr_, tensors.gradBias.dptr_,
+ gamma, tensors.runningMean.dptr_,
+ tensors.runningVar.dptr_,
+ gradOutput.OuterSize(), gradOutput.InnerSize(),
+ gradOutput.ChannelCount(), NHW_div_nvec, eps, flags_copy);
+ });
+ }
+ }
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward);
}
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 9a67e82..ee27006 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -350,7 +350,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
__syncthreads();
}
if (my_id < warp_size) {
- AType my_value = warp_reduce(scratch[threadIdx.x],
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
[](AType x, AType y) { return ::max(x, y); });
scratch[threadIdx.x] = my_value;
}
@@ -374,7 +374,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
__syncthreads();
}
if (my_id < warp_size) {
- AType my_value = warp_reduce(scratch[threadIdx.x],
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
[](AType x, AType y) { return x + y;});
scratch[threadIdx.x] = my_value;
}
@@ -488,7 +488,7 @@ __global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd
__syncthreads();
}
if (my_id < warp_size) {
- AType my_value = warp_reduce(scratch[threadIdx.x],
+ AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x],
[](AType x, AType y) { return x + y; });
scratch[threadIdx.x] = my_value;
}