You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/09/23 14:04:13 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN BN (#16199)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new 0b8805a  [mkldnn-v1.0] Add MKL-DNN BN (#16199)
0b8805a is described below

commit 0b8805afc963296ecfeb0599e6c63858ca6a85b0
Author: rongzha1 <ro...@intel.com>
AuthorDate: Mon Sep 23 22:03:04 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN BN (#16199)
    
    * add mkldnn bn
    
    * add static_cast to transform data type
    
    * change mkldnn_args_map_t
    
    * retrigger CI
---
 src/operator/nn/batch_norm.cc                  |  14 +-
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 232 ++++++++-----------------
 2 files changed, 81 insertions(+), 165 deletions(-)

diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 3214e3b..da042a1 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -28,7 +28,7 @@
 #include <nnvm/op_attr_types.h>
 #include "../elemwise_op_common.h"
 #include "../operator_common.h"
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "./mkldnn/mkldnn_batch_norm-inl.h"
 #endif
 
@@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
   mxnet::TShape shape = input.shape();
   return SupportMKLDNN(input) && shape.ndim() == 4
@@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
 
   bool dispatched = false;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   if (!dispatched) {
     dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
                                    in_attrs, out_attrs);
@@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0.
 .set_attr<nnvm::FInferType>("FInferType", BatchNormType)
 .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
 .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
 .set_num_outputs(3)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
 #endif
 .set_attr_parser(ParamParser<BatchNormParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 61de08f..ef5886e 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -26,7 +26,7 @@
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include <vector>
 #include <utility>
 #include <mkldnn.hpp>
@@ -44,54 +44,44 @@ typedef mkldnn::batch_normalization_forward::desc               t_bn_f_desc;
 typedef mkldnn::batch_normalization_backward::primitive_desc    t_bn_b_pdesc;
 typedef mkldnn::batch_normalization_backward::desc              t_bn_b_desc;
 
-using mkldnn::use_global_stats;
-using mkldnn::use_scale_shift;
-using mkldnn::forward_training;
-using mkldnn::forward_inference;
-
-inline static unsigned _GetFlags(const std::vector<NDArray> &in_data,
+inline static mkldnn::normalization_flags _GetFlags(const std::vector<NDArray> &in_data,
                                  const std::vector<NDArray> &aux_states,
                                  const BatchNormParam &param, bool is_train_and_not_global_stats) {
-  unsigned flags = 0U;
+  mkldnn::normalization_flags flags = static_cast<mkldnn::normalization_flags>(0U);
   if (in_data.size() == 3U) {
-    flags |= use_scale_shift;
+    flags |=  mkldnn::normalization_flags::use_scale_shift;
   }
 
   // aux_states[0]: inMean
   // aux_states[1]: inVariance
   if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
-    flags |= use_global_stats;
+    flags |=  mkldnn::normalization_flags::use_global_stats;
   }
   return flags;
 }
 
-template <typename DType>
 inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem,
                                    bool is_train,
-                                   DType eps,
-                                   unsigned flags) {
-  auto data_mpd   = data_mem.get_primitive_desc();
-  auto data_md    = data_mpd.desc();
-  auto engine     = CpuEngine::Get()->get_engine();
+                                   float eps,
+                                   mkldnn::normalization_flags flags) {
+  auto data_md   = data_mem.get_desc();
+  auto engine    = CpuEngine::Get()->get_engine();
 
   if (is_train) {
-    t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags);
+    t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_training, data_md, eps, flags);
     return t_bn_f_pdesc(bnFwd_desc, engine);
   } else {
-    t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags);
+    t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_inference, data_md, eps, flags);
     return t_bn_f_pdesc(bnFwd_desc, engine);
   }
 }
 
-template <typename DType>
 inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem,
                                    const mkldnn::memory &diff_mem,
-                                   DType eps,
-                                   unsigned flags) {
-  auto data_mpd   = data_mem.get_primitive_desc();
-  auto data_md    = data_mpd.desc();
-  auto diff_mpd   = diff_mem.get_primitive_desc();
-  auto diff_md    = diff_mpd.desc();
+                                   float eps,
+                                   mkldnn::normalization_flags flags) {
+  auto data_md    = data_mem.get_desc();
+  auto diff_md    = diff_mem.get_desc();
   auto engine     = CpuEngine::Get()->get_engine();
 
   t_bn_b_desc  bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags);
