You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/07/23 19:19:38 UTC

[incubator-mxnet] branch v1.6.x updated: [v1.6.x][Bug Fixed] Fix batch norm when grad_req is `add` (#18518) (#18714)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new c765931  [v1.6.x][Bug Fixed] Fix batch norm when grad_req is `add` (#18518) (#18714)
c765931 is described below

commit c76593162c2a5e421243dd09a8c3cde044efaa3f
Author: Chaitanya Prakash Bapat <ch...@gmail.com>
AuthorDate: Thu Jul 23 12:17:39 2020 -0700

    [v1.6.x][Bug Fixed] Fix batch norm when grad_req is `add` (#18518) (#18714)
    
    * [Bug Fixed] Fix batch norm when grad_req is `add` (#18500)
    
    * fix batch norm when fix_gamma is True
    
    * support gradient accumulation for batch norm
    
    * mkldnn batchnorm support grad add
    
    * unittest for bn
    
    * fix bn arg
    
    * fix lint
    
    * fix mkldnn
    
    * fix mkldnn bn
    
    * fix grad when fixing gamma
    
    * fix naive gpu bn
    
    * fix lint
    
    * fix cudnn bn
    
    * fix flag
    
    * fix lint
    
    * fix testcase
    
    * fix
    
    * use @pytest.mark.parametrize
    
    * combination
    
    * remove redundant test in batchnorm
    
    * npx.batch_norm test
    
    * try to fix test
    
    * reduce the number of tests for batchnorm
    
    * fix
    
    * Revert "[Bug Fixed] Fix batch norm when grad_req is `add` (#18500)"
    
    This reverts commit 8e32cd6959461290c1698e02466fcc16f61ad237.
    
    * [v1.x] backport #18500 - [Bug Fixed] Fix batch norm when grad_req is `add` (#18518)
    
    * Fix batch norm when grad_req is
    
    * fix
    
    * remove softmax test
    
    * fix
    
    * add copy_size
    
    * Fix init method for TestBatchNorm
    
    Co-authored-by: JackieWu <wk...@live.cn>
---
 src/operator/nn/batch_norm-inl.h               |   1 +
 src/operator/nn/batch_norm.cc                  |  83 +++++++++----
 src/operator/nn/batch_norm.cu                  |  68 ++++++++---
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h   |  15 ++-
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h |  46 ++++++--
 tests/python/unittest/test_numpy_op.py         | 157 +++++++++++++++++++++++++
 tests/python/unittest/test_operator.py         | 103 ++++++++++------
 7 files changed, 389 insertions(+), 84 deletions(-)

diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 17a16db..485b3b3 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -259,6 +259,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
                        const std::vector<TBlob> &outputs) {
   CHECK_EQ(inputs.size(), 8U);
   CHECK_EQ(outputs.size(), 3U);
+
   std::vector<TBlob> out_grad(1);
   std::vector<TBlob> out_data(3);
   std::vector<TBlob> in_data(3);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 3214e3b..fc65476 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3<DType1> &in_data,
   }
 }
 
+template<typename DType1, typename DType2, typename DType3, typename OnData>
+static inline void ForEachFast(const BNTensor3<DType1> &in_data,
+                               const BNTensor3<DType2> &in_data2,
+                               const BNTensor3<DType3> &out_data,
+                               const size_t channel,
+                               OnData onData) {
+  const size_t num         = in_data.OuterSize();
+  const size_t matrixSize  = in_data.InnerSize();
+  const size_t skipLength  = in_data.SkipLengthToNextSameChannelData();
+  const size_t startOffset = in_data.StartOffset(channel);
+
+  DType1 *data = in_data.dptr_  + startOffset;
+  DType2 *data2 = in_data2.dptr_  + startOffset;
+  DType3 *odata = out_data.dptr_ + startOffset;
+
+  for (size_t outer = 0; outer < num; ++outer) {
+    for (size_t i = 0; i < matrixSize; ++i) {
+      onData(data++, data2++, odata++);
+    }
+    data  += skipLength;
+    data2 += skipLength;
+    odata += skipLength;
+  }
+}
+
 }  // namespace batchnorm
 
 /*! \brief Forward CPU */
@@ -264,7 +289,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
                   dotp += (*thisInputData - mean) * (*gradOut_data);
                 });
 
-    if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) {  // if there's a grad input
+    if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) {  // if there's a grad input
       if (is_train_and_not_global_stats) {
         // when in training mode
         // Q(X) = X - E[x] ; i.e. input centered to zero mean
@@ -273,44 +298,60 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
 
         // projection of gradOutput on to output scaled by std
         const AccReal k = dotp * invstd * invstd / itemCount;
-        ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
-                    [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
-                      *gradIn_data = (*inputDataPtr - mean) * k;
-                    });
-
         const AccReal iw = invstd * w;
         const AccReal gradMean = sumGradOut / itemCount;
-        ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
-                    [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
-                      *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
-                    });
+        if (req[batchnorm::kData] != kAddTo) {
+          ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
+                      [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
+                        *gradIn_data = (*inputDataPtr - mean) * k;
+                      });
+
+          ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+                      [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
+                        *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
+                      });
+        } else {
+          ForEachFast(inputData, gradOut, gradIn, static_cast<size_t>(channel),
+                      [&mean, &k, iw, gradMean](const DType *inputDataPtr,
+                                                const DType *gradOut_data,
+                                                DType *gradIn_data) {
+                        DType normal_val = (*inputDataPtr - mean) * k;
+                        *gradIn_data += (*gradOut_data - gradMean -
+                            normal_val) * iw;
+                      });
+        }
       } else {
         // when in evaluation mode
         // Q(X) = X - running_mean  ; i.e. input centered to zero mean
         // Y = Q(X) / running_std    ; i.e. BN output before weight and bias
         // dL/dX = w / running_std
         const AccReal iw = invstd * w;
-        ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
-                    [iw](const DType *gradOut_data, DType *gradIn_data) {
-                      *gradIn_data = *gradOut_data * iw;
-                    });
+        if (req[batchnorm::kData] != kAddTo) {
+          ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+                      [iw](const DType *gradOut_data, DType *gradIn_data) {
+                        *gradIn_data = *gradOut_data * iw;
+                      });
+        } else {
+          ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+                      [iw](const DType *gradOut_data, DType *gradIn_data) {
+                        *gradIn_data += *gradOut_data * iw;
+                      });
+        }
       }
     }
 
     // May want to make this a param eventually
     const AccReal scale = 1.0f;
 
-    if (IsBNWriting(req[batchnorm::kGamma])) {
-      if (!param_.fix_gamma) {
-        gradWeightData[channel] = scale * dotp * invstd;
-      } else {
+    if (!param_.fix_gamma) {
+      KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd);
+    } else {
+      if (IsBNWriting(req[batchnorm::kGamma])) {
         gradWeightData[channel] = AccReal(0);
       }
     }
 
-    if (IsBNWriting(req[batchnorm::kBeta])) {
-      gradBiasData[channel] = scale * sumGradOut;
-    }
+    KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut);
   }
 }
 
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index be9309c..7b36d25 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -34,6 +34,9 @@
 #define FIX_GAMMA_FLAG        8
 #define IS_TRAINING_FLAG      16
 #define USE_GLOBAL_STATS_FLAG 32
+#define ADDTO_DATA_FLAG       (1 << 6)
+#define ADDTO_GAMMA_FLAG      (1 << 7)
+#define ADDTO_BETA_FLAG       (1 << 8)
 
 #if MXNET_USE_CUDNN == 1
 #include "./cudnn/cudnn_batch_norm-inl.h"
@@ -362,33 +365,60 @@ static __global__ void BatchNormalizationBackwardKernel(
                                 * momentum + localVariance * (AccReal(1) - momentum);
   }
 
-  if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
-    for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
-      for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
-        const DType gradOut = gradOutput.get_ref(batch, plane, x);
-        if (is_train_and_not_global_stats) {
-          const DType inp = input.get_ref(batch, plane, x);
-          const AccReal proj = (inp - mean) * projScale;
-          gradInput.get_ref(batch, plane, x) =
-            ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
-        } else {
-          gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
-            gradOut * gradScale);
+  if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
+    const bool grad_write = flags & WRITE_DATA_FLAG;
+    if (grad_write) {
+      for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
+        for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
+          const DType gradOut = gradOutput.get_ref(batch, plane, x);
+          if (is_train_and_not_global_stats) {
+            const DType inp = input.get_ref(batch, plane, x);
+            const AccReal proj = (inp - mean) * projScale;
+            gradInput.get_ref(batch, plane, x) =
+              ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
+          } else {
+            gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
+              gradOut * gradScale);
+          }
+        }
+      }
+    } else {
+      // grad addto
+      for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
+        for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
+          const DType gradOut = gradOutput.get_ref(batch, plane, x);
+          if (is_train_and_not_global_stats) {
+            const DType inp = input.get_ref(batch, plane, x);
+            const AccReal proj = (inp - mean) * projScale;
+            gradInput.get_ref(batch, plane, x) +=
+              ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
+          } else {
+            gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
+              gradOut * gradScale);
+          }
         }
       }
     }
   }
 
-  if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
+  if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
+      (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
     if ((flags & FIX_GAMMA_FLAG) == 0) {
-      tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
+      if (flags & WRITE_GAMMA_FLAG)
+        tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
+      else
+        tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
     } else {
       tensors.gradWeight[plane] = DType(0);
     }
   }
 
-  if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
-    tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+  if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
+      (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
+    if (flags & WRITE_BETA_FLAG)
+      tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+    else
+      tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
   }
 }
 
@@ -585,12 +615,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
   flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
   if (IsBNWriting(req[batchnorm::kData])) {
     flags |= WRITE_DATA_FLAG;
+  } else if (req[batchnorm::kData] == kAddTo) {
+    flags |= ADDTO_DATA_FLAG;
   }
   if (IsBNWriting(req[batchnorm::kGamma])) {
     flags |= WRITE_GAMMA_FLAG;
+  } else if (req[batchnorm::kGamma] == kAddTo) {
+    flags |= ADDTO_GAMMA_FLAG;
   }
   if (IsBNWriting(req[batchnorm::kBeta])) {
     flags |= WRITE_BETA_FLAG;
+  } else if (req[batchnorm::kBeta] == kAddTo) {
+    flags |= ADDTO_BETA_FLAG;
   }
   return flags;
 }
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index 3fc9119..5dad073 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -208,13 +208,24 @@ class CuDNNBatchNormOp {
 
       if (param_.fix_gamma) gamma = 1.f;
 
+      bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) ||
+                                 (req[cudnnbatchnorm::kBeta] == kAddTo);
+      if (grad_add_gamma_beta) {
+        if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
+          dgamma = 0.f;
+        }
+        if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
+          dbeta = 0.f;
+        }
+      }
+
       CUDNN_CALL(cudnnBatchNormalizationBackward(
         s->dnn_handle_,
         mode,
         &a,
-        &b,
+        req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
         &a,
-        req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
+        grad_add_gamma_beta ? &b_add : &b,  // gamma and beta
         io_desc_,
         x.dptr_,
         io_desc_,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 26637c7..2e0fb64 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -317,13 +317,15 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
   else if (diff.IsDefaultData())
     diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
   auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
-  auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
+  auto gradi_mem = CreateMKLDNNMem(const_cast<NDArray &>(gradIn),
+      bwd.pd.diff_src_desc(), req[batchnorm::kData]);
 
   if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
     DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
     nnvm::dim_t channels_ = data.shape()[1];
+    const size_t copy_size = sizeof(DType) * channels_;
     for (int i = 0; i < channels_; i++) {
       if (!param.fix_gamma)
         weight_buf[i] = (gamma.data().dptr<DType>())[i];   // weight
@@ -337,7 +339,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
 
     mkldnn_args_map_t net_args;
     net_args[MKLDNN_ARG_SRC] = *data_mem;
-    net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
+    net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second;
     net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
     net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
     net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
@@ -362,26 +364,46 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
       }
       net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
       net_args[MKLDNN_ARG_VARIANCE] = var_mem;
-      MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
-      MKLDNNStream::Get()->Submit();
     } else {
       net_args[MKLDNN_ARG_MEAN] =  *(moving_mean.GetMKLDNNData());
       net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
-      MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
-      MKLDNNStream::Get()->Submit();
     }
+    MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
+    CommitOutput(gradIn, gradi_mem);
+    MKLDNNStream::Get()->Submit();
 
     // copy data from gradw_mem to in_grad[1] and in_grad[2]
     DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
-    for (int i = 0; i < channels_; i++) {
-      if (!param.fix_gamma)
-        (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
-      else
+    DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
+    DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
+
+    // the gradient of gamma
+    if (!param.fix_gamma) {
+      if (req[batchnorm::kGamma] != kNullOp) {
+        if (req[batchnorm::kGamma] != kAddTo) {
+          memcpy(w_grad_1, gw_buf, copy_size);
+        } else {
+          for (int i = 0; i < channels_; i++) {
+            w_grad_1[i] += gw_buf[i];
+          }
+        }
+      }
+    } else {
+      for (int i = 0; i < channels_; i++) {
         (in_grad[1].data().dptr<DType>())[i] = 0.0f;
+      }
     }
 
-    for (int i = 0; i < channels_; i++) {
-      (in_grad[2].data().dptr<DType>())[i] = gw_buf[i + channels_];
+    // the gradient of beta
+    if (req[batchnorm::kBeta] != kNullOp) {
+      if (req[batchnorm::kBeta] != kAddTo) {
+        memcpy(w_grad_2, &gw_buf[channels_], copy_size);
+      } else {
+        DType *grad_beta = &gw_buf[channels_];
+        for (int i = 0; i < channels_; i++) {
+          w_grad_2[i] += grad_beta[i];
+        }
+      }
     }
   } else {
     LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1ff1b61..fe2df9e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1073,6 +1073,163 @@ def test_npx_batch_dot():
 
 @with_seed()
 @use_np
+def test_npx_batch_norm():
+    momentum = 0.9
+    epsilon = 1e-5
+    class TestBatchNorm(HybridBlock):
+        def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs):
+            super(TestBatchNorm, self).__init__()
+            self.eps = eps
+            self.fix_gamma = fix_gamma
+            self.momentum = momentum
+            self.kwargs = kwargs
+        def hybrid_forward(self, F, data, bn_gamma, bn_beta,
+                           bn_running_mean, bn_running_var):
+            op = F.npx.batch_norm
+            output = op(data, bn_gamma, bn_beta,
+                        bn_running_mean, bn_running_var,
+                        momentum=self.momentum, eps=self.eps,
+                        fix_gamma=self.fix_gamma, **self.kwargs)
+            return output
+
+    def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var,
+                             axis,
+                             data_grad_req, gamma_grad_req, beta_grad_req):
+        kwargs = dict(output_mean_var=output_mean_var)
+        kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
+        op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs)
+        nch = shape[axis]
+
+        if not fix_gamma:
+            bn_gamma = np.random.uniform(size=(nch,))
+            bn_gamma.attach_grad(grad_req=gamma_grad_req)
+        else:
+            bn_gamma = np.ones((nch,))
+
+        bn_beta = np.random.uniform(size=(nch,))
+        bn_beta.attach_grad(grad_req=beta_grad_req)
+
+        bn_running_mean = np.zeros(nch)
+        bn_running_var = np.ones(nch)
+
+        running_mean = np.zeros(nch)
+        running_var = np.ones(nch)
+        num_iters = 10
+        expand_shape = [1] * len(shape)
+        expand_shape[axis] = shape[axis]
+        expand_shape = tuple(expand_shape)
+        data = np.random.uniform(size=shape)
+        data.attach_grad(grad_req=data_grad_req)
+        adX, adW, adb = 0, 0, 0
+        is_train = data_grad_req != 'null' or \
+            (not fix_gamma and gamma_grad_req != 'null') or \
+            beta_grad_req != 'null'
+        for _ in range(num_iters):
+            if data_grad_req != 'add':
+                data = np.random.uniform(size=shape)
+                data.attach_grad(grad_req=data_grad_req)
+            ograd = np.random.uniform(size=shape)
+            with mx.autograd.record():
+                output = op(data, bn_gamma, bn_beta,
+                            bn_running_mean, bn_running_var)
+                if output_mean_var:
+                    output, output_mean, output_std = output
+                if is_train:
+                    output.backward(ograd)
+            mx.nd.waitall()
+
+            assert 0 <= axis < data.ndim
+            reduce_axis = tuple(i for i in range(data.ndim) if i != axis)
+            assert len(reduce_axis) == data.ndim - 1
+            data_mean = data.mean(
+                axis=reduce_axis, keepdims=True)
+            data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis,
+                                                        keepdims=True)
+
+            target_output = (data - data_mean) / \
+                np.sqrt(data_var + epsilon) * \
+                bn_gamma.reshape(expand_shape) + \
+                bn_beta.reshape(expand_shape)
+
+            # squeeze data_mean and data_var
+            data_mean_flat = data_mean.squeeze()
+            data_var_flat = data_var.squeeze()
+
+            running_mean = running_mean * momentum + \
+                data_mean_flat * (1 - momentum)
+            running_var = running_var * momentum + \
+                data_var_flat * (1 - momentum)
+
+            W = bn_gamma.reshape(expand_shape)
+            dnx = ograd * W
+            xsm = data - data_mean
+            nd = 1.0 / np.sqrt(data_var + epsilon)
+            nx = xsm * nd
+            m = _np.prod(shape) / shape[axis]
+            dvar = np.sum(dnx * xsm, axis=reduce_axis, keepdims=True,
+                                  ) * (-0.5) * np.power(nd, 3)
+            dmean = -nd * np.sum(dnx, axis=reduce_axis, keepdims=True) - \
+                dvar * xsm.mean(axis=reduce_axis, keepdims=True,
+                                ) * 2.0
+            dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
+            dW = np.sum(ograd * nx, axis=reduce_axis)
+            db = np.sum(ograd, axis=reduce_axis)
+            adX = dX if data_grad_req != 'add' else adX + dX
+            adW = dW if gamma_grad_req != 'add' else adW + dW
+            adb = db if beta_grad_req != 'add' else adb + db
+
+            atol, rtol = 5e-2, 5e-2
+
+            if output_mean_var:
+                assert_almost_equal(output_mean.asnumpy(),
+                                    data_mean_flat.asnumpy(),
+                                    atol=atol, rtol=rtol)
+                assert_almost_equal(output_std.asnumpy(),
+                                    (1.0 / np.sqrt(data_var_flat +
+                                            epsilon)).asnumpy(),
+                                    atol=atol, rtol=rtol)
+            assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
+                                atol=atol, rtol=rtol)
+            if is_train:
+                assert_almost_equal(bn_running_mean.asnumpy(
+                ), running_mean.asnumpy(), atol=atol, rtol=rtol)
+                assert_almost_equal(bn_running_var.asnumpy(
+                ), running_var.asnumpy(), atol=atol, rtol=rtol)
+
+            if data_grad_req != 'null':
+                assert_almost_equal(data.grad.asnumpy(),
+                                    adX.asnumpy(), atol=atol, rtol=rtol)
+            if not fix_gamma:
+                if gamma_grad_req != 'null':
+                    assert_almost_equal(
+                        bn_gamma.grad.asnumpy(), adW.asnumpy(),
+                        atol=atol, rtol=rtol)
+            else:
+                assert((bn_gamma.asnumpy() == 1).all())
+            if beta_grad_req != 'null':
+                assert_almost_equal(
+                    bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
+
+    shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+    bools = [False, True]
+    for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
+            shapes, bools, bools, bools):
+        grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
+        for data_grad_req in grad_reqs:
+            for gamma_grad_req in grad_reqs:
+                if fix_gamma and gamma_grad_req != 'null':
+                    continue
+                for beta_grad_req in grad_reqs:
+                    for axis in range(len(shape)):
+                        _test_batchnorm_impl(
+                            shape, fix_gamma, cudnn_off, output_mean_var,
+                            axis,
+                            data_grad_req,
+                            gamma_grad_req, beta_grad_req)
+
+
+@with_seed()
+@use_np
 def test_npi_boolean_assign():
     class TestBooleanAssignScalar(HybridBlock):
         def __init__(self, val):
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 39fd16d..0dcb476 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1826,11 +1826,18 @@ def test_batchnorm():
     momentum = 0.9
     epsilon = 1e-5
 
