You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/07 10:56:37 UTC

[incubator-mxnet] branch v1.9.x updated: Fixed issue with batchnorm on even number of channels (#20927)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new ae7a104  Fixed issue with batchnorm on even number of channels (#20927)
ae7a104 is described below

commit ae7a104fc19d3586db932235d5eb8da3b123e8dc
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Mon Mar 7 11:54:18 2022 +0100

    Fixed issue with batchnorm on even number of channels (#20927)
---
 src/operator/contrib/batch_norm_relu.cc        | 10 ++++---
 src/operator/nn/batch_norm.cc                  | 10 ++++---
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 36 +++++++++++++++++++++-----
 3 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc
index d1f409c..e9a0e95 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -138,12 +138,13 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
                                    const std::vector<NDArray> &outputs) {
   CHECK_EQ(inputs.size(), 5U);
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  bool fuse_relu = true;
+
   if (SupportMKLDNNBNReLU(inputs[0], param)) {
     CHECK_GT(outputs.size(), 3U);
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
-      MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+      MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx,
+                inputs, req, outputs);
     });
     return;
   }
@@ -156,11 +157,12 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs,
                                        const std::vector<OpReqType> &req,
                                        const std::vector<NDArray> &outputs) {
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  bool fuse_relu = true;
+
   if (SupportMKLDNNBNReLU(inputs[0], param)) {
       CHECK_EQ(inputs.size(), 9U);
       MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-      MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+      MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx,
+                inputs, req, outputs);
       return;
   }
   LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend.";
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 7701099..ef39f22 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -452,11 +452,12 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
                            const std::vector<NDArray> &outputs) {
   CHECK_EQ(inputs.size(), 5U);
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  bool fuse_relu = false;
+
   if (SupportMKLDNNBN(inputs[0], param)) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
-        MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+      MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx,
+                inputs, req, outputs);
     });
     MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
@@ -470,10 +471,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
                                const std::vector<OpReqType> &req,
                                const std::vector<NDArray> &outputs) {
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  bool fuse_relu = false;
+
   if (SupportMKLDNNBN(inputs[0], param)) {
       MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-      MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+      MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx,
+                inputs, req, outputs);
       MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
       return;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 75c7c4d..b443d3c 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -146,9 +146,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
 }
 
 template <typename DType>
-void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
-                            const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
-                            const std::vector<NDArray> &outputs, bool fuse_relu) {
+void MKLDNNBatchNormForwardImpl(const nnvm::NodeAttrs &attrs,
+                                const OpContext &ctx,
+                                const std::vector<NDArray> &inputs,
+                                const std::vector<OpReqType> &req,
+                                const std::vector<NDArray> &outputs,
+                                bool fuse_relu) {
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
   std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
 
@@ -263,6 +266,15 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   }
 }
 
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                            const std::vector<NDArray> &inputs,
+                            const std::vector<OpReqType> &req,
+                            const std::vector<NDArray> &outputs) {
+  MKLDNNBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+                                    fuse_relu);
+}
+
 class MKLDNNBNBackward {
   std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
   const std::shared_ptr<mkldnn::memory> weight_m;
@@ -310,9 +322,12 @@ static MKLDNNBNBackward &GetBNBackward(
 }
 
 template <typename DType>
-void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
-                             const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
-                             const std::vector<NDArray> &outputs, bool fuse_relu) {
+void MKLDNNBatchNormBackwardImpl(const nnvm::NodeAttrs &attrs,
+                                 const OpContext &ctx,
+                                 const std::vector<NDArray> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<NDArray> &outputs,
+                                 bool fuse_relu) {
   if (fuse_relu) {
     CHECK_EQ(inputs.size(), 9U);
   } else {
@@ -477,6 +492,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
   }
 }
+
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                             const std::vector<NDArray> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<NDArray> &outputs) {
+  MKLDNNBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+                                     fuse_relu);
+}
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN