You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2021/09/30 17:12:03 UTC
[incubator-mxnet] branch master updated: Fast cuDNN BatchNorm NHWC
kernels support (#20615)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 23af413 Fast cuDNN BatchNorm NHWC kernels support (#20615)
23af413 is described below
commit 23af413d848c7a2d1cd6b20c80f9d4d8346c180a
Author: Vladimir Cherepanov <56...@users.noreply.github.com>
AuthorDate: Thu Sep 30 10:10:18 2021 -0700
Fast cuDNN BatchNorm NHWC kernels support (#20615)
* Fast cuDNN NHWC kernels support
* Fix lint errors
* Get rid of a warning
* Remove CuDNNBatchNorm from AMP lists
Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
python/mxnet/amp/lists/symbol_fp16.py | 5 -
src/operator/nn/batch_norm.cc | 2 +-
src/operator/nn/batch_norm.cu | 32 +--
src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 307 ---------------------------
src/operator/nn/cudnn/cudnn_batch_norm.cc | 125 -----------
src/operator/nn/cudnn/cudnn_batch_norm.cu | 210 ++++++++++++++++++
src/operator/nn/cudnn/cudnn_batch_norm.h | 56 +++++
7 files changed, 274 insertions(+), 463 deletions(-)
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index d942051..009586e 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -459,11 +459,6 @@ FP16_FP32_FUNCS = [
'zeros_like',
]
-if Features().is_enabled('CUDNN'):
- FP16_FP32_FUNCS.extend([
- 'CuDNNBatchNorm',
- ])
-
# Functions that have to be cast to FP32 due to possible
# overflows
FP32_FUNCS = [
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index fb12180..5a18363 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -649,11 +649,11 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
+#endif
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 7807691..195423b 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -39,7 +39,7 @@
#define ADDTO_BETA_FLAG (1 << 8)
#if MXNET_USE_CUDNN == 1
-#include "./cudnn/cudnn_batch_norm-inl.h"
+#include "./cudnn/cudnn_batch_norm.h"
#endif
#include "../../../include/mxnet/tensor_blob.h"
@@ -935,11 +935,6 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu>* s,
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
if (is_train_and_not_global_stats) {
-#ifdef NDEBUG
- constexpr bool SMALLER_THREADS = false;
-#else
- constexpr bool SMALLER_THREADS = true;
-#endif
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
@@ -1104,19 +1099,6 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu>* stream,
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}
-#if MXNET_USE_CUDNN == 1
-template <typename DType>
-static CuDNNBatchNormOp<DType>& GetCuDNNOp(const BatchNormParam& param) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local CuDNNBatchNormOp<DType> op;
-#else
- static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
-#endif
- op.Init(param);
- return op;
-}
-#endif
-
template <>
void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -1132,9 +1114,9 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
- if (!param.use_global_stats && !param.cudnn_off) {
- MSHADOW_REAL_TYPE_SWITCH(
- dtype, DType, { GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states); })
+ if (!param.use_global_stats && !param.cudnn_off &&
+ CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
+ CudnnBatchNormForward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
@@ -1160,9 +1142,9 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
- if (!param.use_global_stats && !param.cudnn_off) {
- MSHADOW_REAL_TYPE_SWITCH(
- dtype, DType, { GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); })
+ if (!param.use_global_stats && !param.cudnn_off &&
+ CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
+ CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
deleted file mode 100644
index 0f79430..0000000
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ /dev/null
@@ -1,307 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cudnn_batch_norm-inl.h
- * \brief
- * \author Junyuan Xie
- */
-
-#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
-#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
-#include <vector>
-#include <map>
-#include <string>
-#include <utility>
-#include "../batch_norm-inl.h"
-
-namespace mxnet {
-namespace op {
-#if MXNET_USE_CUDNN == 1
-namespace cudnnbatchnorm {
-enum CuDNNBatchNormOpInputs { kData, kGamma, kBeta };
-enum CuDNNBatchNormOpOutputs { kOut, kMean, kInvVar };
-enum CuDNNBatchNormOpAuxiliary { kMovingMean, kMovingInvVar };
-} // namespace cudnnbatchnorm
-
-#if defined(__CUDACC__)
-template <typename DType>
-class CuDNNBatchNormOp {
- STATIC_ASSERT_CUDNN_VERSION_GE(5000);
-
- public:
- CuDNNBatchNormOp() {
- using namespace mshadow;
- dtype_ = DataType<DType>::kCudnnFlag;
- // For float16 input type beta, gamma, mean, and average are stored in float32.
- // For other input types, these parameters have the same type as input
- dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType<DType>::kFlag;
- CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_));
- CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_));
- internal_aux_states_lock_ = false;
- }
-
- void Init(const BatchNormParam& param) {
- CHECK_GE(param.eps, CUDNN_BN_MIN_EPSILON)
- << "CuDNN requires eps to be no less than " << CUDNN_BN_MIN_EPSILON;
- this->param_ = param;
- }
-
- ~CuDNNBatchNormOp() {
- CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc_));
- CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc_));
- }
-
- void Forward(const OpContext& ctx,
- const std::vector<TBlob>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& out_data,
- const std::vector<TBlob>& aux_states) {
- using namespace mshadow;
- using namespace mshadow::expr;
- CHECK_EQ(in_data.size(), 3U);
- CHECK_EQ(aux_states.size(), 2U);
- if (ctx.is_train) {
- CHECK_EQ(out_data.size(), 3U);
- CHECK_EQ(req.size(), 3U);
- } else {
- CHECK_GE(out_data.size(), 1U);
- CHECK_GE(req.size(), 1U);
- }
- CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
- CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
-
- Init(in_data[cudnnbatchnorm::kData]);
- Stream<gpu>* s = ctx.get_stream<gpu>();
- Tensor<gpu, 4, DType> x =
- in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
-
- Tensor<gpu, 4, DType> y =
- out_data[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);
-#if CUDNN_VERSION >= 7002
- auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
-#else
- auto mode = CUDNN_BATCHNORM_SPATIAL;
-#endif
-
- MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
- Tensor<gpu, 1, DTypeParam> gamma =
- in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> beta =
- in_data[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> moving_mean =
- aux_states[cudnnbatchnorm::kMovingMean].get_with_shape<gpu, 1, DTypeParam>(
- Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> moving_inv_var =
- aux_states[cudnnbatchnorm::kMovingInvVar].get_with_shape<gpu, 1, DTypeParam>(
- Shape1(shape_[1]), s);
- typename DataType<DType>::ScaleType a = 1.0f;
- typename DataType<DType>::ScaleType b = 0.0f;
-
- if (param_.fix_gamma)
- gamma = 1.f;
-
- if (ctx.is_train) {
- Tensor<gpu, 1, DTypeParam> save_mean =
- out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]),
- s);
- Tensor<gpu, 1, DTypeParam> save_inv_var =
- out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]),
- s);
- // If the lock on the auxiliary states is set, then this implies that
- // the preceding call is also a `Forward()` call, which further
- // indicates that we are in the backward mirroring mode, and therefore
- // update to the auxiliary states is disabled. This is done by setting
- // the `momentum` to `1` (or `factor` to `0`).
- float factor =
- ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) &&
- internal_aux_states_lock_)
- ? 0
- : (1 - param_.momentum);
- CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_,
- mode,
- &a,
- &b,
- io_desc_,
- x.dptr_,
- io_desc_,
- y.dptr_,
- mean_desc_,
- gamma.dptr_,
- beta.dptr_,
- factor,
- moving_mean.dptr_,
- moving_inv_var.dptr_,
- param_.eps,
- save_mean.dptr_,
- save_inv_var.dptr_));
- } else {
- CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_,
- CUDNN_BATCHNORM_SPATIAL,
- &a,
- &b,
- io_desc_,
- x.dptr_,
- io_desc_,
- y.dptr_,
- mean_desc_,
- gamma.dptr_,
- beta.dptr_,
- moving_mean.dptr_,
- moving_inv_var.dptr_,
- param_.eps));
- }
- })
- // Set the lock on the auxiliary states.
- // If the next call to the operator is a `Forward()` call,
- // then `momentum` will be set to `1` and hence auxiliary states will not be updated.
- internal_aux_states_lock_ = true;
- }
-
- void Backward(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
- using namespace mshadow;
- using namespace mshadow::expr;
- CHECK_EQ(inputs.size(), 8U);
- CHECK_EQ(outputs.size(), 3U);
-
- // Rename the inputs and outputs.
- const TBlob& out_grad = inputs[0];
- const TBlob& out_mean = inputs[1];
- const TBlob& out_var = inputs[2];
- const TBlob& in_data = inputs[3];
- const TBlob& in_gamma = inputs[4];
- const std::vector<TBlob>& in_grad = outputs;
-
- Init(in_data);
- Stream<gpu>* s = ctx.get_stream<gpu>();
- Tensor<gpu, 4, DType> x = in_data.get_with_shape<gpu, 4, DType>(shape_, s);
- Tensor<gpu, 4, DType> dx =
- in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
- Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, s);
-
- const bool global_stats = !ctx.is_train || param_.use_global_stats;
-
-#if CUDNN_VERSION >= 7002
- auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
-#else
- auto mode = CUDNN_BATCHNORM_SPATIAL;
-#endif
- MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
- Tensor<gpu, 1, DTypeParam> gamma =
- in_gamma.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> dbeta =
- in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> dgamma =
- in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> save_mean =
- out_mean.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
- Tensor<gpu, 1, DTypeParam> save_inv_var =
- out_var.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-
- typename DataType<DType>::ScaleType a = 1.0f;
- typename DataType<DType>::ScaleType b = 0.0f;
- typename DataType<DType>::ScaleType b_add = 1.0f;
- CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-
- if (param_.fix_gamma)
- gamma = 1.f;
-
- bool grad_add_gamma_beta =
- (req[cudnnbatchnorm::kGamma] == kAddTo) || (req[cudnnbatchnorm::kBeta] == kAddTo);
- if (grad_add_gamma_beta) {
- if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
- dgamma = 0.f;
- }
- if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
- dbeta = 0.f;
- }
- }
-
- CUDNN_CALL(
- cudnnBatchNormalizationBackward(s->dnn_handle_,
- mode,
- &a,
- req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
- &a,
- grad_add_gamma_beta ? &b_add : &b, // gamma and beta
- io_desc_,
- x.dptr_,
- io_desc_,
- dy.dptr_,
- io_desc_,
- dx.dptr_,
- mean_desc_,
- gamma.dptr_,
- dgamma.dptr_,
- dbeta.dptr_,
- param_.eps,
- global_stats ? nullptr : save_mean.dptr_,
- global_stats ? nullptr : save_inv_var.dptr_));
- if (param_.fix_gamma)
- dgamma = 0.f;
- })
- // Release the lock on the auxiliary states, so that the next forward pass
- // will be able to update the auxiliary states normally.
- internal_aux_states_lock_ = false;
- }
-
- private:
- void Init(const TBlob& in_data) {
- CHECK_GE(param_.axis, 0);
- CHECK_LT(param_.axis, in_data.ndim());
- if (param_.axis == 1) {
- if (in_data.ndim() == 4) {
- for (int i = 0; i < 4; ++i)
- shape_[i] = in_data.shape_[i];
- } else {
- // when in_data.ndim() != 4
- shape_[0] = in_data.shape_[0];
- shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
- shape_[2] = 1;
- shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2, in_data.ndim()));
- }
- } else {
- // reshape to (N, C, 1, D), C is the `param_.axis` dimension
- shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
- shape_[1] = in_data.shape_[param_.axis];
- shape_[2] = 1;
- shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1, in_data.ndim()));
- }
-
- CUDNN_CALL(cudnnSetTensor4dDescriptor(
- io_desc_, CUDNN_TENSOR_NCHW, dtype_, shape_[0], shape_[1], shape_[2], shape_[3]));
- CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_, io_desc_, CUDNN_BATCHNORM_SPATIAL));
- }
-
- cudnnDataType_t dtype_;
- int dtype_param_;
- cudnnTensorDescriptor_t io_desc_, mean_desc_;
- mshadow::Shape<4> shape_;
- BatchNormParam param_;
- bool internal_aux_states_lock_;
-};
-#endif // defined(__CUDACC__)
-
-#endif // MXNET_USE_CUDNN == 1
-} // namespace op
-} // namespace mxnet
-#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc
deleted file mode 100644
index 5ea46f2..0000000
--- a/src/operator/nn/cudnn/cudnn_batch_norm.cc
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cudnn_batch_norm.cc
- * \brief
- * \author Junyuan Xie, Da Zheng
- */
-
-#include "./cudnn_batch_norm-inl.h"
-#include <nnvm/op_attr_types.h>
-#include "../../elemwise_op_common.h"
-
-namespace mxnet {
-namespace op {
-#if MXNET_USE_CUDNN == 1
-
-static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector* in_shape,
- mxnet::ShapeVector* out_shape) {
- using namespace mshadow;
- CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, moving_mean, moving_var]";
- const mxnet::TShape& dshape = in_shape->at(0);
- if (!mxnet::ndim_is_known(dshape))
- return false;
- in_shape->at(1) = mxnet::TShape(Shape1(dshape[1]));
- in_shape->at(2) = mxnet::TShape(Shape1(dshape[1]));
- in_shape->at(3) = mxnet::TShape(Shape1(dshape[1]));
- in_shape->at(4) = mxnet::TShape(Shape1(dshape[1]));
-
- out_shape->clear();
- out_shape->push_back(dshape);
- out_shape->push_back(Shape1(dshape[1]));
- out_shape->push_back(Shape1(dshape[1]));
-
- return true;
-}
-
-static void BatchNormCompute_CPU(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
- LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu.";
-}
-
-static void BatchNormGradCompute_CPU(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
- LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu.";
-}
-
-NNVM_REGISTER_OP(CuDNNBatchNorm)
- .describe("Apply batch normalization to input.")
- .set_num_inputs(5)
- .set_num_outputs(3)
- .set_attr_parser(ParamParser<BatchNormParam>)
- .set_attr<nnvm::FListInputNames>(
- "FListInputNames",
- [](const NodeAttrs& attrs) {
- return std::vector<std::string>{"data", "gamma", "beta", "moving_mean", "moving_var"};
- })
- .set_attr<nnvm::FListOutputNames>("FListOutputNames",
- [](const NodeAttrs& attrs) {
- return std::vector<std::string>{"output", "mean", "var"};
- })
- .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
- [](const NodeAttrs& attrs) { return 1; })
- .set_attr<nnvm::FMutateInputs>("FMutateInputs",
- [](const nnvm::NodeAttrs& attrs) {
- return std::vector<uint32_t>{3, 4};
- })
- .set_attr<mxnet::FInferShape>("FInferShape", BatchNormShape)
- .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute_CPU)
- .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_CuDNNBatchNorm"})
- .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
- .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
- .add_argument("beta", "NDArray-or-Symbol", "beta array")
- .add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
- .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
- .add_arguments(BatchNormParam::__FIELDS__())
- .set_attr<nnvm::FSetInputVarAttrOnCompose>(
- "FSetInputVarAttrOnCompose",
- [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) {
- if (var->attrs.dict.find("__init__") != var->attrs.dict.end())
- return;
- if (index == 3) {
- var->attrs.dict["__init__"] = "[\"zero\", {}]";
- } else if (index == 4) {
- var->attrs.dict["__init__"] = "[\"one\", {}]";
- }
- });
-
-NNVM_REGISTER_OP(_backward_CuDNNBatchNorm)
- .set_num_outputs(5)
- .set_attr<nnvm::FMutateInputs>("FMutateInputs",
- [](const nnvm::NodeAttrs& attrs) {
- return std::vector<uint32_t>{6, 7};
- })
- .set_attr<nnvm::TIsBackward>("TIsBackward", true)
- .set_attr_parser(ParamParser<BatchNormParam>)
- .set_attr<FCompute>("FCompute<cpu>", BatchNormGradCompute_CPU);
-
-#endif // MXNET_USE_CUDNN
-
-} // namespace op
-} // namespace mxnet
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cu b/src/operator/nn/cudnn/cudnn_batch_norm.cu
new file mode 100644
index 0000000..bed274f
--- /dev/null
+++ b/src/operator/nn/cudnn/cudnn_batch_norm.cu
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file cudnn_batch_norm.cu
+ * \brief
+ * \author Junyuan Xie, Da Zheng
+ */
+
+#include "cudnn_batch_norm.h"
+
+#include "../../../common/cuda/utils.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDNN == 1
+
+namespace {
+
+struct Globals {
+ cudnnTensorDescriptor_t io_desc;
+ cudnnTensorDescriptor_t mean_desc;
+ bool internal_aux_states_lock = false;
+
+ static Globals& Get() {
+ thread_local Globals ret;
+ return ret;
+ }
+
+ Globals() {
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc));
+ }
+
+ ~Globals() {
+ CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc));
+ CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc));
+ }
+};
+
+void SetDescriptors(const BatchNormParam& param, const TBlob& x) {
+ CHECK_GE(x.shape_.ndim(), 3);
+ CHECK(param.axis == 1 || param.axis == x.shape_.ndim() - 1);
+
+ cudnnTensorFormat_t format = param.axis == 1 ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
+ int n = x.shape_[0];
+ int c = x.shape_[param.axis];
+ size_t last_spatial_i = param.axis == 1 ? x.shape_.ndim() - 1 : x.shape_.ndim() - 2;
+ int w = x.shape_[last_spatial_i];
+ int h = x.shape_.ProdShape(last_spatial_i - (x.shape_.ndim() - 3), last_spatial_i);
+
+ MSHADOW_REAL_TYPE_SWITCH(x.type_flag_, DType, {
+ CUDNN_CALL(cudnnSetTensor4dDescriptor(Globals::Get().io_desc, format,
+ mshadow::DataType<DType>::kCudnnFlag, n, c, h, w));
+ })
+ CUDNN_CALL(cudnnDeriveBNTensorDescriptor(Globals::Get().mean_desc, Globals::Get().io_desc,
+ CUDNN_BATCHNORM_SPATIAL));
+}
+
+mshadow::TypeFlag ParamType(int x_type) {
+ auto xt = static_cast<mshadow::TypeFlag>(x_type);
+ return xt == mshadow::kFloat16 ? mshadow::kFloat32 : xt;
+}
+
+} // namespace
+
+bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x) {
+ int n = x.shape_.ndim();
+ return n >= 3 && (param.axis == 1 || param.axis == n - 1);
+}
+
+void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
+ const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 5);
+ if (ctx.is_train) {
+ CHECK_EQ(outputs.size(), 3);
+ CHECK_EQ(req.size(), 3);
+ } else {
+ CHECK_GE(outputs.size(), 1);
+ CHECK_GE(req.size(), 1);
+ }
+ CHECK_EQ(req[batchnorm::kOut], kWriteTo);
+ CHECK_GE(inputs[batchnorm::kData].ndim(), 2);
+
+ SetDescriptors(param, inputs[batchnorm::kData]);
+
+ auto s = ctx.get_stream<gpu>();
+ MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[batchnorm::kData].type_flag_), DType, {
+ DType a = 1.0f;
+ DType b = 0.0f;
+ if (param.fix_gamma) inputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
+ if (ctx.is_train) {
+ size_t workspace_size = 0;
+ CUDNN_CALL(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
+ s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+ Globals::Get().io_desc, nullptr, Globals::Get().io_desc, Globals::Get().mean_desc,
+ nullptr, &workspace_size));
+ auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormForward");
+
+ // If the lock on the auxiliary states is set, then this implies that
+ // the preceding call is also a `Forward()` call, which further
+ // indicates that we are in the backward mirroring mode, and therefore
+ // update to the auxiliary states is disabled. This is done by setting
+ // the `momentum` to `1` (or `factor` to `0`).
+ double factor =
+ ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) &&
+ Globals::Get().internal_aux_states_lock)
+ ? 0
+ : (1 - param.momentum);
+ CUDNN_CALL(cudnnBatchNormalizationForwardTrainingEx(
+ s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, &a, &b,
+ Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
+ nullptr, nullptr, // zDesc, zData
+ Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
+ Globals::Get().mean_desc,
+ inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
+ factor, inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
+ param.eps, outputs[batchnorm::kMean].dptr_, outputs[batchnorm::kVar].dptr_,
+ nullptr, // activation desc
+ workspace, workspace_size,
+ nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes
+ } else {
+ CUDNN_CALL(cudnnBatchNormalizationForwardInference(
+ s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL, &a, &b,
+ Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
+ Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
+ Globals::Get().mean_desc,
+ inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
+ inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
+ param.eps));
+ }
+ })
+ // Set the lock on the auxiliary states.
+ // If the next call to the operator is a `Forward()` call,
+ // then `momentum` will be set to `1` and hence auxiliary states will not be updated.
+ Globals::Get().internal_aux_states_lock = true;
+}
+
+void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
+ const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 8);
+ CHECK_EQ(outputs.size(), 3);
+ CHECK_EQ(req.size(), 3);
+
+ SetDescriptors(param, inputs[3 + batchnorm::kData]);
+ auto s = ctx.get_stream<gpu>();
+ size_t workspace_size = 0;
+ CUDNN_CALL(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
+ s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+ Globals::Get().io_desc, Globals::Get().io_desc, Globals::Get().io_desc, nullptr,
+ Globals::Get().io_desc, Globals::Get().mean_desc, nullptr, &workspace_size));
+ auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormBackward");
+ MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[3 + batchnorm::kData].type_flag_), DType, {
+ if (param.fix_gamma) inputs[3 + batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
+ bool grad_add_gamma_beta = req[batchnorm::kGamma] == kAddTo || req[batchnorm::kBeta] == kAddTo;
+ if (grad_add_gamma_beta) {
+ if (IsBNWriting(req[batchnorm::kGamma]))
+ outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
+ if (IsBNWriting(req[batchnorm::kBeta]))
+ outputs[batchnorm::kBeta].FlatTo1D<gpu, DType>(s) = 0.0f;
+ }
+ DType a = 1.0f;
+ DType b = 0.0f;
+ DType b_add = 1.0f;
+ const bool global_stats = !ctx.is_train || param.use_global_stats;
+ CUDNN_CALL(cudnnBatchNormalizationBackwardEx(
+ s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+ &a, req[batchnorm::kData] == kAddTo ? &b_add : &b,
+ &a, grad_add_gamma_beta ? &b_add : &b,
+ Globals::Get().io_desc, inputs[3 + batchnorm::kData].dptr_,
+ nullptr, nullptr, // yDesc, yData
+ Globals::Get().io_desc, inputs[batchnorm::kOut].dptr_,
+ nullptr, nullptr, // dzDesc, dzData
+ Globals::Get().io_desc, outputs[batchnorm::kData].dptr_,
+ Globals::Get().mean_desc,
+ inputs[3 + batchnorm::kGamma].dptr_, inputs[3 + batchnorm::kBeta].dptr_,
+ outputs[batchnorm::kGamma].dptr_, outputs[batchnorm::kBeta].dptr_, param.eps,
+ global_stats ? nullptr : inputs[batchnorm::kMean].dptr_,
+ global_stats ? nullptr : inputs[batchnorm::kVar].dptr_,
+ nullptr, // activationDesc
+ workspace, workspace_size,
+ nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes
+ if (param.fix_gamma) outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
+ })
+ Globals::Get().internal_aux_states_lock = false;
+}
+
+#endif // MXNET_USE_CUDNN == 1
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.h b/src/operator/nn/cudnn/cudnn_batch_norm.h
new file mode 100644
index 0000000..57249b1
--- /dev/null
+++ b/src/operator/nn/cudnn/cudnn_batch_norm.h
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file cudnn_batch_norm.h
+ * \brief
+ * \author Junyuan Xie
+*/
+
+#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_
+#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include "../batch_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDNN == 1
+
+STATIC_ASSERT_CUDNN_VERSION_GE(7401);
+
+bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x);
+
+void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
+ const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+
+void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
+ const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+
+#endif // MXNET_USE_CUDNN == 1
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_