@@ -101,18 +91,15 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem,
 typedef ParamOpSign<BatchNormParam> MKLDNNBNSignature;
 
 class MKLDNNBNForward {
-  std::shared_ptr<const mkldnn::memory> data_m;
   std::shared_ptr<const mkldnn::memory> weight_m;
-  std::shared_ptr<const mkldnn::memory> out_m;
-  std::shared_ptr<const mkldnn::memory> mean_m;
-  std::shared_ptr<const mkldnn::memory> var_m;
   std::shared_ptr<mkldnn::batch_normalization_forward> fwd;
   bool is_train_and_not_global_stats;
   t_bn_f_pdesc pd;
 
  public:
   MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) {
-    weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc()));
+    weight_m.reset(new mkldnn::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
+    fwd.reset(new mkldnn::batch_normalization_forward(pd));
     this->is_train_and_not_global_stats = is_train_and_not_global_stats;
   }
 
@@ -124,59 +111,6 @@ class MKLDNNBNForward {
     return pd;
   }
 
-  const mkldnn::memory &GetMean() const {
-    return *mean_m;
-  }
-
-  const mkldnn::memory &GetVar() const {
-    return *var_m;
-  }
-
-  void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean,
-                     const mkldnn::memory *var, const mkldnn::memory *out) {
-    if (data_m) {
-      data_m->set_data_handle(data->get_data_handle());
-    } else {
-      data_m.reset(new mkldnn::memory(data->get_primitive_desc(),
-                                      data->get_data_handle()));
-    }
-    if (out_m) {
-      out_m->set_data_handle(out->get_data_handle());
-    } else {
-      out_m.reset(new mkldnn::memory(out->get_primitive_desc(),
-                                     out->get_data_handle()));
-    }
-    if (mean_m) {
-      mean_m->set_data_handle(mean->get_data_handle());
-    } else {
-      mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(),
-                                      mean->get_data_handle()));
-    }
-    if (var_m) {
-      var_m->set_data_handle(var->get_data_handle());
-    } else {
-      var_m.reset(new mkldnn::memory(var->get_primitive_desc(),
-                                     var->get_data_handle()));
-    }
-
-    if (fwd == nullptr) {
-      if (!is_train_and_not_global_stats)
-        fwd.reset(new mkldnn::batch_normalization_forward(
-                pd, *data_m, mkldnn::primitive::at(*mean_m),
-                mkldnn::primitive::at(*var_m), *weight_m, *out_m));
-      else
-        fwd.reset(new mkldnn::batch_normalization_forward(
-                pd, mkldnn::primitive::at(*data_m),
-                mkldnn::primitive::at(*weight_m), *out_m,
-                *mean_m, *var_m));
-    }
-  }
-
-  void SetDataHandle(const NDArray &data, const NDArray &mean,
-                     const NDArray &var, const mkldnn::memory &out) {
-    SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
-  }
-
   const mkldnn::batch_normalization_forward &GetFwd() const {
     return *fwd;
   }
@@ -185,7 +119,7 @@ class MKLDNNBNForward {
 template<typename DType>
 static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
                                      const OpContext &ctx, const mkldnn::memory *data_mem,
-                                     unsigned flags) {
+                                     mkldnn::normalization_flags flags) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, OpHash> fwds;
 #else
@@ -193,13 +127,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
 #endif
   MKLDNNBNSignature key(param);
   key.AddSign(ctx.is_train);
-  key.AddSign(param.use_global_stats);
   key.AddSign(*data_mem);
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     auto fwd_pd = _GetFwd(*data_mem, ctx.is_train,
-                          (DType) param.eps, flags);
+                          param.eps, flags);
     MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
     it = AddToCache(&fwds, key, fwd);
   }
@@ -209,7 +142,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
 template<typename DType>
 static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
                                      const OpContext &ctx, const NDArray &in_data,
-                                     unsigned flags) {
+                                     mkldnn::normalization_flags flags) {
   return GetBNForward<DType>(param, ctx, in_data.GetMKLDNNData(), flags);
 }
 
