You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/07 10:56:37 UTC
[incubator-mxnet] branch v1.9.x updated: Fixed issue with batchnorm on even number of channels (#20927)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new ae7a104 Fixed issue with batchnorm on even number of channels (#20927)
ae7a104 is described below
commit ae7a104fc19d3586db932235d5eb8da3b123e8dc
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Mon Mar 7 11:54:18 2022 +0100
Fixed issue with batchnorm on even number of channels (#20927)
---
src/operator/contrib/batch_norm_relu.cc | 10 ++++---
src/operator/nn/batch_norm.cc | 10 ++++---
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 36 +++++++++++++++++++++-----
3 files changed, 42 insertions(+), 14 deletions(-)
diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc
index d1f409c..e9a0e95 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -138,12 +138,13 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = true;
+
if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_GT(outputs.size(), 3U);
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx,
+ inputs, req, outputs);
});
return;
}
@@ -156,11 +157,12 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = true;
+
if (SupportMKLDNNBNReLU(inputs[0], param)) {
CHECK_EQ(inputs.size(), 9U);
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx,
+ inputs, req, outputs);
return;
}
LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend.";
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 7701099..ef39f22 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -452,11 +452,12 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = false;
+
if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNN_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- MKLDNNBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx,
+ inputs, req, outputs);
});
MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
@@ -470,10 +471,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- bool fuse_relu = false;
+
if (SupportMKLDNNBN(inputs[0], param)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- MKLDNNBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ MKLDNNRun(MKLDNNBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx,
+ inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 75c7c4d..b443d3c 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -146,9 +146,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
}
template <typename DType>
-void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
- const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
- const std::vector<NDArray> &outputs, bool fuse_relu) {
+void MKLDNNBatchNormForwardImpl(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs,
+ bool fuse_relu) {
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
@@ -263,6 +266,15 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
}
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ MKLDNNBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+ fuse_relu);
+}
+
class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
const std::shared_ptr<mkldnn::memory> weight_m;
@@ -310,9 +322,12 @@ static MKLDNNBNBackward &GetBNBackward(
}
template <typename DType>
-void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
- const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
- const std::vector<NDArray> &outputs, bool fuse_relu) {
+void MKLDNNBatchNormBackwardImpl(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs,
+ bool fuse_relu) {
if (fuse_relu) {
CHECK_EQ(inputs.size(), 9U);
} else {
@@ -477,6 +492,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
}
}
+
+template <typename DType, bool fuse_relu>
+void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ MKLDNNBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs,
+ fuse_relu);
+}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN