You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2021/09/30 17:12:03 UTC

[incubator-mxnet] branch master updated: Fast cuDNN BatchNorm NHWC kernels support (#20615)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 23af413  Fast cuDNN BatchNorm NHWC kernels support (#20615)
23af413 is described below

commit 23af413d848c7a2d1cd6b20c80f9d4d8346c180a
Author: Vladimir Cherepanov <56...@users.noreply.github.com>
AuthorDate: Thu Sep 30 10:10:18 2021 -0700

    Fast cuDNN BatchNorm NHWC kernels support (#20615)
    
    * Fast cuDNN NHWC kernels support
    
    * Fix lint errors
    
    * Get rid of a warning
    
    * Remove CuDNNBatchNorm from AMP lists
    
    Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
 python/mxnet/amp/lists/symbol_fp16.py        |   5 -
 src/operator/nn/batch_norm.cc                |   2 +-
 src/operator/nn/batch_norm.cu                |  32 +--
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 307 ---------------------------
 src/operator/nn/cudnn/cudnn_batch_norm.cc    | 125 -----------
 src/operator/nn/cudnn/cudnn_batch_norm.cu    | 210 ++++++++++++++++++
 src/operator/nn/cudnn/cudnn_batch_norm.h     |  56 +++++
 7 files changed, 274 insertions(+), 463 deletions(-)

diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index d942051..009586e 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -459,11 +459,6 @@ FP16_FP32_FUNCS = [
     'zeros_like',
     ]
 
-if Features().is_enabled('CUDNN'):
-    FP16_FP32_FUNCS.extend([
-        'CuDNNBatchNorm',
-    ])
-
 # Functions that have to be cast to FP32 due to possible
 # overflows
 FP32_FUNCS = [
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index fb12180..5a18363 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -649,11 +649,11 @@ then set ``gamma`` to 1 and its gradient to 0.
     .set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
 #if MXNET_USE_ONEDNN == 1
     .set_attr<bool>("TIsMKLDNN", true)
+#endif
     .set_attr<FResourceRequest>("FResourceRequest",
                                 [](const NodeAttrs& n) {
                                   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
                                 })
-#endif
     .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
     .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
     .add_argument("beta", "NDArray-or-Symbol", "beta array")
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 7807691..195423b 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -39,7 +39,7 @@
 #define ADDTO_BETA_FLAG       (1 << 8)
 
 #if MXNET_USE_CUDNN == 1
-#include "./cudnn/cudnn_batch_norm-inl.h"
+#include "./cudnn/cudnn_batch_norm.h"
 #endif
 
 #include "../../../include/mxnet/tensor_blob.h"
@@ -935,11 +935,6 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu>* s,
       (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
 
   if (is_train_and_not_global_stats) {
-#ifdef NDEBUG
-    constexpr bool SMALLER_THREADS = false;
-#else
-    constexpr bool SMALLER_THREADS = true;
-#endif
     dim3 blocks(gradOutput.ChannelCount());
     dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
     BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
@@ -1104,19 +1099,6 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu>* stream,
   MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
 }
 
-#if MXNET_USE_CUDNN == 1
-template <typename DType>
-static CuDNNBatchNormOp<DType>& GetCuDNNOp(const BatchNormParam& param) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local CuDNNBatchNormOp<DType> op;
-#else
-  static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
-#endif
-  op.Init(param);
-  return op;
-}
-#endif
-
 template <>
 void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
                            const OpContext& ctx,
@@ -1132,9 +1114,9 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
 
   param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
 #if MXNET_USE_CUDNN == 1
-  if (!param.use_global_stats && !param.cudnn_off) {
-    MSHADOW_REAL_TYPE_SWITCH(
-        dtype, DType, { GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states); })
+  if (!param.use_global_stats && !param.cudnn_off &&
+      CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
+    CudnnBatchNormForward(param, ctx, inputs, req, outputs);
   } else {
     MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
       BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
@@ -1160,9 +1142,9 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
 
   param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
 #if MXNET_USE_CUDNN == 1
-  if (!param.use_global_stats && !param.cudnn_off) {
-    MSHADOW_REAL_TYPE_SWITCH(
-        dtype, DType, { GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); })
+  if (!param.use_global_stats && !param.cudnn_off &&
+      CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
+    CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
   } else {
     MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
       BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
deleted file mode 100644
index 0f79430..0000000
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ /dev/null
@@ -1,307 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cudnn_batch_norm-inl.h
- * \brief
- * \author Junyuan Xie
- */
-
-#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
-#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
-#include <vector>
-#include <map>
-#include <string>
-#include <utility>
-#include "../batch_norm-inl.h"
-
-namespace mxnet {
-namespace op {
-#if MXNET_USE_CUDNN == 1
-namespace cudnnbatchnorm {
-enum CuDNNBatchNormOpInputs { kData, kGamma, kBeta };
-enum CuDNNBatchNormOpOutputs { kOut, kMean, kInvVar };
-enum CuDNNBatchNormOpAuxiliary { kMovingMean, kMovingInvVar };
-}  // namespace cudnnbatchnorm
-
-#if defined(__CUDACC__)
-template <typename DType>
-class CuDNNBatchNormOp {
-  STATIC_ASSERT_CUDNN_VERSION_GE(5000);
-
- public:
-  CuDNNBatchNormOp() {
-    using namespace mshadow;
-    dtype_ = DataType<DType>::kCudnnFlag;
-    // For float16 input type beta, gamma, mean, and average are stored in float32.
-    // For other input types, these parameters have the same type as input
-    dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType<DType>::kFlag;
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_));
-    internal_aux_states_lock_ = false;
-  }
-
-  void Init(const BatchNormParam& param) {
-    CHECK_GE(param.eps, CUDNN_BN_MIN_EPSILON)
-        << "CuDNN requires eps to be no less than " << CUDNN_BN_MIN_EPSILON;
-    this->param_ = param;
-  }
-
-  ~CuDNNBatchNormOp() {
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc_));
-  }
-
-  void Forward(const OpContext& ctx,
-               const std::vector<TBlob>& in_data,
-               const std::vector<OpReqType>& req,
-               const std::vector<TBlob>& out_data,
-               const std::vector<TBlob>& aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 3U);
-    CHECK_EQ(aux_states.size(), 2U);
-    if (ctx.is_train) {
-      CHECK_EQ(out_data.size(), 3U);
-      CHECK_EQ(req.size(), 3U);
-    } else {
-      CHECK_GE(out_data.size(), 1U);
-      CHECK_GE(req.size(), 1U);
-    }
-    CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
-    CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
-
-    Init(in_data[cudnnbatchnorm::kData]);
-    Stream<gpu>* s = ctx.get_stream<gpu>();
-    Tensor<gpu, 4, DType> x =
-        in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
-
-    Tensor<gpu, 4, DType> y =
-        out_data[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);
-#if CUDNN_VERSION >= 7002
-    auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
-#else
-    auto mode = CUDNN_BATCHNORM_SPATIAL;
-#endif
-
-    MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
-      Tensor<gpu, 1, DTypeParam> gamma =
-          in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> beta =
-          in_data[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> moving_mean =
-          aux_states[cudnnbatchnorm::kMovingMean].get_with_shape<gpu, 1, DTypeParam>(
-              Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> moving_inv_var =
-          aux_states[cudnnbatchnorm::kMovingInvVar].get_with_shape<gpu, 1, DTypeParam>(
-              Shape1(shape_[1]), s);
-      typename DataType<DType>::ScaleType a = 1.0f;
-      typename DataType<DType>::ScaleType b = 0.0f;
-
-      if (param_.fix_gamma)
-        gamma = 1.f;
-
-      if (ctx.is_train) {
-        Tensor<gpu, 1, DTypeParam> save_mean =
-            out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]),
-                                                                               s);
-        Tensor<gpu, 1, DTypeParam> save_inv_var =
-            out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]),
-                                                                                 s);
-        // If the lock on the auxiliary states is set, then this implies that
-        // the preceding call is also a `Forward()` call, which further
-        // indicates that we are in the backward mirroring mode, and therefore
-        // update to the auxiliary states is disabled. This is done by setting
-        // the `momentum` to `1` (or `factor` to `0`).
-        float factor =
-            ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) &&
-             internal_aux_states_lock_)
-                ? 0
-                : (1 - param_.momentum);
-        CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_,
-                                                          mode,
-                                                          &a,
-                                                          &b,
-                                                          io_desc_,
-                                                          x.dptr_,
-                                                          io_desc_,
-                                                          y.dptr_,
-                                                          mean_desc_,
-                                                          gamma.dptr_,
-                                                          beta.dptr_,
-                                                          factor,
-                                                          moving_mean.dptr_,
-                                                          moving_inv_var.dptr_,
-                                                          param_.eps,
-                                                          save_mean.dptr_,
-                                                          save_inv_var.dptr_));
-      } else {
-        CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_,
-                                                           CUDNN_BATCHNORM_SPATIAL,
-                                                           &a,
-                                                           &b,
-                                                           io_desc_,
-                                                           x.dptr_,
-                                                           io_desc_,
-                                                           y.dptr_,
-                                                           mean_desc_,
-                                                           gamma.dptr_,
-                                                           beta.dptr_,
-                                                           moving_mean.dptr_,
-                                                           moving_inv_var.dptr_,
-                                                           param_.eps));
-      }
-    })
-    // Set the lock on the auxiliary states.
-    // If the next call to the operator is a `Forward()` call,
-    // then `momentum` will be set to `1` and hence auxiliary states will not be updated.
-    internal_aux_states_lock_ = true;
-  }
-
-  void Backward(const OpContext& ctx,
-                const std::vector<TBlob>& inputs,
-                const std::vector<OpReqType>& req,
-                const std::vector<TBlob>& outputs) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(inputs.size(), 8U);
-    CHECK_EQ(outputs.size(), 3U);
-
-    // Rename the inputs and outputs.
-    const TBlob& out_grad             = inputs[0];
-    const TBlob& out_mean             = inputs[1];
-    const TBlob& out_var              = inputs[2];
-    const TBlob& in_data              = inputs[3];
-    const TBlob& in_gamma             = inputs[4];
-    const std::vector<TBlob>& in_grad = outputs;
-
-    Init(in_data);
-    Stream<gpu>* s          = ctx.get_stream<gpu>();
-    Tensor<gpu, 4, DType> x = in_data.get_with_shape<gpu, 4, DType>(shape_, s);
-    Tensor<gpu, 4, DType> dx =
-        in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
-    Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, s);
-
-    const bool global_stats = !ctx.is_train || param_.use_global_stats;
-
-#if CUDNN_VERSION >= 7002
-    auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
-#else
-    auto mode = CUDNN_BATCHNORM_SPATIAL;
-#endif
-    MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
-      Tensor<gpu, 1, DTypeParam> gamma =
-          in_gamma.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> dbeta =
-          in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> dgamma =
-          in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> save_mean =
-          out_mean.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-      Tensor<gpu, 1, DTypeParam> save_inv_var =
-          out_var.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
-
-      typename DataType<DType>::ScaleType a     = 1.0f;
-      typename DataType<DType>::ScaleType b     = 0.0f;
-      typename DataType<DType>::ScaleType b_add = 1.0f;
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-
-      if (param_.fix_gamma)
-        gamma = 1.f;
-
-      bool grad_add_gamma_beta =
-          (req[cudnnbatchnorm::kGamma] == kAddTo) || (req[cudnnbatchnorm::kBeta] == kAddTo);
-      if (grad_add_gamma_beta) {
-        if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
-          dgamma = 0.f;
-        }
-        if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
-          dbeta = 0.f;
-        }
-      }
-
-      CUDNN_CALL(
-          cudnnBatchNormalizationBackward(s->dnn_handle_,
-                                          mode,
-                                          &a,
-                                          req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
-                                          &a,
-                                          grad_add_gamma_beta ? &b_add : &b,  // gamma and beta
-                                          io_desc_,
-                                          x.dptr_,
-                                          io_desc_,
-                                          dy.dptr_,
-                                          io_desc_,
-                                          dx.dptr_,
-                                          mean_desc_,
-                                          gamma.dptr_,
-                                          dgamma.dptr_,
-                                          dbeta.dptr_,
-                                          param_.eps,
-                                          global_stats ? nullptr : save_mean.dptr_,
-                                          global_stats ? nullptr : save_inv_var.dptr_));
-      if (param_.fix_gamma)
-        dgamma = 0.f;
-    })
-    // Release the lock on the auxiliary states, so that the next forward pass
-    // will be able to update the auxiliary states normally.
-    internal_aux_states_lock_ = false;
-  }
-
- private:
-  void Init(const TBlob& in_data) {
-    CHECK_GE(param_.axis, 0);
-    CHECK_LT(param_.axis, in_data.ndim());
-    if (param_.axis == 1) {
-      if (in_data.ndim() == 4) {
-        for (int i = 0; i < 4; ++i)
-          shape_[i] = in_data.shape_[i];
-      } else {
-        // when in_data.ndim() != 4
-        shape_[0] = in_data.shape_[0];
-        shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
-        shape_[2] = 1;
-        shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2, in_data.ndim()));
-      }
-    } else {
-      // reshape to (N, C, 1, D), C is the `param_.axis` dimension
-      shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
-      shape_[1] = in_data.shape_[param_.axis];
-      shape_[2] = 1;
-      shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1, in_data.ndim()));
-    }
-
-    CUDNN_CALL(cudnnSetTensor4dDescriptor(
-        io_desc_, CUDNN_TENSOR_NCHW, dtype_, shape_[0], shape_[1], shape_[2], shape_[3]));
-    CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_, io_desc_, CUDNN_BATCHNORM_SPATIAL));
-  }
-
-  cudnnDataType_t dtype_;
-  int dtype_param_;
-  cudnnTensorDescriptor_t io_desc_, mean_desc_;
-  mshadow::Shape<4> shape_;
-  BatchNormParam param_;
-  bool internal_aux_states_lock_;
-};
-#endif  // defined(__CUDACC__)
-
-#endif  // MXNET_USE_CUDNN == 1
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_INL_H_
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc
deleted file mode 100644
index 5ea46f2..0000000
--- a/src/operator/nn/cudnn/cudnn_batch_norm.cc
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cudnn_batch_norm.cc
- * \brief
- * \author Junyuan Xie, Da Zheng
- */
-
-#include "./cudnn_batch_norm-inl.h"
-#include <nnvm/op_attr_types.h>
-#include "../../elemwise_op_common.h"
-
-namespace mxnet {
-namespace op {
-#if MXNET_USE_CUDNN == 1
-
-static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
-                           mxnet::ShapeVector* in_shape,
-                           mxnet::ShapeVector* out_shape) {
-  using namespace mshadow;
-  CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, moving_mean, moving_var]";
-  const mxnet::TShape& dshape = in_shape->at(0);
-  if (!mxnet::ndim_is_known(dshape))
-    return false;
-  in_shape->at(1) = mxnet::TShape(Shape1(dshape[1]));
-  in_shape->at(2) = mxnet::TShape(Shape1(dshape[1]));
-  in_shape->at(3) = mxnet::TShape(Shape1(dshape[1]));
-  in_shape->at(4) = mxnet::TShape(Shape1(dshape[1]));
-
-  out_shape->clear();
-  out_shape->push_back(dshape);
-  out_shape->push_back(Shape1(dshape[1]));
-  out_shape->push_back(Shape1(dshape[1]));
-
-  return true;
-}
-
-static void BatchNormCompute_CPU(const nnvm::NodeAttrs& attrs,
-                                 const OpContext& ctx,
-                                 const std::vector<TBlob>& inputs,
-                                 const std::vector<OpReqType>& req,
-                                 const std::vector<TBlob>& outputs) {
-  LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu.";
-}
-
-static void BatchNormGradCompute_CPU(const nnvm::NodeAttrs& attrs,
-                                     const OpContext& ctx,
-                                     const std::vector<TBlob>& inputs,
-                                     const std::vector<OpReqType>& req,
-                                     const std::vector<TBlob>& outputs) {
-  LOG(FATAL) << "CuDNNBatchNormOp is only available for gpu.";
-}
-
-NNVM_REGISTER_OP(CuDNNBatchNorm)
-    .describe("Apply batch normalization to input.")
-    .set_num_inputs(5)
-    .set_num_outputs(3)
-    .set_attr_parser(ParamParser<BatchNormParam>)
-    .set_attr<nnvm::FListInputNames>(
-        "FListInputNames",
-        [](const NodeAttrs& attrs) {
-          return std::vector<std::string>{"data", "gamma", "beta", "moving_mean", "moving_var"};
-        })
-    .set_attr<nnvm::FListOutputNames>("FListOutputNames",
-                                      [](const NodeAttrs& attrs) {
-                                        return std::vector<std::string>{"output", "mean", "var"};
-                                      })
-    .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
-                                        [](const NodeAttrs& attrs) { return 1; })
-    .set_attr<nnvm::FMutateInputs>("FMutateInputs",
-                                   [](const nnvm::NodeAttrs& attrs) {
-                                     return std::vector<uint32_t>{3, 4};
-                                   })
-    .set_attr<mxnet::FInferShape>("FInferShape", BatchNormShape)
-    .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute_CPU)
-    .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_CuDNNBatchNorm"})
-    .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
-    .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
-    .add_argument("beta", "NDArray-or-Symbol", "beta array")
-    .add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
-    .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
-    .add_arguments(BatchNormParam::__FIELDS__())
-    .set_attr<nnvm::FSetInputVarAttrOnCompose>(
-        "FSetInputVarAttrOnCompose",
-        [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) {
-          if (var->attrs.dict.find("__init__") != var->attrs.dict.end())
-            return;
-          if (index == 3) {
-            var->attrs.dict["__init__"] = "[\"zero\", {}]";
-          } else if (index == 4) {
-            var->attrs.dict["__init__"] = "[\"one\", {}]";
-          }
-        });
-
-NNVM_REGISTER_OP(_backward_CuDNNBatchNorm)
-    .set_num_outputs(5)
-    .set_attr<nnvm::FMutateInputs>("FMutateInputs",
-                                   [](const nnvm::NodeAttrs& attrs) {
-                                     return std::vector<uint32_t>{6, 7};
-                                   })
-    .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-    .set_attr_parser(ParamParser<BatchNormParam>)
-    .set_attr<FCompute>("FCompute<cpu>", BatchNormGradCompute_CPU);
-
-#endif  // MXNET_USE_CUDNN
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cu b/src/operator/nn/cudnn/cudnn_batch_norm.cu
new file mode 100644
index 0000000..bed274f
--- /dev/null
+++ b/src/operator/nn/cudnn/cudnn_batch_norm.cu
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file cudnn_batch_norm.cu
+ * \brief
+ * \author Junyuan Xie, Da Zheng
+ */
+
+#include "cudnn_batch_norm.h"
+
+#include "../../../common/cuda/utils.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDNN == 1
+
+namespace {
+
+struct Globals {
+  cudnnTensorDescriptor_t io_desc;
+  cudnnTensorDescriptor_t mean_desc;
+  bool internal_aux_states_lock = false;
+
+  static Globals& Get() {
+    thread_local Globals ret;
+    return ret;
+  }
+
+  Globals() {
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc));
+  }
+
+  ~Globals() {
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc));
+  }
+};
+
+void SetDescriptors(const BatchNormParam& param, const TBlob& x) {
+  CHECK_GE(x.shape_.ndim(), 3);
+  CHECK(param.axis == 1 || param.axis == x.shape_.ndim() - 1);
+
+  cudnnTensorFormat_t format = param.axis == 1 ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
+  int n = x.shape_[0];
+  int c = x.shape_[param.axis];
+  size_t last_spatial_i = param.axis == 1 ? x.shape_.ndim() - 1 : x.shape_.ndim() - 2;
+  int w = x.shape_[last_spatial_i];
+  int h = x.shape_.ProdShape(last_spatial_i - (x.shape_.ndim() - 3), last_spatial_i);
+
+  MSHADOW_REAL_TYPE_SWITCH(x.type_flag_, DType, {
+    CUDNN_CALL(cudnnSetTensor4dDescriptor(Globals::Get().io_desc, format,
+                                          mshadow::DataType<DType>::kCudnnFlag, n, c, h, w));
+  })
+  CUDNN_CALL(cudnnDeriveBNTensorDescriptor(Globals::Get().mean_desc, Globals::Get().io_desc,
+                                           CUDNN_BATCHNORM_SPATIAL));
+}
+
+mshadow::TypeFlag ParamType(int x_type) {
+  auto xt = static_cast<mshadow::TypeFlag>(x_type);
+  return xt == mshadow::kFloat16 ? mshadow::kFloat32 : xt;
+}
+
+}  // namespace
+
+bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x) {
+  int n = x.shape_.ndim();
+  return n >= 3 && (param.axis == 1 || param.axis == n - 1);
+}
+
+void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
+                           const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 5);
+  if (ctx.is_train) {
+    CHECK_EQ(outputs.size(), 3);
+    CHECK_EQ(req.size(), 3);
+  } else {
+    CHECK_GE(outputs.size(), 1);
+    CHECK_GE(req.size(), 1);
+  }
+  CHECK_EQ(req[batchnorm::kOut], kWriteTo);
+  CHECK_GE(inputs[batchnorm::kData].ndim(), 2);
+
+  SetDescriptors(param, inputs[batchnorm::kData]);
+
+  auto s = ctx.get_stream<gpu>();
+  MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[batchnorm::kData].type_flag_), DType, {
+    DType a = 1.0f;
+    DType b = 0.0f;
+    if (param.fix_gamma) inputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
+    if (ctx.is_train) {
+      size_t workspace_size = 0;
+      CUDNN_CALL(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
+          s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+          Globals::Get().io_desc, nullptr, Globals::Get().io_desc, Globals::Get().mean_desc,
+          nullptr, &workspace_size));
+      auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormForward");
+
+      // If the lock on the auxiliary states is set, then this implies that
+      // the preceding call is also a `Forward()` call, which further
+      // indicates that we are in the backward mirroring mode, and therefore
+      // update to the auxiliary states is disabled. This is done by setting
+      // the `momentum` to `1` (or `factor` to `0`).
+      double factor =
+          ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) &&
+           Globals::Get().internal_aux_states_lock)
+              ? 0
+              : (1 - param.momentum);
+      CUDNN_CALL(cudnnBatchNormalizationForwardTrainingEx(
+          s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, &a, &b,
+          Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
+          nullptr, nullptr,  // zDesc, zData
+          Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
+          Globals::Get().mean_desc,
+          inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
+          factor, inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
+          param.eps, outputs[batchnorm::kMean].dptr_, outputs[batchnorm::kVar].dptr_,
+          nullptr,  // activation desc
+          workspace, workspace_size,
+          nullptr, 0));  // reserveSpace, reserveSpaceSizeInBytes
+    } else {
+      CUDNN_CALL(cudnnBatchNormalizationForwardInference(
+          s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL, &a, &b,
+          Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
+          Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
+          Globals::Get().mean_desc,
+          inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
+          inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
+          param.eps));
+    }
+  })
+  // Set the lock on the auxiliary states.
+  // If the next call to the operator is a `Forward()` call,
+  // then `momentum` will be set to `1` and hence auxiliary states will not be updated.
+  Globals::Get().internal_aux_states_lock = true;
+}
+
+void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
+                            const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 8);
+  CHECK_EQ(outputs.size(), 3);
+  CHECK_EQ(req.size(), 3);
+
+  SetDescriptors(param, inputs[3 + batchnorm::kData]);
+  auto s = ctx.get_stream<gpu>();
+  size_t workspace_size = 0;
+  CUDNN_CALL(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
+      s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+      Globals::Get().io_desc, Globals::Get().io_desc, Globals::Get().io_desc, nullptr,
+      Globals::Get().io_desc, Globals::Get().mean_desc, nullptr, &workspace_size));
+  auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormBackward");
+  MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[3 + batchnorm::kData].type_flag_), DType, {
+    if (param.fix_gamma) inputs[3 + batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
+    bool grad_add_gamma_beta = req[batchnorm::kGamma] == kAddTo || req[batchnorm::kBeta] == kAddTo;
+    if (grad_add_gamma_beta) {
+      if (IsBNWriting(req[batchnorm::kGamma]))
+        outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
+      if (IsBNWriting(req[batchnorm::kBeta]))
+        outputs[batchnorm::kBeta].FlatTo1D<gpu, DType>(s) = 0.0f;
+    }
+    DType a = 1.0f;
+    DType b = 0.0f;
+    DType b_add = 1.0f;
+    const bool global_stats = !ctx.is_train || param.use_global_stats;
+    CUDNN_CALL(cudnnBatchNormalizationBackwardEx(
+        s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
+        &a, req[batchnorm::kData] == kAddTo ? &b_add : &b,
+        &a, grad_add_gamma_beta ? &b_add : &b,
+        Globals::Get().io_desc, inputs[3 + batchnorm::kData].dptr_,
+        nullptr, nullptr,  // yDesc, yData
+        Globals::Get().io_desc, inputs[batchnorm::kOut].dptr_,
+        nullptr, nullptr,  // dzDesc, dzData
+        Globals::Get().io_desc, outputs[batchnorm::kData].dptr_,
+        Globals::Get().mean_desc,
+        inputs[3 + batchnorm::kGamma].dptr_, inputs[3 + batchnorm::kBeta].dptr_,
+        outputs[batchnorm::kGamma].dptr_, outputs[batchnorm::kBeta].dptr_, param.eps,
+        global_stats ? nullptr : inputs[batchnorm::kMean].dptr_,
+        global_stats ? nullptr : inputs[batchnorm::kVar].dptr_,
+        nullptr,  // activationDesc
+        workspace, workspace_size,
+        nullptr, 0));  // reserveSpace, reserveSpaceSizeInBytes
+    if (param.fix_gamma) outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
+  })
+  Globals::Get().internal_aux_states_lock = false;
+}
+
+#endif  // MXNET_USE_CUDNN == 1
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.h b/src/operator/nn/cudnn/cudnn_batch_norm.h
new file mode 100644
index 0000000..57249b1
--- /dev/null
+++ b/src/operator/nn/cudnn/cudnn_batch_norm.h
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file cudnn_batch_norm.h
+ * \brief
+ * \author Junyuan Xie
+*/
+
+#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_
+#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include "../batch_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDNN == 1
+
+STATIC_ASSERT_CUDNN_VERSION_GE(7401);
+
+bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x);
+
+void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
+                           const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs);
+
+void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
+                            const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs);
+
+#endif  // MXNET_USE_CUDNN == 1
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_