@@ -220,18 +153,20 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
                             const std::vector<NDArray>   &out_data,
                             const std::vector<NDArray>   &aux_states) {
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
-  unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
+  mkldnn::normalization_flags flags = _GetFlags(in_data,
+                                                aux_states,
+                                                param,
+                                                ctx.is_train && !param.use_global_stats);
   const NDArray &data = in_data[batchnorm::kData];
-
   auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
-  const NDArray &out  = out_data[batchnorm::kOut];
+  const NDArray &out = out_data[batchnorm::kOut];
 
   // for output memory
-  auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());
+  auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
 
   // mxnet will always use scale shift.
   // But if fix_gamma is true, then all scale elements will be set to 1.0f
-  if (flags & use_scale_shift) {
+  if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
     CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
@@ -241,7 +176,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
     DType* weight_buf = reinterpret_cast<DType *>(weight_mem.get_data_handle());
 
     nnvm::dim_t channels_ = data.shape()[1];
-    CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2);
+    CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(DType) * 2);
     DType* weight_ptr = gamma.data().dptr<DType>();
     DType* bias_ptr = beta.data().dptr<DType>();
     if (!param.fix_gamma) {
@@ -249,17 +184,22 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
       memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_);
     } else if (IsBNWriting(req[batchnorm::kGamma])) {
       for (int i = 0; i < channels_; i++) {
-        weight_buf[i] = (DType)1.0f;
-        weight_ptr[i] = (DType)1.0f;
+        weight_buf[i] = static_cast<DType>(1.0f);
+        weight_ptr[i] = static_cast<DType>(1.0f);
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
     } else {
       for (int i = 0; i < channels_; i++) {
-        weight_buf[i] = (DType)1.0f;
+        weight_buf[i] = static_cast<DType>(1.0f);
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
     }
 
+    mkldnn_args_map_t net_args;
+    net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
+    net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
+    net_args[MKLDNN_ARG_DST] = *out_mem;
+
     if (!ctx.is_train || param.use_global_stats) {
       DType* omean    = out_data[batchnorm::kMean].data().dptr<DType>();
       DType* ovar     = out_data[batchnorm::kVar].data().dptr<DType>();
@@ -270,26 +210,21 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
         omean[i] = inmean[i];
         ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
       }
-
-      fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean],
-                        aux_states[batchnorm::kMovingVar],
-                        *out_mem);
-      MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+      net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData());
+      net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData());
+      MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
       MKLDNNStream::Get()->Submit();
     } else {  // training
       const NDArray &outMean  = out_data[batchnorm::kMean];
       const NDArray &outVar   = out_data[batchnorm::kVar];
-      DType* omean    = outMean.data().dptr<DType>();
-      DType* ovar     = outVar.data().dptr<DType>();
-
-      fwd.SetDataHandle(data, outMean, outVar, *out_mem);
-      MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+      net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData());
+      net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData());
+      MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
       MKLDNNStream::Get()->Submit();
