You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/10 18:38:36 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12625: [MXNET-979] Add fix_beta support in BatchNorm

sandeep-krishnamurthy closed pull request #12625: [MXNET-979] Add fix_beta support in BatchNorm
URL: https://github.com/apache/incubator-mxnet/pull/12625
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index d26841977ac..26ef64dfd0b 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -324,7 +324,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
                  in_channels=0, **kwargs):
         super(BatchNorm, self).__init__(**kwargs)
         self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
-                        'fix_gamma': not scale, 'use_global_stats': use_global_stats}
+                        'fix_gamma': not scale, 'fix_beta': not center, 'use_global_stats': use_global_stats}
         if in_channels != 0:
             self.in_channels = in_channels
 
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 3f47d58bb8c..f8b381c87be 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -62,6 +62,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
   double eps;
   float momentum;
   bool fix_gamma;
+  bool fix_beta;
   bool use_global_stats;
   bool output_mean_var;
   int axis;
@@ -75,6 +76,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
     .describe("Momentum for moving average");
     DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
     .describe("Fix gamma while training");
+    DMLC_DECLARE_FIELD(fix_beta).set_default(false)
+    .describe("Fix beta while training");
     DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
     .describe("Whether use global moving statistics instead of local batch-norm. "
               "This will force change batch-norm into a scale shift operator.");
@@ -90,6 +93,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
     return this->eps == other.eps &&
            this->momentum == other.momentum &&
            this->fix_gamma == other.fix_gamma &&
+           this->fix_beta == other.fix_beta &&
            this->use_global_stats == other.use_global_stats &&
            this->output_mean_var == other.output_mean_var &&
            this->axis == other.axis &&
@@ -107,6 +111,7 @@ struct hash<mxnet::op::BatchNormParam> {
     size_t ret = 0;
     ret = dmlc::HashCombine(ret, val.momentum);
     ret = dmlc::HashCombine(ret, val.fix_gamma);
+    ret = dmlc::HashCombine(ret, val.fix_beta);
     ret = dmlc::HashCombine(ret, val.use_global_stats);
     ret = dmlc::HashCombine(ret, val.output_mean_var);
     ret = dmlc::HashCombine(ret, val.axis);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index f28f5d7a436..dc1d123fd00 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -155,35 +155,34 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,
 
     // compute output
     AccReal *w = weights.dptr<AccReal>();
-    const AccReal *b = bias.dptr<AccReal>();
+    AccReal *b = bias.dptr<AccReal>();
+
+    // Ignore gamma
+    if (param_.fix_gamma) {
+      if (IsBNWriting(req[batchnorm::kGamma])) {
+        w[channel] = AccReal(1);
+      }
+    }
+
+    // Ignore beta
+    if (param_.fix_beta) {
+       if (IsBNWriting(req[batchnorm::kBeta])) {
+          b[channel] = AccReal(0);
+        }
+    }
 
     const AccReal thisMean = mean[channel];
     const AccReal thisInvstd = var[channel];
     const AccReal thisWeight = w[channel];
     const AccReal thisBias = b[channel];
 
-    // note that var is still invstd
-    if (!param_.fix_gamma) {
-      if (IsBNWriting(req[batchnorm::kData])) {
-        ForEachFast(inputData, outputData, channel,
-                    [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
-                                                                 DType *out_data) {
-                      *out_data = static_cast<DType>(
-                        ((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
-                    });
-      }
-    } else {
-      if (IsBNWriting(req[batchnorm::kGamma])) {
-        w[channel] = AccReal(1);
-      }
-      if (IsBNWriting(req[batchnorm::kData])) {
-        ForEachFast(inputData, outputData, channel,
-                    [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
-                                                                 DType *out_data) {
-                      *out_data = static_cast<DType>(
-                        ((*in_data - thisMean) * thisInvstd) + thisBias);
-                    });
-      }
+    if (IsBNWriting(req[batchnorm::kData])) {
+          ForEachFast(inputData, outputData, channel,
+                      [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
+                                                                  DType *out_data) {
+                        *out_data = static_cast<DType>(
+                          ((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
+                      });
     }
   }
 }
@@ -309,7 +308,11 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
     }
 
     if (IsBNWriting(req[batchnorm::kBeta])) {
-      gradBiasData[channel] = scale * sumGradOut;
+      if (!param_.fix_beta) {
+        gradBiasData[channel] = scale * sumGradOut;
+      } else {
+        gradBiasData[channel] = AccReal(0);
+      }
     }
   }
 }
@@ -478,6 +481,9 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
   if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) {
     LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647";
   }
+  if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_beta) {
+    LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked at #11647";
+  }
   return dispatched;
 }
 
@@ -565,11 +571,12 @@ the 'channel' (separately normalized groups).  The default is 1.  Specifying -1
 axis to be the last item in the input shape.
 
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
-then set ``gamma`` to 1 and its gradient to 0.
+then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then set ``beta`` to 0
+and its gradient to 0.
 
 Note::
 
-When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False,
+When fix_gamma/fix_beta is set to True, no sparse support is provided. If fix_gamma/fix_beta is set to False,
 the sparse tensors will fallback.
 
 )code" ADD_FILELINE)
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 03962cbc0f3..309542d33c2 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -32,8 +32,9 @@
 #define WRITE_GAMMA_FLAG      2
 #define WRITE_BETA_FLAG       4
 #define FIX_GAMMA_FLAG        8
-#define IS_TRAINING_FLAG      16
-#define USE_GLOBAL_STATS_FLAG 32
+#define FIX_BETA_FLAG         16
+#define IS_TRAINING_FLAG      32
+#define USE_GLOBAL_STATS_FLAG 64
 
 #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 #include "./cudnn/cudnn_batch_norm-inl.h"
@@ -223,8 +224,9 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel(
   AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
                   ? ScalarConvert<DType, AccReal>::to(weight[plane])
                   : ScalarConvert<int, AccReal>::to(1);
-  AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
-                                        : ScalarConvert<int, AccReal>::to(0);
+  AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+                  ? ScalarConvert<DType, AccReal>::to(bias[plane])
+                  : ScalarConvert<int, AccReal>::to(0);
   if (threadIdx.x == 0) {
     saveMean[plane] = runningMean[plane];
     saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
@@ -232,6 +234,10 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel(
         && weight.numElements() > 0) {
       weight[plane] = AccReal(1);
     }
+    if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+        && bias.numElements() > 0) {
+      bias[plane] = AccReal(0);
+    }
   }
   // Write normalized and update the output
   for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
@@ -282,14 +288,19 @@ __global__ void BatchNormalizationUpdateOutputKernel(
         && weight.numElements() > 0) {
       weight[plane] = AccReal(1);
     }
+    if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+        && bias.numElements() > 0) {
+      bias[plane] = AccReal(0);
+    }
   }
 
   // Write normalized and update the output
   const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
                         ? ScalarConvert<DType, AccReal>::to(weight[plane])
                         : ScalarConvert<int, AccReal>::to(1);
-  const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
-                                              : ScalarConvert<int, AccReal>::to(0);
+  const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+                        ? ScalarConvert<DType, AccReal>::to(bias[plane])
+                        : ScalarConvert<int, AccReal>::to(0);
   for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
     for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
       const DType inp = input.get_ref(batch, plane, x);
@@ -388,7 +399,11 @@ static __global__ void BatchNormalizationBackwardKernel(
   }
 
   if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
-    tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+    if ((flags & FIX_BETA_FLAG) == 0) {
+      tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+    } else {
+      tensors.gradBias[plane] = DType(0);
+    }
   }
 }
 
@@ -582,6 +597,7 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
   uint32_t flags = 0;
   flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
   flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
+  flags |= params.fix_beta ? FIX_BETA_FLAG : 0;
   flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
   if (IsBNWriting(req[batchnorm::kData])) {
     flags |= WRITE_DATA_FLAG;
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index d4b9f84ed2f..9caa9d3ddd3 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -115,6 +115,8 @@ class CuDNNBatchNormOp {
 
       if (param_.fix_gamma) gamma = 1.f;
 
+      if (param_.fix_beta) beta = 0.f;
+
       if (ctx.is_train) {
         Tensor<gpu, 1, DTypeParam> save_mean =
           out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
@@ -229,6 +231,7 @@ class CuDNNBatchNormOp {
         global_stats ? nullptr : save_mean.dptr_,
         global_stats ? nullptr : save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
+      if (param_.fix_beta) dbeta = 0.f;
     })
 #else  // CUDNN_VERSION < 4007
     MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
@@ -267,6 +270,7 @@ class CuDNNBatchNormOp {
                                                  global_stats ? nullptr : save_mean.dptr_,
                                                  global_stats ? nullptr : save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
+      if (param_.fix_beta) dbeta = 0.f;
     })
 #endif
   }
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index dd7ec985c7c..13022c108c4 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -303,35 +303,52 @@ def test_batchnorm_with_type():
 
 
   # V2, 2D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_2D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_2D)
 
   # V2, 1D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_1D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  #
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_1D)
+
   # # V2, 3D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_3D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_3D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_3D)