-    def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var):
-        print(str((op, shape, axis, cudnn_off)))
-
+    def _test_batchnorm_impl(op_name, shape, fix_gamma, cudnn_off, output_mean_var,
+                             axis,
+                             data_grad_req, gamma_grad_req, beta_grad_req):
+
+        if op_name == 'BatchNorm':
+            op = mx.nd.BatchNorm
+        elif op_name == 'SyncBatchNorm':
+            op = mx.nd.contrib.SyncBatchNorm
+        else:
+            raise ValueError('Not supported {}'.format(op_name))
         kwargs = dict(output_mean_var=output_mean_var)
-        if op == mx.nd.contrib.SyncBatchNorm:
+        if op_name == 'SyncBatchNorm':
             if axis != 1:
                 return
             key = str(op) + str(shape) + str(axis)
@@ -1841,11 +1848,14 @@ def test_batchnorm():
             kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
         nch = shape[axis]
 
-        bn_gamma = mx.nd.random.uniform(shape=(nch,))
-        bn_gamma.attach_grad()
+        if not fix_gamma:
+            bn_gamma = mx.nd.random.uniform(shape=(nch,))
+            bn_gamma.attach_grad(grad_req=gamma_grad_req)
+        else:
+            bn_gamma = mx.nd.ones(shape=(nch,))
 
         bn_beta = mx.nd.random.uniform(shape=(nch,))