-      DType* mean_mem_ptr = reinterpret_cast<DType*>(fwd.GetMean().get_data_handle());
-      DType* var_mem_ptr  = reinterpret_cast<DType*>(fwd.GetVar().get_data_handle());
+
+      DType* ovar = outVar.data().dptr<DType>();
       for (int i = 0; i < channels_; i++) {
-        omean[i] = mean_mem_ptr[i];
-        ovar[i]  = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps);
+        ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
       }
     }
   } else {  // no input gamma and beta
@@ -299,11 +234,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
 
 class MKLDNNBNBackward {
   std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
-  std::shared_ptr<mkldnn::memory> data_m;
-  std::shared_ptr<mkldnn::memory> diff_m;
-  std::shared_ptr<mkldnn::memory> gradi_m;
-  std::shared_ptr<mkldnn::memory> mean_m;
-  std::shared_ptr<mkldnn::memory> var_m;
   const std::shared_ptr<mkldnn::memory> weight_m;
   const std::shared_ptr<mkldnn::memory> gradw_m;
 
@@ -311,41 +241,16 @@ class MKLDNNBNBackward {
   const t_bn_b_pdesc pd;
 
   explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
-      : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
-        gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
-        pd(_pd) {}
+      : weight_m(new mkldnn::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
+        gradw_m(new mkldnn::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
+        pd(_pd) {
+    bwd.reset(new mkldnn::batch_normalization_backward(pd));
+  }
 
   const mkldnn::memory &GetWeight() const { return *weight_m; }
 
   const mkldnn::memory &GetGradw() const { return *gradw_m; }
 
-  void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
-                     const NDArray &mean, const mkldnn::memory &var,
-                     const mkldnn::memory &gradi) {
-    auto mean_ptr = mean.data().dptr_;
-    if (bwd == nullptr) {
-      data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
-                                      data.get_data_handle()));
-      diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
-                                      diff.get_data_handle()));
-      gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
-                                       gradi.get_data_handle()));
-      mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
-      var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
-                                     var.get_data_handle()));
-      bwd.reset(new mkldnn::batch_normalization_backward(
-          pd, *data_m, mkldnn::primitive::at(*mean_m),
-          mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
-          *gradw_m));
-    } else {
-      data_m->set_data_handle(data.get_data_handle());
-      diff_m->set_data_handle(diff.get_data_handle());
-      gradi_m->set_data_handle(gradi.get_data_handle());
-      mean_m->set_data_handle(mean_ptr);
-      var_m->set_data_handle(var.get_data_handle());
-    }
-  }
-
   const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
 };
 
@@ -353,7 +258,7 @@ template <typename DType>
 static MKLDNNBNBackward &GetBNBackward(
     const BatchNormParam &param, const OpContext &ctx, const NDArray &in_data,
     const mkldnn::memory &in_mem, const NDArray &diff_data,
-    const mkldnn::memory &diff_mem, unsigned flags) {
+    const mkldnn::memory &diff_mem, mkldnn::normalization_flags flags) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
 #else
@@ -385,7 +290,10 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
   CHECK_EQ(in_data.size(), 3U);
   CHECK_EQ(out_data.size(), 3U);
   CHECK_EQ(in_grad.size(), 3U);
-  unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
+  mkldnn::normalization_flags flags = _GetFlags(in_data,
+                                                aux_states,
+                                                param,
+                                                ctx.is_train && !param.use_global_stats);
 
   const NDArray &data         = in_data[batchnorm::kData];
   const NDArray &diff         = out_grad[batchnorm::kOut];
@@ -405,13 +313,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
   // MKLDNN batchnorm should run on special layouts. If one of them isn't, we
   // should reorder them.
   if (data.IsDefaultData())
-    data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
+    data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc());
   else if (diff.IsDefaultData())
-    diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
+    diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
   auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
-  auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());
+  auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
 
-  if (flags & use_scale_shift) {
+  if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
     DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
@@ -420,20 +328,27 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
       if (!param.fix_gamma)
         weight_buf[i] = (gamma.data().dptr<DType>())[i];   // weight
       else
-        weight_buf[i] = (DType)1.0f;
+        weight_buf[i] = static_cast<DType>(1.0f);
     }
 
     for (int i = 0; i < channels_; i++) {
       weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i];  // bias
     }
 
+    mkldnn_args_map_t net_args;
+    net_args[MKLDNN_ARG_SRC] = *data_mem;
+    net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
+    net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
+    net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
+    net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
+
     // training but no input mean and variance
     if (ctx.is_train && !param.use_global_stats) {
       DType* moving_mean_ptr  = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
       DType* moving_var_ptr   = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
       DType* out_mean_ptr     = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
       DType* out_var_ptr      = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
-      mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
+      mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
       DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
 
       DType minus_mom = (1.0f - param.momentum);
@@ -445,13 +360,14 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
         moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
                             variance * minus_mom;
       }
-      bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
-      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
+      net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
+      net_args[MKLDNN_ARG_VARIANCE] = var_mem;
+      MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
       MKLDNNStream::Get()->Submit();
     } else {
-      bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
-                        *moving_var.GetMKLDNNData(), *gradi_mem);
-      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
+      net_args[MKLDNN_ARG_MEAN] =  *(moving_mean.GetMKLDNNData());
+      net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
+      MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
       MKLDNNStream::Get()->Submit();
     }