You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/10/25 19:50:36 UTC

[incubator-mxnet] branch v1.7.x updated: [BUGFIX] Fix MKLDNN BatchNorm with even number of channels (#19150) (#19299)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new a22abce  [BUGFIX] Fix MKLDNN BatchNorm with even number of channels (#19150) (#19299)
a22abce is described below

commit a22abce0ce576ef4630aaea00cc9ad4d844f99f9
Author: Anna Karbownik <69...@users.noreply.github.com>
AuthorDate: Sun Oct 25 20:47:30 2020 +0100

    [BUGFIX] Fix MKLDNN BatchNorm with even number of channels (#19150) (#19299)
    
    * Fix MKLDNN BatchNorm with even number of channels (#19150)
    
    Even number of channels results in data reordering before batch
    norm operation. Therefore, if BatchNorm data array is view of
    another array and the data is stored in MKLDNN format, the data
    needs to be converted to the default format.
    
    * Add or updated test to verify Batchnorm odd & even number of channels
    
    * Fix for Batchnorm odd & even chnls number context
---
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 17 ++++++-------
 tests/python/mkl/test_mkldnn.py                |  2 +-
 tests/python/unittest/test_gluon.py            | 35 ++++++++++++++++++++++++++
 3 files changed, 43 insertions(+), 11 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 18055ca..0e7a056 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -139,13 +139,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
   return it->second;
 }
 
-template<typename DType>
-static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
-                                     const OpContext &ctx, const NDArray &in_data,
-                                     mkldnn::normalization_flags flags) {
-  return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
-}
-
 template <typename DType>
 void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
                             const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
@@ -176,8 +169,12 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
                                                 aux_states,
                                                 param,
                                                 ctx.is_train && !param.use_global_stats);
-  const NDArray &data = in_data[batchnorm::kData];
-  auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
+
+  NDArray &data = in_data[batchnorm::kData];
+  if (data.IsMKLDNNData() && data.IsView())
+    data = data.Reorder2Default();
+  auto data_mem = data.GetMKLDNNData();
+  auto &fwd = GetBNForward<DType>(param, ctx, data_mem, flags);
 
   // for output memory
   auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
@@ -215,7 +212,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
     }
 
     mkldnn_args_map_t net_args;
-    net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
+    net_args[MKLDNN_ARG_SRC] = *data_mem;
     net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
     net_args[MKLDNN_ARG_DST] = *out_mem;
 
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 44e7d3c..3be71f4 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -294,7 +294,7 @@ def test_mkldnn_sum_inplace_with_cpu_layout():
 @with_seed()
 def test_batchnorm():
     def check_batchnorm_training(stype):
-        for shape in [(2, 3), (2, 3, 2, 2)]:
+        for shape in [(2, 3), (2, 4), (2, 3, 2, 2), (2, 4, 2, 2)]:
             data_tmp = np.random.normal(-0.1, 0.1, size=shape)
             s = shape[1],
             gamma = np.ones(s)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index a026825..ae5af33 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -20,6 +20,7 @@ import tempfile
 
 import mxnet as mx
 from mxnet import gluon
+from mxnet import init
 from mxnet.gluon import nn
 from mxnet.base import py_str, MXNetError
 from mxnet.test_utils import assert_almost_equal
@@ -2161,6 +2162,40 @@ def test_batchnorm_16c():
 
 
 @with_seed()
+def test_batchnorm_chnls():
+    chn_list = [1024, 512, 256, 128, 64, 45, 32, 16, 3]
+    class Net(gluon.HybridBlock):
+        def __init__(self,
+                     chn_num,
+                     norm_kwargs=None,
+                     in_channels=3,
+                     **kwargs):
+            super(Net, self).__init__(**kwargs)
+            self.in_channels = in_channels
+            self.conv1 = gluon.nn.Conv3D(
+                    in_channels=self.in_channels,
+                    channels=chn_num,
+                    kernel_size=(1, 7, 7),
+                    strides=(1, 2, 2),
+                    padding=(0, 3, 3),
+                    use_bias=False,
+                    )
+            self.bn1 = gluon.nn.BatchNorm(in_channels=chn_num, **({} if norm_kwargs is None else norm_kwargs))
+
+        def hybrid_forward(self, F, x):
+            """Hybrid forward of R2+1D net"""
+            conv = self.conv1(x)
+            out = self.bn1(conv)
+            return out
+
+    for i in range(len(chn_list)):
+        net = Net(chn_list[i])
+        net.initialize(init=init.Constant(1))
+        x = mx.nd.zeros((1, 3, 8, 160, 160))
+        net(x).asnumpy()
+
+
+@with_seed()
 def test_concat():
     chn_list = [16, 64]
     shapes = [1, 3, 5]