-        bn_beta.attach_grad()
+        bn_beta.attach_grad(grad_req=beta_grad_req)
 
         bn_running_mean = mx.nd.zeros(nch)
         bn_running_var = mx.nd.ones(nch)
@@ -1855,18 +1865,26 @@ def test_batchnorm():
         num_iters = 10
         expand_shape = [1] * len(shape)
         expand_shape[axis] = shape[axis]
+        data = mx.nd.random.uniform(shape=shape)
+        data.attach_grad(grad_req=data_grad_req)
+        adX, adW, adb = 0, 0, 0
+        is_train = data_grad_req != 'null' or \
+            (not fix_gamma and gamma_grad_req != 'null') or \
+            beta_grad_req != 'null'
         for _ in range(num_iters):
-            data = mx.nd.random.uniform(shape=shape)
-            data.attach_grad()
+            if data_grad_req != 'add':
+                data = mx.nd.random.uniform(shape=shape)
+                data.attach_grad(grad_req=data_grad_req)
             ograd = mx.nd.random.uniform(shape=shape)
             with mx.autograd.record():
                 output = op(data, bn_gamma, bn_beta,
                             bn_running_mean, bn_running_var,
                             momentum=momentum, eps=epsilon,
-                            fix_gamma=False, **kwargs)
+                            fix_gamma=fix_gamma, **kwargs)
                 if output_mean_var:
                     output, output_mean, output_std = output