-
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=False, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_3D)
 
 @with_seed()
 def test_batchnorm_versions():
-  def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_global_stats):
+  def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, use_global_stats):
     ctx_list = []
     sym_list = []
     # BatchNormV1 cpu
@@ -352,6 +369,7 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_globa
     if 'batchnorm_cpu' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm'))
 
@@ -359,6 +377,7 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_globa
     if 'batchnorm_gpu' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm', cudnn_off=True))
 
@@ -366,47 +385,54 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_globa
     if 'batchnorm_cudnn' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm', cudnn_off=False))
 
     check_consistency(sym_list, ctx_list)
 
-  def test_1d_batchnorm(fix_gamma, use_global_stats):
+  def test_1d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 20)
     test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu', 'batchnorm_cudnn'],
                                    data=data,
-                                   fix_gamma=fix_gamma, use_global_stats=use_global_stats)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
 
-  def test_2d_batchnorm(fix_gamma, use_global_stats):
+  def test_2d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 10, 10)
-    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu',
-                                                      'batchnorm_cpu',
+    # batchmorm_v1 is deprecated.
+    # `fix_beta` parameter is available only in new batchnorm operator.
+    # Checking consistency separately for batchnormv1 and batchnorm.
+    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu'],
+                                   data=data,
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
+
+    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu', 'batchnorm_cudnn'],
                                    data=data,
-                                   fix_gamma=fix_gamma, use_global_stats=use_global_stats)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
 
-  def test_3d_batchnorm(fix_gamma, use_global_stats):
+  def test_3d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 3, 5, 5)
     test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu'],
                                    data=data,
-                                   fix_gamma=fix_gamma, use_global_stats=use_global_stats)
-
-  test_1d_batchnorm(True,  False)
-  test_1d_batchnorm(False, False)
-  test_1d_batchnorm(False, True)
-  test_1d_batchnorm(True,  True)
-
-  test_2d_batchnorm(True,  False)
-  test_2d_batchnorm(False, False)
-  test_2d_batchnorm(False, True)
-  test_2d_batchnorm(True,  True)
-
-  test_3d_batchnorm(True,  False)
-  test_3d_batchnorm(False, False)
-  test_3d_batchnorm(False, True)
-  test_3d_batchnorm(True,  True)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
+
+  test_1d_batchnorm(True,  False, False)
+  test_1d_batchnorm(False, True, False)
+  test_1d_batchnorm(False, False, True)
+  test_1d_batchnorm(True,  True, True)
+
+  test_2d_batchnorm(True,  False, False)
+  test_2d_batchnorm(False, True, False)
+  test_2d_batchnorm(False, False, True)
+  test_2d_batchnorm(True,  True, True)
+
+  test_3d_batchnorm(True,  False, False)
+  test_3d_batchnorm(False, True, False)
+  test_3d_batchnorm(False, False, True)
+  test_3d_batchnorm(True,  True, True)
 
 
 @with_seed(1234)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 53e4051fc07..16480874831 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -235,7 +235,7 @@ def check_batchnorm_training(stype):
                            mx.nd.array(beta).tostype(stype)]
             mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
     stypes = ['row_sparse', 'default']
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index b5a7303195f..639691568db 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1534,25 +1534,25 @@ def check_batchnorm_training(stype):
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True, use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             # Test varying channel axis
@@ -1581,16 +1581,16 @@ def check_batchnorm_training(stype):
                 xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
                              mx.nd.array(xrolling_std).tostype(stype)]
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=False, use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
     check_batchnorm_training('default')
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 57808248b08..bddab11f95d 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -2124,13 +2124,19 @@ def test_batchnorm_fallback():
         test = mx.symbol.BatchNorm(data, fix_gamma=True)
         assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
+        test = mx.symbol.BatchNorm(data, fix_beta=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
         test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
         assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
-        test = mx.symbol.BatchNorm(data, fix_gamma=False)
+        test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
         check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
-        test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True)
         check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
         # Test varying channel axis
@@ -2161,14 +2167,20 @@ def test_batchnorm_fallback():
 
             test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
             assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+            
+            test = mx.symbol.BatchNorm(data, fix_beta=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
             test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
             assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+            test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, axis=chaxis)
             check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True, axis=chaxis)
             check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services