You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/08/22 11:54:02 UTC

[incubator-mxnet] branch master updated: Fix fused resnet low accuracy (#21122)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new daac02c785 Fix fused resnet low accuracy (#21122)
daac02c785 is described below

commit daac02c7854ffa71bc11fd950c2d6c9ea356b394
Author: hankaj <ha...@intel.com>
AuthorDate: Mon Aug 22 13:53:44 2022 +0200

    Fix fused resnet low accuracy (#21122)
    
    * Change flag to postop
    
    * Add attributes to batch norm relu node
    
    * Refactor code with batch norm relu op
    
    * Delete fuse_norm_relu flag
    
    * Delete fuse_norm_relu flag
    
    * Refactor BN operator
    
    * Review suggestions
    
    * Review suggestions once again
    
    * Fix formatting
    
    * Fix lint errors
---
 python/mxnet/amp/lists/symbol_bf16.py              |   2 +-
 python/mxnet/amp/lists/symbol_fp16.py              |   2 +-
 python/mxnet/gluon/nn/basic_layers.py              |  81 +---
 src/operator/nn/batch_norm-inl.h                   |   2 +-
 src/operator/nn/batch_norm.cc                      |   6 +-
 src/operator/nn/dnnl/dnnl_batch_norm-inl.h         | 420 ++++-----------------
 .../{dnnl_batch_norm-inl.h => dnnl_batch_norm.cc}  | 287 +++++---------
 .../quantization/dnnl/dnnl_quantized_batch_norm.cc |   2 +-
 .../dnnl/dnnl_bn_relu.cc}                          |  86 +----
 src/operator/subgraph/dnnl/dnnl_bn_relu_property.h |   7 +-
 tests/python/dnnl/op_cfg.py                        |   2 +-
 tests/python/dnnl/test_dnnl.py                     |  58 ---
 12 files changed, 189 insertions(+), 766 deletions(-)

diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py
index 89ddea2820..5b9df27497 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -41,7 +41,6 @@ if Features.instance.is_enabled('ONEDNN'):
 # are dtype neutral (can work in both bf16 and fp32)
 BF16_FP32_FUNCS = [
     '_contrib_AdaptiveAvgPooling2D',
-    '_contrib_BatchNormWithReLU',
     'Activation',
     'BatchNorm',
     'LayerNorm',
@@ -102,6 +101,7 @@ WIDEST_TYPE_CASTS = [
 if Features.instance.is_enabled('ONEDNN'):
     WIDEST_TYPE_CASTS.extend([
         '_sg_onednn_batch_dot',
+        '_sg_onednn_batch_norm',
     ])
 
 # Functions that when running with Bfloat16, the params that still need float32.
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index 1cd5316361..76e9488f69 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -43,7 +43,6 @@ FP16_FP32_FUNCS = [
     'BlockGrad',
     'Cast',
     'cast_storage',
-    '_contrib_BatchNormWithReLU',
     '_contrib_allclose',
     '_contrib_arange_like',
     '_contrib_dynamic_reshape',
@@ -637,6 +636,7 @@ if Features().is_enabled('ONEDNN'):
         '_sg_onednn_selfatt_qk',
         '_sg_onednn_selfatt_valatt',
         '_sg_onednn_batch_dot',
+        '_sg_onednn_batch_norm',
         '_sg_pow_mul_scalar'
     ])
 
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 4afffb7d47..883b714b16 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,7 +19,7 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
-           'BatchNorm', 'SyncBatchNorm', 'BatchNormReLU', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
+           'BatchNorm', 'SyncBatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
            'Flatten', 'Lambda', 'HybridLambda', 'Concatenate', 'HybridConcatenate', 'Identity']
 import warnings
 import uuid
@@ -322,8 +322,6 @@ class _BatchNorm(HybridBlock):
         If True, use global moving statistics instead of local batch-norm. This will force
         change batch-norm into a scale shift operator.
         If False, use local batch-norm.
-    fuse_relu: bool, default False
-        If True, this operator is equal to `BN+ReLU`.
     beta_initializer: str or `Initializer`, default 'zeros'
         Initializer for the beta weight.
     gamma_initializer: str or `Initializer`, default 'ones'
@@ -345,14 +343,13 @@ class _BatchNorm(HybridBlock):
         - **out**: output tensor with the same shape as `data`.
     """
     def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
-                 use_global_stats=False, fuse_relu=False,
+                 use_global_stats=False,
                  beta_initializer='zeros', gamma_initializer='ones',
                  running_mean_initializer='zeros', running_variance_initializer='ones',
                  in_channels=0, **kwargs):
         super(_BatchNorm, self).__init__(**kwargs)
         self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
                         'fix_gamma': not scale, 'use_global_stats': use_global_stats}
-        self.fuse_relu = fuse_relu
         self._axis = axis
         if in_channels != 0:
             self.in_channels = in_channels
@@ -383,13 +380,7 @@ class _BatchNorm(HybridBlock):
 
     def forward(self, x):
         device = x.device
-        if self.fuse_relu:
-            return npx.batch_norm_with_relu(x, self.gamma.data(device), self.beta.data(device),
-                                            self.running_mean.data(device),
-                                            self.running_var.data(device),
-                                            name='fwd', **self._kwargs)
-        else:
-            return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
+        return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
                                   self.running_mean.data(device),
                                   self.running_var.data(device),
                                   name='fwd', **self._kwargs)
@@ -467,71 +458,7 @@ class BatchNorm(_BatchNorm):
         super(BatchNorm, self).__init__(
             axis=axis, momentum=momentum, epsilon=epsilon, center=center,
             scale=scale,
-            use_global_stats=use_global_stats, fuse_relu=False,
-            beta_initializer=beta_initializer,
-            gamma_initializer=gamma_initializer,
-            running_mean_initializer=running_mean_initializer,
-            running_variance_initializer=running_variance_initializer,
-            in_channels=in_channels, **kwargs)
-
-
-class BatchNormReLU(_BatchNorm):
-    """Batch normalization layer (Ioffe and Szegedy, 2014).
-    Normalizes the input at each batch, i.e. applies a transformation
-    that maintains the mean activation close to 0 and the activation
-    standard deviation close to 1.
-
-    Parameters
-    ----------
-    axis : int, default 1
-        The axis that should be normalized. This is typically the channels
-        (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
-        set `axis=1` in `BatchNorm`. If `layout='NHWC'`, then set `axis=3`.
-    momentum: float, default 0.9
-        Momentum for the moving average.
-    epsilon: float, default 1e-5
-        Small float added to variance to avoid dividing by zero.
-    center: bool, default True
-        If True, add offset of `beta` to normalized tensor.
-        If False, `beta` is ignored.
-    scale: bool, default True
-        If True, multiply by `gamma`. If False, `gamma` is not used.
-        When the next layer is linear (also e.g. `nn.relu`),
-        this can be disabled since the scaling
-        will be done by the next layer.
-    use_global_stats: bool, default False
-        If True, use global moving statistics instead of local batch-norm. This will force
-        change batch-norm into a scale shift operator.
-        If False, use local batch-norm.
-    beta_initializer: str or `Initializer`, default 'zeros'
-        Initializer for the beta weight.
-    gamma_initializer: str or `Initializer`, default 'ones'
-        Initializer for the gamma weight.
-    running_mean_initializer: str or `Initializer`, default 'zeros'
-        Initializer for the running mean.
-    running_variance_initializer: str or `Initializer`, default 'ones'
-        Initializer for the running variance.
-    in_channels : int, default 0
-        Number of channels (feature maps) in input data. If not specified,
-        initialization will be deferred to the first time `forward` is called
-        and `in_channels` will be inferred from the shape of input data.
-
-
-    Inputs:
-        - **data**: input tensor with arbitrary shape.
-
-    Outputs:
-        - **out**: output tensor with the same shape as `data`.
-    """
-    def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
-                 use_global_stats=False,
-                 beta_initializer='zeros', gamma_initializer='ones',
-                 running_mean_initializer='zeros', running_variance_initializer='ones',
-                 in_channels=0, **kwargs):
-        super(BatchNormReLU, self).__init__(
-            axis=axis, momentum=momentum, epsilon=epsilon,
-            center=center, scale=scale,
-            use_global_stats=use_global_stats, fuse_relu=True,
+            use_global_stats=use_global_stats,
             beta_initializer=beta_initializer,
             gamma_initializer=gamma_initializer,
             running_mean_initializer=running_mean_initializer,
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 92eded093d..d863302040 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -46,7 +46,7 @@
 #endif
 
 /*! \brief inverse standard deviation <-> variance */
-#define VARIANCE_TO_INVSTD(__var$, __eps$)    (1.0 / std::sqrt((__var$) + DType(__eps$)))
+#define VARIANCE_TO_INVSTD(__var$, __eps$)    (1.0 / std::sqrt((__var$) + (__eps$)))
 #define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
 
 namespace mxnet {
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index dc09ebeb22..154471f109 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -475,9 +475,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 5U);
   if (SupportDNNLBN(inputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
-    DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
-      DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
-    });
+    DNNLRun(DNNLBatchNormForward</*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
@@ -491,7 +489,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                const std::vector<NDArray>& outputs) {
   if (SupportDNNLBN(inputs[0])) {
     DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-    DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
+    DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 2780c9685f..40cf88ade1 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file dnnl_batch_norm.cc
+ * \file dnnl_batch_norm-inl.h
  * \brief
  * \author Tao Lv
  */
@@ -28,11 +28,12 @@
 
 #if MXNET_USE_ONEDNN == 1
 #include <dnnl.hpp>
+
 #include <utility>
 #include <vector>
 
-#include "operator/nn/batch_norm-inl.h"
 #include "dnnl_base-inl.h"
+#include "operator/nn/batch_norm-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -44,8 +45,7 @@ typedef dnnl::batch_normalization_backward::desc t_bn_b_desc;
 
 inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in_data,
                                                   const std::vector<NDArray>& aux_states,
-                                                  bool is_train_and_not_global_stats,
-                                                  bool fuse_relu) {
+                                                  bool is_train_and_not_global_stats) {
   dnnl::normalization_flags flags = static_cast<dnnl::normalization_flags>(0U);
   if (in_data.size() == 3U) {
     flags |= dnnl::normalization_flags::use_scale_shift;
@@ -56,15 +56,12 @@ inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in
   if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
     flags |= dnnl::normalization_flags::use_global_stats;
   }
-
-  if (fuse_relu) {
-    flags |= dnnl::normalization_flags::fuse_norm_relu;
-  }
   return flags;
 }
 
 inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
                                    bool is_train,
+                                   bool fuse_relu,
                                    float eps,
                                    dnnl::normalization_flags flags) {
   auto data_md = data_mem.get_desc();
@@ -73,6 +70,18 @@ inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
   if (is_train) {
     t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_training, data_md, eps, flags);
     return t_bn_f_pdesc(bnFwd_desc, engine);
+  }
+
+  if (fuse_relu) {
+    const float scale = 1.f;
+    const float alpha = 0.f;
+    const float beta  = 0.f;
+    dnnl::post_ops post_ops;
+    post_ops.append_eltwise(scale, dnnl::algorithm::eltwise_relu, alpha, beta);
+    dnnl::primitive_attr attr;
+    attr.set_post_ops(post_ops);
+    t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
+    return t_bn_f_pdesc(bnFwd_desc, attr, engine);
   } else {
     t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
     return t_bn_f_pdesc(bnFwd_desc, engine);
@@ -88,7 +97,7 @@ inline static t_bn_b_pdesc _GetBwd(const dnnl::memory& data_mem,
   auto engine  = CpuEngine::Get()->get_engine();
 
   t_bn_b_desc bnBwd_desc(dnnl::prop_kind::backward, diff_md, data_md, eps, flags);
-  return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags));
+  return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, false, eps, flags));
 }
 
 typedef ParamOpSign<BatchNormParam> DNNLBNSignature;
@@ -100,63 +109,39 @@ class DNNLBNForward {
   t_bn_f_pdesc pd;
 
  public:
-  DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats) : pd(_pd) {
-    weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
-    fwd.reset(new dnnl::batch_normalization_forward(pd));
-    this->is_train_and_not_global_stats = is_train_and_not_global_stats;
-  }
+  DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats);
 
-  const dnnl::memory& GetWeight() const {
-    return *weight_m;
-  }
+  const dnnl::memory& GetWeight() const;
 
-  const t_bn_f_pdesc& GetPd() const {
-    return pd;
-  }
+  const t_bn_f_pdesc& GetPd() const;
 
-  const dnnl::batch_normalization_forward& GetFwd() const {
-    return *fwd;
-  }
-};
+  const dnnl::batch_normalization_forward& GetFwd() const;
 
-template <typename DType>
-static DNNLBNForward& GetBNForward(const BatchNormParam& param,
-                                   const OpContext& ctx,
-                                   const dnnl::memory* data_mem,
-                                   dnnl::normalization_flags flags) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
-#endif
-  DNNLBNSignature key(param);
-  key.AddSign(ctx.is_train);
-  key.AddSign(*data_mem);
-  key.AddSign(static_cast<int>(flags));
-
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, param.eps, flags);
-    DNNLBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
-}
+  static DNNLBNForward& GetCached(const BatchNormParam& param,
+                                  const OpContext& ctx,
+                                  const dnnl::memory* data_mem,
+                                  bool fuse_relu,
+                                  dnnl::normalization_flags flags);
+  void Execute(const OpContext& ctx,
+               const BatchNormParam& param,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs,
+               bool fuse_relu);
+};
 
-template <typename DType>
-void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs,
-                              bool fuse_relu) {
+template <bool fuse_relu>
+void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
   const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
   std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
 
   mxnet::TShape shape = inputs[batchnorm::kData].shape();
   const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
   CHECK_LT(real_axis, shape.ndim());
-  NDArray out = outputs[batchnorm::kOut];
   if (param.axis != 1 || shape.ndim() != 4) {
     // reshape to (N, C, 1, D)
     mxnet::TShape new_shape{
@@ -165,109 +150,18 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
         1,
         static_cast<index_t>(shape.ProdShape(real_axis + 1, static_cast<int>(shape.ndim())))};
     in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
-    out                       = out.Reshape(new_shape);
   }
 
   const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   dnnl::normalization_flags flags =
-      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
   NDArray& data = in_data[batchnorm::kData];
   if (data.IsDNNLData() && data.IsView())
     data = data.Reorder2Default();
-  auto data_mem = data.GetDNNLData();
-  auto& fwd     = GetBNForward<DType>(param, ctx, data_mem, flags);
-
-  // for output memory
-  auto fwd_dst_desc = fwd.GetPd().dst_desc();
-  auto out_mem      = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
-
-  // mxnet will always use scale shift.
-  // But if fix_gamma is true, then all scale elements will be set to 1.0f
-  if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
-    const NDArray& gamma = in_data[batchnorm::kGamma];
-    const NDArray& beta  = in_data[batchnorm::kBeta];
-    CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
-    CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
-
-    const dnnl::memory& weight_mem = fwd.GetWeight();
-    float* weight_buf              = reinterpret_cast<float*>(weight_mem.get_data_handle());
-
-    index_t channels_ = data.shape()[1];
-    CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
-    float* weight_ptr      = gamma.data().dptr<float>();
-    float* bias_ptr        = beta.data().dptr<float>();
-    const size_t copy_size = sizeof(weight_buf[0]) * channels_;
-    if (!param.fix_gamma) {
-      memcpy(weight_buf, weight_ptr, copy_size);
-      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
-    } else if (IsBNWriting(req[batchnorm::kGamma])) {
-      for (index_t i = 0; i < channels_; i++) {
-        weight_buf[i]             = 1.0f;
-        weight_ptr[i]             = 1.0f;
-        weight_buf[channels_ + i] = bias_ptr[i];  // bias
-      }
-    } else {
-      for (index_t i = 0; i < channels_; i++) {
-        weight_buf[i]             = 1.0f;
-        weight_buf[channels_ + i] = bias_ptr[i];  // bias
-      }
-    }
-
-    dnnl_args_map_t net_args;
-    net_args[DNNL_ARG_SRC]         = *data_mem;
-    net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
-    net_args[DNNL_ARG_DST]         = *out_mem;
-    if (fuse_relu) {
-      const NDArray* workspace = nullptr;
-      workspace                = &outputs[3];
-      auto engine              = CpuEngine::Get()->get_engine();
-      if (workspace == nullptr) {
-        LOG(FATAL) << "oneDNN BatchNorm: incorrect workspace input";
-      }
-      auto ws = std::make_shared<dnnl::memory>(
-          fwd.GetPd().workspace_desc(), engine, workspace->GetDNNLData()->get_data_handle());
-      net_args[DNNL_ARG_WORKSPACE] = *ws;
-    }
-    if (!ctx.is_train || param.use_global_stats) {
-      float* omean  = outputs[batchnorm::kMean].data().dptr<float>();
-      float* ovar   = outputs[batchnorm::kVar].data().dptr<float>();
-      float* inmean = aux_states[batchnorm::kMovingMean].data().dptr<float>();
-      float* invar  = aux_states[batchnorm::kMovingVar].data().dptr<float>();
-      // to align with origin implmentation: batch_norm.cc: L164
-      for (index_t i = 0; i < channels_; i++) {
-        omean[i] = inmean[i];
-        ovar[i]  = VARIANCE_TO_INVSTD(invar[i], param.eps);
-      }
-      net_args[DNNL_ARG_MEAN]     = *(aux_states[batchnorm::kMovingMean].GetDNNLData());
-      net_args[DNNL_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetDNNLData());
-      DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
-      DNNLStream::Get()->Submit();
-    } else {  // training
-      const NDArray& outMean      = outputs[batchnorm::kMean];
-      const NDArray& outVar       = outputs[batchnorm::kVar];
-      net_args[DNNL_ARG_MEAN]     = *(outMean.GetDNNLData());
-      net_args[DNNL_ARG_VARIANCE] = *(outVar.GetDNNLData());
-      DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
-      DNNLStream::Get()->Submit();
-
-      float* ovar = outVar.data().dptr<float>();
-      for (index_t i = 0; i < channels_; i++) {
-        ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
-      }
-    }
-  } else {  // no input gamma and beta
-    LOG(FATAL) << "oneDNN batch normalization: should not reach here ...";
-  }
-}
-
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext& ctx,
-                          const std::vector<NDArray>& inputs,
-                          const std::vector<OpReqType>& req,
-                          const std::vector<NDArray>& outputs) {
-  DNNLBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+  auto data_mem      = data.GetDNNLData();
+  DNNLBNForward& fwd = DNNLBNForward::GetCached(param, ctx, data_mem, fuse_relu, flags);
+  fwd.Execute(ctx, param, inputs, req, outputs, fuse_relu);
 }
 
 class DNNLBNBackward {
@@ -278,95 +172,50 @@ class DNNLBNBackward {
  public:
   const t_bn_b_pdesc pd;
 
-  explicit DNNLBNBackward(const t_bn_b_pdesc& _pd)
-      : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
-        gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
-        pd(_pd) {
-    bwd.reset(new dnnl::batch_normalization_backward(pd));
-  }
+  explicit DNNLBNBackward(const t_bn_b_pdesc& _pd);
 
-  const dnnl::memory& GetWeight() const {
-    return *weight_m;
-  }
+  const dnnl::memory& GetWeight() const;
 
-  const dnnl::memory& GetGradw() const {
-    return *gradw_m;
-  }
+  const dnnl::memory& GetGradw() const;
 
-  const dnnl::batch_normalization_backward& GetBwd() const {
-    return *bwd;
-  }
-};
+  const dnnl::batch_normalization_backward& GetBwd() const;
 
-template <typename DType>
-static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
-                                     const OpContext& ctx,
-                                     const NDArray& in_data,
-                                     const dnnl::memory& in_mem,
-                                     const NDArray& diff_data,
-                                     const dnnl::memory& diff_mem,
-                                     dnnl::normalization_flags flags) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
-#endif
-  DNNLBNSignature key(param);
-  key.AddSign(in_data);
-  key.AddSign(diff_data);
-  key.AddSign(static_cast<int>(flags));
-
-  auto it = bwds.find(key);
-  if (it == bwds.end()) {
-    auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
-    DNNLBNBackward bwd(bwd_pd);
-    it = AddToCache(&bwds, key, bwd);
-  }
-  return it->second;
-}
+  static DNNLBNBackward& GetCached(const BatchNormParam& param,
+                                   const OpContext& ctx,
+                                   const NDArray& in_data,
+                                   const dnnl::memory& in_mem,
+                                   const NDArray& diff_data,
+                                   const dnnl::memory& diff_mem,
+                                   dnnl::normalization_flags flags);
 
-template <typename DType>
-void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
-                               const OpContext& ctx,
-                               const std::vector<NDArray>& inputs,
-                               const std::vector<OpReqType>& req,
-                               const std::vector<NDArray>& outputs,
-                               bool fuse_relu) {
-  if (fuse_relu) {
-    CHECK_EQ(inputs.size(), 9U);
-  } else {
-    CHECK_EQ(inputs.size(), 8U);
-  }
+  void Execute(const BatchNormParam& param,
+               const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+};
+
+inline void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
   const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
   std::vector<NDArray> out_grad(1);
-  std::vector<NDArray> out_data(3);
   std::vector<NDArray> in_data(3);
   std::vector<NDArray> aux_states(2);
-  out_grad[0]                         = inputs[0];
-  out_data[batchnorm::kMean]          = inputs[1];
-  out_data[batchnorm::kVar]           = inputs[2];
-  in_data[batchnorm::kData]           = inputs[3];
-  in_data[batchnorm::kGamma]          = inputs[4];
-  in_data[batchnorm::kBeta]           = inputs[5];
-  aux_states[batchnorm::kMovingMean]  = inputs[6];
-  aux_states[batchnorm::kMovingVar]   = inputs[7];
-  const std::vector<NDArray>& in_grad = outputs;
+  out_grad[0]                        = inputs[0];
+  in_data[batchnorm::kData]          = inputs[3];
+  in_data[batchnorm::kGamma]         = inputs[4];
+  in_data[batchnorm::kBeta]          = inputs[5];
+  aux_states[batchnorm::kMovingMean] = inputs[6];
+  aux_states[batchnorm::kMovingVar]  = inputs[7];
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   dnnl::normalization_flags flags =
-      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
-
-  NDArray data               = in_data[batchnorm::kData];
-  NDArray diff               = out_grad[batchnorm::kOut];
-  NDArray gradIn             = in_grad[batchnorm::kData];
-  const NDArray& moving_mean = aux_states[batchnorm::kMovingMean];
-  const NDArray& moving_var  = aux_states[batchnorm::kMovingVar];
-  const NDArray& out_mean    = out_data[batchnorm::kMean];
-  const NDArray& out_var     = out_data[batchnorm::kVar];
+      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
 
-  CHECK(out_mean.IsDefaultData());
-  CHECK(out_var.IsDefaultData());
-  CHECK(moving_mean.IsDefaultData());
-  CHECK(moving_var.IsDefaultData());
+  NDArray data = in_data[batchnorm::kData];
+  NDArray diff = out_grad[batchnorm::kOut];
 
   mxnet::TShape shape = data.shape();
   const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
@@ -378,15 +227,12 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
         shape[real_axis],
         1,
         static_cast<index_t>(shape.ProdShape(real_axis + 1, static_cast<int>(shape.ndim())))};
-    data   = data.Reshape(new_shape);
-    diff   = diff.Reshape(new_shape);
-    gradIn = gradIn.Reshape(new_shape);
+    data = data.Reshape(new_shape);
+    diff = diff.Reshape(new_shape);
   }
 
   auto data_mem = data.GetDNNLData();
   auto diff_mem = diff.GetDNNLData();
-  // DNNL batchnorm should run on special layouts. If one of them isn't, we
-  // should reorder them.
   if (data.IsDefaultData()) {
     auto diff_desc = diff_mem->get_desc();
     data_mem       = data.GetDNNLDataReorder(&diff_desc);
@@ -394,113 +240,9 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
     auto data_desc = data_mem->get_desc();
     diff_mem       = diff.GetDNNLDataReorder(&data_desc);
   }
-  auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
-  auto gradi_mem =
-      CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
-
-  if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
-    const NDArray& gamma   = in_data[batchnorm::kGamma];
-    const NDArray& beta    = in_data[batchnorm::kBeta];
-    DType* weight_buf      = reinterpret_cast<DType*>(bwd.GetWeight().get_data_handle());
-    index_t channels_      = data.shape()[1];
-    DType* weight_ptr      = gamma.data().dptr<DType>();
-    DType* bias_ptr        = beta.data().dptr<DType>();
-    const size_t copy_size = sizeof(DType) * channels_;
-    if (!param.fix_gamma) {
-      memcpy(weight_buf, weight_ptr, copy_size);
-      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
-    } else {
-      for (index_t i = 0; i < channels_; i++) {
-        weight_buf[i] = static_cast<DType>(1.0f);
-      }
-      memcpy(&weight_buf[channels_], bias_ptr, copy_size);
-    }
-    dnnl_args_map_t net_args;
-    net_args[DNNL_ARG_SRC]              = *data_mem;
-    net_args[DNNL_ARG_DIFF_SRC]         = *gradi_mem.second;
-    net_args[DNNL_ARG_SCALE_SHIFT]      = bwd.GetWeight();
-    net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
-    net_args[DNNL_ARG_DIFF_DST]         = *diff_mem;
-
-    if (fuse_relu) {
-      const NDArray* workspace = nullptr;
-      workspace                = &inputs[8];
-      if (workspace != nullptr) {
-        net_args[DNNL_ARG_WORKSPACE] = *(workspace->GetDNNLData());
-      }
-    }
-
-    // training but no input mean and variance
-    if (ctx.is_train && !param.use_global_stats) {
-      DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
-      DType* moving_var_ptr  = moving_var.data().dptr<DType>();
-      DType* out_mean_ptr    = out_mean.data().dptr<DType>();
-      DType* out_var_ptr     = out_var.data().dptr<DType>();
-      dnnl::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
-      DType* tmp_var_ptr = reinterpret_cast<DType*>(var_mem.get_data_handle());
-
-      DType minus_mom = (1.0f - param.momentum);
-      for (index_t i = 0; i < channels_; i++) {
-        moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + out_mean_ptr[i] * minus_mom;
-        float variance     = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
-        tmp_var_ptr[i]     = variance;
-        moving_var_ptr[i]  = moving_var_ptr[i] * param.momentum + variance * minus_mom;
-      }
-      net_args[DNNL_ARG_MEAN]     = *(out_mean.GetDNNLData());
-      net_args[DNNL_ARG_VARIANCE] = var_mem;
-    } else {
-      net_args[DNNL_ARG_MEAN]     = *(moving_mean.GetDNNLData());
-      net_args[DNNL_ARG_VARIANCE] = *(moving_var.GetDNNLData());
-    }
-    DNNLStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
-    CommitOutput(gradIn, gradi_mem);
-    DNNLStream::Get()->Submit();
-
-    // copy data from gradw_mem to in_grad[1] and in_grad[2]
-    DType* gw_buf   = reinterpret_cast<DType*>(bwd.GetGradw().get_data_handle());
-    DType* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
-    DType* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
-
-    // the gradient of gamma
-    if (!param.fix_gamma) {
-      if (req[batchnorm::kGamma] != kNullOp) {
-        if (req[batchnorm::kGamma] != kAddTo) {
-          memcpy(w_grad_1, gw_buf, copy_size);
-        } else {
-          for (index_t i = 0; i < channels_; i++) {
-            w_grad_1[i] += gw_buf[i];
-          }
-        }
-      }
-    } else {
-      for (index_t i = 0; i < channels_; i++) {
-        (in_grad[1].data().dptr<DType>())[i] = 0.0f;
-      }
-    }
-
-    // the gradient of beta
-    if (req[batchnorm::kBeta] != kNullOp) {
-      if (req[batchnorm::kBeta] != kAddTo) {
-        memcpy(w_grad_2, &gw_buf[channels_], copy_size);
-      } else {
-        DType* grad_beta = &gw_buf[channels_];
-        for (index_t i = 0; i < channels_; i++) {
-          w_grad_2[i] += grad_beta[i];
-        }
-      }
-    }
-  } else {
-    LOG(FATAL) << "oneDNN batch normalization backward: should not reach here ...";
-  }
-}
-
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const std::vector<NDArray>& inputs,
-                           const std::vector<OpReqType>& req,
-                           const std::vector<NDArray>& outputs) {
-  DNNLBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+  DNNLBNBackward& bwd =
+      DNNLBNBackward::GetCached(param, ctx, data, *data_mem, diff, *diff_mem, flags);
+  bwd.Execute(param, ctx, inputs, req, outputs);
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm.cc
similarity index 56%
copy from src/operator/nn/dnnl/dnnl_batch_norm-inl.h
copy to src/operator/nn/dnnl/dnnl_batch_norm.cc
index 2780c9685f..e4a6e1691a 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm.cc
@@ -19,20 +19,17 @@
 
 /*!
  * \file dnnl_batch_norm.cc
- * \brief
- * \author Tao Lv
  */
 
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
-
 #if MXNET_USE_ONEDNN == 1
 #include <dnnl.hpp>
+
 #include <utility>
 #include <vector>
 
-#include "operator/nn/batch_norm-inl.h"
 #include "dnnl_base-inl.h"
+#include "dnnl_batch_norm-inl.h"
+#include "operator/nn/batch_norm-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -42,115 +39,57 @@ typedef dnnl::batch_normalization_forward::desc t_bn_f_desc;
 typedef dnnl::batch_normalization_backward::primitive_desc t_bn_b_pdesc;
 typedef dnnl::batch_normalization_backward::desc t_bn_b_desc;
 
-inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in_data,
-                                                  const std::vector<NDArray>& aux_states,
-                                                  bool is_train_and_not_global_stats,
-                                                  bool fuse_relu) {
-  dnnl::normalization_flags flags = static_cast<dnnl::normalization_flags>(0U);
-  if (in_data.size() == 3U) {
-    flags |= dnnl::normalization_flags::use_scale_shift;
-  }
-
-  // aux_states[0]: inMean
-  // aux_states[1]: inVariance
-  if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
-    flags |= dnnl::normalization_flags::use_global_stats;
-  }
-
-  if (fuse_relu) {
-    flags |= dnnl::normalization_flags::fuse_norm_relu;
-  }
-  return flags;
+DNNLBNForward::DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats)
+    : pd(_pd) {
+  weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
+  fwd.reset(new dnnl::batch_normalization_forward(pd));
+  this->is_train_and_not_global_stats = is_train_and_not_global_stats;
 }
 
-inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
-                                   bool is_train,
-                                   float eps,
-                                   dnnl::normalization_flags flags) {
-  auto data_md = data_mem.get_desc();
-  auto engine  = CpuEngine::Get()->get_engine();
-
-  if (is_train) {
-    t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_training, data_md, eps, flags);
-    return t_bn_f_pdesc(bnFwd_desc, engine);
-  } else {
-    t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
-    return t_bn_f_pdesc(bnFwd_desc, engine);
-  }
+const dnnl::memory& DNNLBNForward::GetWeight() const {
+  return *weight_m;
 }
 
-inline static t_bn_b_pdesc _GetBwd(const dnnl::memory& data_mem,
-                                   const dnnl::memory& diff_mem,
-                                   float eps,
-                                   dnnl::normalization_flags flags) {
-  auto data_md = data_mem.get_desc();
-  auto diff_md = diff_mem.get_desc();
-  auto engine  = CpuEngine::Get()->get_engine();
-
-  t_bn_b_desc bnBwd_desc(dnnl::prop_kind::backward, diff_md, data_md, eps, flags);
-  return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags));
+const t_bn_f_pdesc& DNNLBNForward::GetPd() const {
+  return pd;
 }
 
-typedef ParamOpSign<BatchNormParam> DNNLBNSignature;
-
-class DNNLBNForward {
-  std::shared_ptr<const dnnl::memory> weight_m;
-  std::shared_ptr<dnnl::batch_normalization_forward> fwd;
-  bool is_train_and_not_global_stats;
-  t_bn_f_pdesc pd;
-
- public:
-  DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats) : pd(_pd) {
-    weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
-    fwd.reset(new dnnl::batch_normalization_forward(pd));
-    this->is_train_and_not_global_stats = is_train_and_not_global_stats;
-  }
-
-  const dnnl::memory& GetWeight() const {
-    return *weight_m;
-  }
-
-  const t_bn_f_pdesc& GetPd() const {
-    return pd;
-  }
-
-  const dnnl::batch_normalization_forward& GetFwd() const {
-    return *fwd;
-  }
-};
+const dnnl::batch_normalization_forward& DNNLBNForward::GetFwd() const {
+  return *fwd;
+}
 
-template <typename DType>
-static DNNLBNForward& GetBNForward(const BatchNormParam& param,
-                                   const OpContext& ctx,
-                                   const dnnl::memory* data_mem,
-                                   dnnl::normalization_flags flags) {
+DNNLBNForward& DNNLBNForward::GetCached(const BatchNormParam& param,
+                                        const OpContext& ctx,
+                                        const dnnl::memory* data_mem,
+                                        bool fuse_relu,
+                                        dnnl::normalization_flags flags) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
 #else
   static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
 #endif
+
   DNNLBNSignature key(param);
   key.AddSign(ctx.is_train);
   key.AddSign(*data_mem);
   key.AddSign(static_cast<int>(flags));
+  key.AddSign(fuse_relu);
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, param.eps, flags);
+    auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, fuse_relu, param.eps, flags);
     DNNLBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
 
-template <typename DType>
-void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs,
-                              bool fuse_relu) {
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+void DNNLBNForward::Execute(const OpContext& ctx,
+                            const BatchNormParam& param,
+                            const std::vector<NDArray>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& outputs,
+                            bool fuse_relu) {
   std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
 
   mxnet::TShape shape = inputs[batchnorm::kData].shape();
@@ -171,15 +110,14 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
   const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   dnnl::normalization_flags flags =
-      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
   NDArray& data = in_data[batchnorm::kData];
   if (data.IsDNNLData() && data.IsView())
     data = data.Reorder2Default();
   auto data_mem = data.GetDNNLData();
-  auto& fwd     = GetBNForward<DType>(param, ctx, data_mem, flags);
 
   // for output memory
-  auto fwd_dst_desc = fwd.GetPd().dst_desc();
+  auto fwd_dst_desc = GetPd().dst_desc();
   auto out_mem      = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
 
   // mxnet will always use scale shift.
@@ -190,7 +128,7 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
     CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
 
-    const dnnl::memory& weight_mem = fwd.GetWeight();
+    const dnnl::memory& weight_mem = GetWeight();
     float* weight_buf              = reinterpret_cast<float*>(weight_mem.get_data_handle());
 
     index_t channels_ = data.shape()[1];
@@ -218,17 +156,6 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
     net_args[DNNL_ARG_SRC]         = *data_mem;
     net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
     net_args[DNNL_ARG_DST]         = *out_mem;
-    if (fuse_relu) {
-      const NDArray* workspace = nullptr;
-      workspace                = &outputs[3];
-      auto engine              = CpuEngine::Get()->get_engine();
-      if (workspace == nullptr) {
-        LOG(FATAL) << "oneDNN BatchNorm: incorrect workspace input";
-      }
-      auto ws = std::make_shared<dnnl::memory>(
-          fwd.GetPd().workspace_desc(), engine, workspace->GetDNNLData()->get_data_handle());
-      net_args[DNNL_ARG_WORKSPACE] = *ws;
-    }
     if (!ctx.is_train || param.use_global_stats) {
       float* omean  = outputs[batchnorm::kMean].data().dptr<float>();
       float* ovar   = outputs[batchnorm::kVar].data().dptr<float>();
@@ -241,14 +168,14 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
       }
       net_args[DNNL_ARG_MEAN]     = *(aux_states[batchnorm::kMovingMean].GetDNNLData());
       net_args[DNNL_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetDNNLData());
-      DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+      DNNLStream::Get()->RegisterPrimArgs(GetFwd(), net_args);
       DNNLStream::Get()->Submit();
     } else {  // training
       const NDArray& outMean      = outputs[batchnorm::kMean];
       const NDArray& outVar       = outputs[batchnorm::kVar];
       net_args[DNNL_ARG_MEAN]     = *(outMean.GetDNNLData());
       net_args[DNNL_ARG_VARIANCE] = *(outVar.GetDNNLData());
-      DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+      DNNLStream::Get()->RegisterPrimArgs(GetFwd(), net_args);
       DNNLStream::Get()->Submit();
 
       float* ovar = outVar.data().dptr<float>();
@@ -261,51 +188,32 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
   }
 }
 
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext& ctx,
-                          const std::vector<NDArray>& inputs,
-                          const std::vector<OpReqType>& req,
-                          const std::vector<NDArray>& outputs) {
-  DNNLBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+DNNLBNBackward::DNNLBNBackward(const t_bn_b_pdesc& _pd)
+    : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
+      gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
+      pd(_pd) {
+  bwd.reset(new dnnl::batch_normalization_backward(pd));
 }
 
-class DNNLBNBackward {
-  std::shared_ptr<dnnl::batch_normalization_backward> bwd;
-  const std::shared_ptr<dnnl::memory> weight_m;
-  const std::shared_ptr<dnnl::memory> gradw_m;
-
- public:
-  const t_bn_b_pdesc pd;
-
-  explicit DNNLBNBackward(const t_bn_b_pdesc& _pd)
-      : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
-        gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
-        pd(_pd) {
-    bwd.reset(new dnnl::batch_normalization_backward(pd));
-  }
+const dnnl::memory& DNNLBNBackward::GetWeight() const {
+  return *weight_m;
+}
 
-  const dnnl::memory& GetWeight() const {
-    return *weight_m;
-  }
+const dnnl::memory& DNNLBNBackward::GetGradw() const {
+  return *gradw_m;
+}
 
-  const dnnl::memory& GetGradw() const {
-    return *gradw_m;
-  }
+const dnnl::batch_normalization_backward& DNNLBNBackward::GetBwd() const {
+  return *bwd;
+}
 
-  const dnnl::batch_normalization_backward& GetBwd() const {
-    return *bwd;
-  }
-};
-
-template <typename DType>
-static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
-                                     const OpContext& ctx,
-                                     const NDArray& in_data,
-                                     const dnnl::memory& in_mem,
-                                     const NDArray& diff_data,
-                                     const dnnl::memory& diff_mem,
-                                     dnnl::normalization_flags flags) {
+DNNLBNBackward& DNNLBNBackward::GetCached(const BatchNormParam& param,
+                                          const OpContext& ctx,
+                                          const NDArray& in_data,
+                                          const dnnl::memory& in_mem,
+                                          const NDArray& diff_data,
+                                          const dnnl::memory& diff_mem,
+                                          dnnl::normalization_flags flags) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
 #else
@@ -325,19 +233,12 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
   return it->second;
 }
 
-template <typename DType>
-void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
-                               const OpContext& ctx,
-                               const std::vector<NDArray>& inputs,
-                               const std::vector<OpReqType>& req,
-                               const std::vector<NDArray>& outputs,
-                               bool fuse_relu) {
-  if (fuse_relu) {
-    CHECK_EQ(inputs.size(), 9U);
-  } else {
-    CHECK_EQ(inputs.size(), 8U);
-  }
-  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+void DNNLBNBackward::Execute(const BatchNormParam& param,
+                             const OpContext& ctx,
+                             const std::vector<NDArray>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 8U);
   std::vector<NDArray> out_grad(1);
   std::vector<NDArray> out_data(3);
   std::vector<NDArray> in_data(3);
@@ -353,7 +254,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
   const std::vector<NDArray>& in_grad = outputs;
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   dnnl::normalization_flags flags =
-      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+      _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
 
   NDArray data               = in_data[batchnorm::kData];
   NDArray diff               = out_grad[batchnorm::kOut];
@@ -394,52 +295,43 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
     auto data_desc = data_mem->get_desc();
     diff_mem       = diff.GetDNNLDataReorder(&data_desc);
   }
-  auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
   auto gradi_mem =
-      CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
+      CreateDNNLMem(const_cast<NDArray&>(gradIn), pd.diff_src_desc(), req[batchnorm::kData]);
 
   if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
     const NDArray& gamma   = in_data[batchnorm::kGamma];
     const NDArray& beta    = in_data[batchnorm::kBeta];
-    DType* weight_buf      = reinterpret_cast<DType*>(bwd.GetWeight().get_data_handle());
+    float* weight_buf      = reinterpret_cast<float*>(GetWeight().get_data_handle());
     index_t channels_      = data.shape()[1];
-    DType* weight_ptr      = gamma.data().dptr<DType>();
-    DType* bias_ptr        = beta.data().dptr<DType>();
-    const size_t copy_size = sizeof(DType) * channels_;
+    float* weight_ptr      = gamma.data().dptr<float>();
+    float* bias_ptr        = beta.data().dptr<float>();
+    const size_t copy_size = sizeof(weight_buf[0]) * channels_;
     if (!param.fix_gamma) {
       memcpy(weight_buf, weight_ptr, copy_size);
       memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     } else {
       for (index_t i = 0; i < channels_; i++) {
-        weight_buf[i] = static_cast<DType>(1.0f);
+        weight_buf[i] = 1.0f;
       }
       memcpy(&weight_buf[channels_], bias_ptr, copy_size);
     }
     dnnl_args_map_t net_args;
     net_args[DNNL_ARG_SRC]              = *data_mem;
     net_args[DNNL_ARG_DIFF_SRC]         = *gradi_mem.second;
-    net_args[DNNL_ARG_SCALE_SHIFT]      = bwd.GetWeight();
-    net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
+    net_args[DNNL_ARG_SCALE_SHIFT]      = GetWeight();
+    net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = GetGradw();
     net_args[DNNL_ARG_DIFF_DST]         = *diff_mem;
 
-    if (fuse_relu) {
-      const NDArray* workspace = nullptr;
-      workspace                = &inputs[8];
-      if (workspace != nullptr) {
-        net_args[DNNL_ARG_WORKSPACE] = *(workspace->GetDNNLData());
-      }
-    }
-
     // training but no input mean and variance
     if (ctx.is_train && !param.use_global_stats) {
-      DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
-      DType* moving_var_ptr  = moving_var.data().dptr<DType>();
-      DType* out_mean_ptr    = out_mean.data().dptr<DType>();
-      DType* out_var_ptr     = out_var.data().dptr<DType>();
-      dnnl::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
-      DType* tmp_var_ptr = reinterpret_cast<DType*>(var_mem.get_data_handle());
-
-      DType minus_mom = (1.0f - param.momentum);
+      float* moving_mean_ptr = moving_mean.data().dptr<float>();
+      float* moving_var_ptr  = moving_var.data().dptr<float>();
+      float* out_mean_ptr    = out_mean.data().dptr<float>();
+      float* out_var_ptr     = out_var.data().dptr<float>();
+      dnnl::memory var_mem(pd.variance_desc(), CpuEngine::Get()->get_engine());
+      float* tmp_var_ptr = reinterpret_cast<float*>(var_mem.get_data_handle());
+
+      float minus_mom = (1.0f - param.momentum);
       for (index_t i = 0; i < channels_; i++) {
         moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + out_mean_ptr[i] * minus_mom;
         float variance     = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
@@ -452,14 +344,14 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
       net_args[DNNL_ARG_MEAN]     = *(moving_mean.GetDNNLData());
       net_args[DNNL_ARG_VARIANCE] = *(moving_var.GetDNNLData());
     }
-    DNNLStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
+    DNNLStream::Get()->RegisterPrimArgs(GetBwd(), net_args);
     CommitOutput(gradIn, gradi_mem);
     DNNLStream::Get()->Submit();
 
     // copy data from gradw_mem to in_grad[1] and in_grad[2]
-    DType* gw_buf   = reinterpret_cast<DType*>(bwd.GetGradw().get_data_handle());
-    DType* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
-    DType* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
+    float* gw_buf   = reinterpret_cast<float*>(GetGradw().get_data_handle());
+    float* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<float>();
+    float* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<float>();
 
     // the gradient of gamma
     if (!param.fix_gamma) {
@@ -474,7 +366,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
       }
     } else {
       for (index_t i = 0; i < channels_; i++) {
-        (in_grad[1].data().dptr<DType>())[i] = 0.0f;
+        (in_grad[1].data().dptr<float>())[i] = 0.0f;
       }
     }
 
@@ -483,7 +375,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
       if (req[batchnorm::kBeta] != kAddTo) {
         memcpy(w_grad_2, &gw_buf[channels_], copy_size);
       } else {
-        DType* grad_beta = &gw_buf[channels_];
+        float* grad_beta = &gw_buf[channels_];
         for (index_t i = 0; i < channels_; i++) {
           w_grad_2[i] += grad_beta[i];
         }
@@ -494,15 +386,6 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
   }
 }
 
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const std::vector<NDArray>& inputs,
-                           const std::vector<OpReqType>& req,
-                           const std::vector<NDArray>& outputs) {
-  DNNLBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
-}
 }  // namespace op
 }  // namespace mxnet
-#endif  // MXNET_USE_ONEDNN
-#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
+#endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
index ba3b2f84ce..39bbb03d78 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
@@ -83,7 +83,7 @@ static void DNNLQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
 
   dnnl::normalization_flags flags =
       dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift;
-  auto& fwd                      = GetBNForward<float>(param, ctx, data_mem, flags);
+  auto& fwd                      = DNNLBNForward::GetCached(param, ctx, data_mem, false, flags);
   const dnnl::memory& weight_mem = fwd.GetWeight();
   CHECK_EQ(weight_mem.get_desc().get_size(), channel_count * sizeof(float) * 2);
   float* weight_buf = reinterpret_cast<float*>(weight_mem.get_data_handle());
diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/subgraph/dnnl/dnnl_bn_relu.cc
similarity index 73%
rename from src/operator/contrib/batch_norm_relu.cc
rename to src/operator/subgraph/dnnl/dnnl_bn_relu.cc
index a0f158f42b..383beb0eab 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/subgraph/dnnl/dnnl_bn_relu.cc
@@ -18,17 +18,18 @@
  */
 
 /*!
- * \file batch_norm_relu.cc
+ * \file dnnl_bn_relu.cc
  * \brief
  * \author Xinyu Chen
  */
 
-#include "../nn/batch_norm-inl.h"
 #include <nnvm/op_attr_types.h>
-#include "../elemwise_op_common.h"
-#include "../operator_common.h"
+
+#include "operator/elemwise_op_common.h"
+#include "operator/nn/batch_norm-inl.h"
+#include "operator/operator_common.h"
 #if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_batch_norm-inl.h"
+#include "operator/nn/dnnl/dnnl_batch_norm-inl.h"
 #endif
 
 namespace mxnet {
@@ -144,27 +145,12 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
   if (SupportDNNLBNReLU(inputs[0])) {
     CHECK_GT(outputs.size(), 3U);
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
-    DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
-      DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
-    });
+    DNNLRun(DNNLBatchNormForward</*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
     return;
   }
   LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend.";
 }
 
-void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
-                                       const OpContext& ctx,
-                                       const std::vector<NDArray>& inputs,
-                                       const std::vector<OpReqType>& req,
-                                       const std::vector<NDArray>& outputs) {
-  if (SupportDNNLBNReLU(inputs[0])) {
-    CHECK_EQ(inputs.size(), 9U);
-    DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-    DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
-    return;
-  }
-  LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend.";
-}
 #endif
 
 static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs& attrs,
@@ -200,47 +186,7 @@ static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
-std::vector<nnvm::NodeEntry> BatchNormWithReLUGrad(const nnvm::ObjectPtr& n,
-                                                   const std::vector<nnvm::NodeEntry>& ograds) {
-  std::vector<nnvm::NodeEntry> out_data;
-  out_data.reserve(n->num_outputs());
-  for (size_t i = 0; i < n->num_outputs(); ++i)
-    out_data.emplace_back(n, i, 0);
-  std::vector<nnvm::NodeEntry> heads;
-  heads.reserve(9);
-  heads.emplace_back(ograds.at(0));
-  heads.emplace_back(out_data.at(batchnormrelu::kMean));
-  heads.emplace_back(out_data.at(batchnormrelu::kVar));
-  heads.emplace_back(n->inputs.at(batchnormrelu::kData));
-  heads.emplace_back(n->inputs.at(batchnormrelu::kGamma));
-  heads.emplace_back(n->inputs.at(batchnormrelu::kBeta));
-  heads.emplace_back(n->inputs.at(batchnormrelu::kInMovingMean));
-  heads.emplace_back(n->inputs.at(batchnormrelu::kInMovingVar));
-  heads.emplace_back(out_data.at(batchnormrelu::kWorkspace));
-
-  nnvm::ObjectPtr gnode = nnvm::Node::Create();
-  gnode->inputs         = std::move(heads);
-  gnode->control_deps.emplace_back(n);
-  gnode->attrs      = n->attrs;
-  gnode->attrs.op   = nnvm::Op::Get("_backward_contrib_BatchNormWithReLU");
-  gnode->attrs.name = n->attrs.name + "_backward";
-  // The input of batchnorm
-  std::vector<nnvm::NodeEntry> in_grad;
-  in_grad.reserve(5);
-  for (size_t i = 0; i < 3; ++i)
-    in_grad.emplace_back(gnode, i, 0);
-  // attach no gradient node to forbid gradient on aux_state
-  nnvm::ObjectPtr ng = nnvm::Node::Create();
-  ng->attrs.op       = Op::Get("_NoGradient");
-  ng->attrs.name     = "NoGradient";
-  // the aux state of batchnorm
-  for (size_t i = 3; i < 5; ++i)
-    in_grad.emplace_back(ng);
-  return in_grad;
-}
-
-NNVM_REGISTER_OP(_contrib_BatchNormWithReLU)
-    .add_alias("_npx_batch_norm_with_relu")
+NNVM_REGISTER_OP(_sg_onednn_batch_norm)
     .describe(R"code(Batch normalization with ReLU fusion.
 
 An extented operator of Batch normalization which can fuse ReLU activation.
@@ -275,7 +221,6 @@ An extented operator of Batch normalization which can fuse ReLU activation.
 #if MXNET_USE_ONEDNN == 1
     .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormWithReLUComputeExCPU)
 #endif
-    .set_attr<nnvm::FGradient>("FGradient", BatchNormWithReLUGrad)
 #if MXNET_USE_ONEDNN == 1
     .set_attr<bool>("TIsDNNL", true)
     .set_attr<FResourceRequest>("FResourceRequest",
@@ -301,20 +246,5 @@ An extented operator of Batch normalization which can fuse ReLU activation.
           }
         });
 
-NNVM_REGISTER_OP(_backward_contrib_BatchNormWithReLU)
-    .set_num_inputs(9)
-    .set_num_outputs(3)
-    .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-    .set_attr<FInferStorageType>("FInferStorageType", BatchNormWithReLUStorageType)
-#if MXNET_USE_ONEDNN == 1
-    .set_attr<FResourceRequest>("FResourceRequest",
-                                [](const NodeAttrs& n) {
-                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-                                })
-    .set_attr<bool>("TIsDNNL", true)
-    .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormWithReLUGradComputeExCPU)
-#endif
-    .set_attr_parser(ParamParser<BatchNormParam>);
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
index 792236a50d..c350fb90a4 100644
--- a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
@@ -25,10 +25,10 @@
 #include <string>
 #include <vector>
 
+#include "dnnl_subgraph_base-inl.h"
 #include "operator/nn/dnnl/dnnl_act-inl.h"
 #include "operator/nn/dnnl/dnnl_batch_norm-inl.h"
 #include "operator/subgraph/common.h"
-#include "dnnl_subgraph_base-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -116,10 +116,11 @@ class SgDNNLBNReLUProperty : public SubgraphProperty {
     });
 
     n->attrs.name = node_name.str();
-    n->attrs.op   = Op::Get("_contrib_BatchNormWithReLU");
+    n->attrs.op   = Op::Get("_sg_onednn_batch_norm");
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(sym));
-    n->attrs.parsed = param;
+    param.SetAttrDict(&(n->attrs.dict));
+    n->op()->attr_parser(&(n->attrs));
     return n;
   }
 
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
index 9effb305b1..c4739ab807 100644
--- a/tests/python/dnnl/op_cfg.py
+++ b/tests/python/dnnl/op_cfg.py
@@ -197,7 +197,6 @@ def get_all_ops_cfgs(dtype):
                 TensorArg(lambda shape, dtype: mx.nd.random.uniform(0, 1, shape, dtype))
             ],
         },
-        '_contrib_BatchNormWithReLU': {CFG_BASED_ON: 'BatchNorm'},
         'LRN': {
             'data,nsize': [(default_tensor(2, dtype), 3)]
         },
@@ -289,6 +288,7 @@ def get_all_ops_cfgs(dtype):
             CFG_BASED_ON: 'batch_dot',
             CFG_SUBGRAPH: [SubgraphCfg('batch_dot', 'ONEDNN')],
         },
+        '_sg_onednn_batch_norm': {CFG_BASED_ON: 'BatchNorm'},
         '_sg_onednn_selfatt_qk': {
             CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
             'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), dtype)],
diff --git a/tests/python/dnnl/test_dnnl.py b/tests/python/dnnl/test_dnnl.py
index 102af2e38d..8eb8373784 100644
--- a/tests/python/dnnl/test_dnnl.py
+++ b/tests/python/dnnl/test_dnnl.py
@@ -288,64 +288,6 @@ def test_batchnorm():
     for stype in stypes:
         check_batchnorm_training(stype)
 
-def test_batchnorm_relu_fusion():
-    def check_batchnorm_relu_fusion(shape):
-        x = mx.sym.Variable('x')
-        in_data = mx.nd.random.normal(shape=shape)
-        grad_out = mx.nd.random.uniform(0, 1, shape)
-        bn = mx.sym.BatchNorm(data=x, fix_gamma=False)
-        relu = mx.sym.Activation(data=bn, act_type='relu', name='relu')
-        exe = relu._simple_bind(ctx=mx.cpu(), x=shape, grad_req='write')
-        exe.arg_arrays[0][:] = in_data
-        exe.forward(is_train=True)
-        exe.backward(grad_out)
-        no_fuse_outputs = exe.outputs
-        no_fuse_grads = exe.grad_arrays
-
-        bnrelu = mx.sym.contrib.BatchNormWithReLU(data=x, fix_gamma=False)
-        exe_fuse = bnrelu._simple_bind(ctx=mx.cpu(), x=shape, grad_req='write')
-        exe_fuse.arg_arrays[0][:] = in_data
-        exe_fuse.forward(is_train=True)
-        exe_fuse.backward(grad_out)
-        fuse_outputs = exe_fuse.outputs
-        fuse_grads = exe_fuse.grad_arrays
-
-        for i in range(len(no_fuse_outputs)):
-            assert_almost_equal(no_fuse_outputs[i], fuse_outputs[i])
-        for i in range(len(no_fuse_grads)):
-            assert_almost_equal(no_fuse_grads[i], fuse_grads[i])
-
-    def check_batchnorm_relu_fusion_gluon(shape):
-        class BNNet(gluon.HybridBlock):
-            def __init__(self, fuse_relu):
-                super(BNNet, self).__init__()
-                self.fuse_relu = fuse_relu
-                if self.fuse_relu:
-                    self.bn = gluon.nn.BatchNormReLU()
-                else:
-                    self.bn = gluon.nn.BatchNorm()
-                self.relu = gluon.nn.Activation('relu')
-
-            def forward(self, x):
-                y = self.bn(x)
-                if not self.fuse_relu:
-                    y = self.relu(y)
-                return y
-        fused_net = BNNet(fuse_relu=True)
-        unfused_net = BNNet(fuse_relu=False)
-        fused_net.initialize()
-        unfused_net.initialize()
-        in_data = mx.np.random.normal(size=shape)
-        no_fuse_outputs = unfused_net.forward(in_data)
-        fuse_outputs = fused_net.forward(in_data)
-
-        for i in range(len(no_fuse_outputs)):
-            assert_almost_equal(no_fuse_outputs[i], fuse_outputs[i])
-
-    check_batchnorm_relu_fusion((1, 3, 224, 224))
-    check_batchnorm_relu_fusion((8, 3, 224, 224))
-    check_batchnorm_relu_fusion_gluon((1, 3, 224, 224))
-    check_batchnorm_relu_fusion_gluon((8, 3, 224, 224))
 
 def test_softmax():
     def check_softmax_training(stype):