-                output.backward(ograd)
+                if is_train:
+                    output.backward(ograd)
             mx.nd.waitall()
 
             data_mean = data.mean(
@@ -1903,9 +1921,11 @@ def test_batchnorm():
             dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
             dW = (ograd * nx).sum(axis=axis, exclude=True)
             db = ograd.sum(axis=axis, exclude=True)
+            adX = dX if data_grad_req != 'add' else adX + dX
+            adW = dW if gamma_grad_req != 'add' else adW + dW
+            adb = db if beta_grad_req != 'add' else adb + db
 
-            atol = 1e-2
-            rtol = 1e-2
+            atol, rtol = 5e-2, 5e-2
 
             if output_mean_var:
                 assert_almost_equal(output_mean.asnumpy(),
@@ -1922,26 +1942,43 @@ def test_batchnorm():
                                         atol=atol, rtol=rtol)
             assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
                                 atol=atol, rtol=rtol)
-            assert_almost_equal(bn_running_mean.asnumpy(
-            ), running_mean.asnumpy(), atol=atol, rtol=rtol)
-            assert_almost_equal(bn_running_var.asnumpy(
-            ), running_var.asnumpy(), atol=atol, rtol=rtol)
-
-            assert_almost_equal(data.grad.asnumpy(),
-                                dX.asnumpy(), atol=atol, rtol=rtol)
-            assert_almost_equal(
-                bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol)
-            assert_almost_equal(
-                bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol)
-
-    for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]:
-        for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]:
-            for axis in range(len(shape)):
-                for cudnn_off in [False, True]:
-                    for output_mean_var in [False, True]:
-                        _test_batchnorm_impl(op, shape, axis,
-                                             cudnn_off, output_mean_var)
-
+            if is_train:
+                assert_almost_equal(bn_running_mean.asnumpy(
+                ), running_mean.asnumpy(), atol=atol, rtol=rtol)
+                assert_almost_equal(bn_running_var.asnumpy(
+                ), running_var.asnumpy(), atol=atol, rtol=rtol)
+
+            if data_grad_req != 'null':
+                assert_almost_equal(data.grad.asnumpy(),
+                                    adX.asnumpy(), atol=atol, rtol=rtol)
+            if not fix_gamma:
+                if gamma_grad_req != 'null':
+                    assert_almost_equal(
+                        bn_gamma.grad.asnumpy(), adW.asnumpy(),
+                        atol=atol, rtol=rtol)
+            else:
+                assert((bn_gamma.asnumpy() == 1).all())
+            if beta_grad_req != 'null':
+                assert_almost_equal(
+                    bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
+
+    op_names = ['BatchNorm', 'SyncBatchNorm']
+    shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+    bools = [False, True]
+    for op_name, shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
+            op_names, shapes, bools, bools, bools):
+        grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
+        for data_grad_req in grad_reqs:
+            for gamma_grad_req in grad_reqs:
+                if fix_gamma and gamma_grad_req != 'null':
+                    continue
+                for beta_grad_req in grad_reqs:
+                    for axis in range(len(shape)):
+                        _test_batchnorm_impl(
+                            op_name, shape, fix_gamma, cudnn_off, output_mean_var,
+                            axis,
+                            data_grad_req,
+                            gamma_grad_req, beta_grad_req)
 
 @with_seed()
 def test_groupnorm():