You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/07/06 14:45:48 UTC

[incubator-mxnet] branch master updated: [BUGFIX] Add checks in BatchNorm's infer shape (#20415)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b74491f  [BUGFIX] Add checks in BatchNorm's infer shape (#20415)
b74491f is described below

commit b74491fb1c39c90aebb2e4162f44ef995ab1b30a
Author: bgawrych <ba...@intel.com>
AuthorDate: Tue Jul 6 16:44:27 2021 +0200

    [BUGFIX] Add checks in BatchNorm's infer shape (#20415)
---
 src/operator/nn/batch_norm.cc                  | 18 ++++++++--------
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 30 +++++++++++++-------------
 2 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 87456dd..be0b015 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -375,15 +375,15 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
 
   const index_t channelCount = dshape[channelAxis];
 
-  in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
-  in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
-  in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount));  // kMovingMean
-  in_shape->at(batchnorm::kInMovingVar) = mxnet::TShape(Shape1(channelCount));  // kMovingVar
-
-  out_shape->clear();
-  out_shape->push_back(dshape);                // kOut
-  out_shape->push_back(Shape1(channelCount));  // kMean
-  out_shape->push_back(Shape1(channelCount));  // kVar
+  SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kGamma, Shape1(channelCount));
+  SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kBeta, Shape1(channelCount));
+  SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingMean, Shape1(channelCount));  // kMovingMean
+  SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingVar, Shape1(channelCount));   // kMovingVar
+
+
+  SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kOut, dshape);
+  SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kMean, Shape1(channelCount));
+  SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kVar, Shape1(channelCount));
 
   return true;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 963ed2c..5a6f84c 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -159,10 +159,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   if (param.axis != 1 || shape.ndim() != 4) {
     // reshape to (N, C, 1, D)
     mxnet::TShape new_shape{
-      static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+      static_cast<index_t>(shape.ProdShape(0, real_axis)),
       shape[real_axis],
       1,
-      static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+      static_cast<index_t>(shape.ProdShape(real_axis + 1,
             static_cast<int>(shape.ndim())))
     };
     in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
@@ -195,7 +195,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     const mkldnn::memory &weight_mem = fwd.GetWeight();
     float* weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());
 
-    nnvm::dim_t channels_ = data.shape()[1];
+    index_t channels_ = data.shape()[1];
     CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
     float* weight_ptr = gamma.data().dptr<float>();
     float* bias_ptr = beta.data().dptr<float>();
@@ -204,13 +204,13 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
       memcpy(weight_buf, weight_ptr, copy_size);
       memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     } else if (IsBNWriting(req[batchnorm::kGamma])) {
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         weight_buf[i] = 1.0f;
         weight_ptr[i] = 1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
     } else {
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         weight_buf[i] = 1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
@@ -237,7 +237,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
       float* inmean   = aux_states[batchnorm::kMovingMean].data().dptr<float>();
       float* invar    = aux_states[batchnorm::kMovingVar].data().dptr<float>();
       // to align with origin implmentation: batch_norm.cc: L164
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         omean[i] = inmean[i];
         ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
       }
@@ -254,7 +254,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
       MKLDNNStream::Get()->Submit();
 
       float* ovar = outVar.data().dptr<float>();
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
       }
     }
@@ -357,10 +357,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   if (param.axis != 1 || shape.ndim() != 4) {
     // reshape to (N, C, 1, D)
     mxnet::TShape new_shape{
-      static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+      static_cast<index_t>(shape.ProdShape(0, real_axis)),
       shape[real_axis],
       1,
-      static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+      static_cast<index_t>(shape.ProdShape(real_axis + 1,
             static_cast<int>(shape.ndim())))
     };
     data = data.Reshape(new_shape);
@@ -384,7 +384,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
     DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
-    nnvm::dim_t channels_ = data.shape()[1];
+    index_t channels_ = data.shape()[1];
     DType *weight_ptr = gamma.data().dptr<DType>();
     DType* bias_ptr = beta.data().dptr<DType>();
     const size_t copy_size = sizeof(DType) * channels_;
@@ -392,7 +392,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
       memcpy(weight_buf, weight_ptr, copy_size);
       memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     } else {
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         weight_buf[i] = static_cast<DType>(1.0f);
       }
       memcpy(&weight_buf[channels_], bias_ptr, copy_size);
@@ -422,7 +422,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
       DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
 
       DType minus_mom = (1.0f - param.momentum);
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum +
                              out_mean_ptr[i] * minus_mom;
         float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
@@ -451,13 +451,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
         if (req[batchnorm::kGamma] != kAddTo) {
           memcpy(w_grad_1, gw_buf, copy_size);
         } else {
-          for (int i = 0; i < channels_; i++) {
+          for (index_t i = 0; i < channels_; i++) {
             w_grad_1[i] += gw_buf[i];
           }
         }
       }
     } else {
-      for (int i = 0; i < channels_; i++) {
+      for (index_t i = 0; i < channels_; i++) {
         (in_grad[1].data().dptr<DType>())[i] = 0.0f;
       }
     }
@@ -468,7 +468,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
         memcpy(w_grad_2, &gw_buf[channels_], copy_size);
       } else {
         DType *grad_beta = &gw_buf[channels_];
-        for (int i = 0; i < channels_; i++) {
+        for (index_t i = 0; i < channels_; i++) {
           w_grad_2[i] += grad_beta[i];
         }
       }