You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/08 09:11:42 UTC
[incubator-mxnet] branch v1.x updated: Fixed issue with batchnorm on even number of channels (#20895)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 6fa7314 Fixed issue with batchnorm on even number of channels (#20895)
6fa7314 is described below
commit 6fa7314a8188c9f860f7ec1120696ee2a62f834f
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Tue Mar 8 10:07:54 2022 +0100
Fixed issue with batchnorm on even number of channels (#20895)
---
src/operator/contrib/batch_norm_relu.cc | 10 +++---
src/operator/nn/batch_norm.cc | 10 +++---
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 42 ++++++++++++++++++--------
3 files changed, 42 insertions(+), 20 deletions(-)
diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc
index 51aa4c5..21bbcb1 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -139,12 +139,13 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = true;
+
if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_GT(outputs.size(), 3U);
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx,
+ inputs, req, outputs);
});
return;
}
@@ -157,11 +158,12 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = true;
+
if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_EQ(inputs.size(), 9U);
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx,
+ inputs, req, outputs);
return;
}
LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend.";
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index b62da0f..1550bb2 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -467,11 +467,12 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = false;
+
if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx,
+ inputs, req, outputs);
});
MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
@@ -485,10 +486,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = false;
+
if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx,
+ inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 10b27be..ea7ec6f 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -145,12 +145,12 @@ static MKLDNNBNForward& GetBNForward(const BatchNormParam& param,
}
template <typename DType>
-void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
+void MKLDNNBatchNormForwardImpl(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs,
+ bool fuse_relu) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
@@ -267,6 +267,15 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
}
}
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ MKLDNNBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+ fuse_relu);
+}
+
class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
const std::shared_ptr<mkldnn::memory> weight_m;
@@ -323,12 +332,12 @@ static MKLDNNBNBackward& GetBNBackward(const BatchNormParam& param,
}
template <typename DType>
-void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
+void MKLDNNBatchNormBackwardImpl(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs,
+ bool fuse_relu) {
if (fuse_relu) {
CHECK_EQ(inputs.size(), 9U);
} else {
@@ -493,6 +502,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
}
}
+
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ MKLDNNBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+ fuse_relu);
+}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN