You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/09/13 05:50:36 UTC

[incubator-mxnet] branch master updated: MKLDNN Backward op cache (#11301)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 741635a  MKLDNN Backward op cache (#11301)
741635a is described below

commit 741635abdb51daad7d87e386423dd22fdd1eceda
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Thu Sep 13 13:50:24 2018 +0800

    MKLDNN Backward op cache (#11301)
    
    * Enable primitive allocation cache for _backward_Activation.
    
    Change-Id: I545628ff68a54cb01b7fef323dc3de9bd47b1a19
    
    * Enable primitive allocation cache for _backward_Deconvolution.
    
    Change-Id: I1e9bf1b9b44bae52068a9c564dff037851e896e5
    
    * Enable primitive allocation cache for _backward_Pooling.
    
    Change-Id: Idbe94e21f1e2ddf711523767194b95beda19b120
    
    * Enable primitive allocation cache for _backward_LRN.
    
    Change-Id: Iefe9f720de719ec2e2f5d24a006602425136711b
    
    * Enable primitive allocation cache for _backward_BatchNorm.
    
    Change-Id: I9e52651bd830b8cb5d2f193076ef51606c9056f9
    
    * Enable primitive allocation cache for _backward_Convolution
    
    Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca
    
    * Enable primitive allocation cache for _backward_Fully_Connected
    
    Change-Id: I8347527ec1271b1518921a74e3581d7d84187429
    
    * remove fc forward and fix indent problem
    
    * remove fc forward and fix convolution indent problem
    
    * Change log level to FATAL for unreachable code in mkldnn_act.cc
    
    * remove fc forward and fix convolution indent problem
    
    * remove useless hint in fc
    
    * Merge branch 'master' into backward_op_cache
    
    * Empty commit to retrigger the CI.
    
    * Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
    
    * Fix build issue after code merge.
    
    * Fix lint after merge
    
    * Fix mkldnn act.
---
 src/operator/nn/fully_connected.cc             |   3 +-
 src/operator/nn/lrn.cc                         |   2 +
 src/operator/nn/mkldnn/mkldnn_act.cc           | 125 +++++++++++---
 src/operator/nn/mkldnn/mkldnn_base-inl.h       |   5 +
 src/operator/nn/mkldnn/mkldnn_base.cc          |  11 ++
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 129 +++++++++-----
 src/operator/nn/mkldnn/mkldnn_convolution.cc   | 205 +++++++++++++++++++---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 225 ++++++++++++++++++++-----
 src/operator/nn/mkldnn/mkldnn_lrn-inl.h        | 180 ++++++++++++++------
 src/operator/nn/mkldnn/mkldnn_pooling-inl.h    |  21 +++
 src/operator/nn/mkldnn/mkldnn_pooling.cc       | 186 +++++++++++++-------
 11 files changed, 844 insertions(+), 248 deletions(-)

diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index d8a32f0..a178b27 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
   uint32_t out_expected = param.no_bias ? 2 : 3;
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), out_expected);
-
-  bool dispatched = false;
   // TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
   // It seems there is a bug.
+  bool dispatched = false;
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index a428eb1..99f0dc4 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
     MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
     MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
     MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    // Copy outputs[1] from opcheck reference as backward check needs it.
+    MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
     return;
   }
   FallBackCompute(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 744fed2..c914b38 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
                                         alg, data_md, alpha);
     return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
   });
-  LOG(INFO) << "Unsupported data type for MKLDNN activation";
+  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
   mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
       mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
   return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
@@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   stream->Submit();
 }
 
+static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
+    const ActivationParam &param, const mkldnn::memory &input_mem,
+    const mkldnn::memory &diff_dst_memory, int dtype) {
+  mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
+  mkldnn::memory::desc data_md = data_mpd.desc();
+  mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
+  auto cpu_engine = data_mpd.get_engine();
+  auto alg = GetMKLDNNActAlgo(param);
+
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    DType alpha = 0;
+    mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+                                          alg, data_md, alpha);
+    mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+    mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
+    mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+                                                      fw_pdesc);
+    return bw_pdesc;
+  });
+  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
+  mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+                                        alg, data_md, 0.0);
+  mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+  mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
+  mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+                                                    fw_pdesc);
+  return bw_pdesc;
+}
+
+class MKLDNNActBackward {
+  std::shared_ptr<mkldnn::eltwise_backward> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> diff_dst_memory;
+  std::shared_ptr<mkldnn::memory> diff_src_memory;
+
+ public:
+  const mkldnn::eltwise_backward::primitive_desc pd;
+
+  explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
+                             const mkldnn::memory &mem,
+                             const mkldnn::memory &diff_dst_memory)
+      : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}
+
+  void SetNewMem(const mkldnn::memory &data,
+                 const mkldnn::memory &diff_dst_memory,
+                 const mkldnn::memory &diff_src_memory) {
+    if (this->bwd != nullptr) {
+      this->data->set_data_handle(data.get_data_handle());
+      this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
+      this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
+    } else {
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          data.get_primitive_desc(), data.get_data_handle()));
+      this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
+                             diff_dst_memory.get_data_handle()));
+      this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(diff_src_memory.get_primitive_desc(),
+                             diff_src_memory.get_data_handle()));
+      this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
+          new mkldnn::eltwise_backward(
+              this->pd, mkldnn::primitive::at(*this->data),
+              *this->diff_dst_memory, *this->diff_src_memory));
+    }
+  }
+
+  const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
+                                                const OpContext &ctx,
+                                                const NDArray &in_data,
+                                                const NDArray &out_grad,
+                                                const mkldnn::memory &in_mem) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#endif
+  MKLDNNActSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(out_grad);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
+    auto ins_ret =
+        bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 // For backward relu activation, it's okay to pass "out_data" as "in_data" to this
 // function, since the computation only involes non-zeros.
 void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
@@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
   // descriptor. Otherwise, the perf will suffer.
   if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
     input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
-  mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
-  mkldnn::memory::desc data_md = data_mpd.desc();
-  mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
-  auto cpu_engine = data_mpd.get_engine();
-
+  MKLDNNActBackward &bwd =
+      GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
   MKLDNNStream *stream = MKLDNNStream::Get();
-  auto alg = GetMKLDNNActAlgo(param);
-  mkldnn_output_t diff_src_memory;
-
-  MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
-    DType alpha = 0;
-    mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
-                                          alg, data_md, alpha);
-    mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
-    mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
-    mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
-                                                      fw_pdesc);
-
-    diff_src_memory = CreateMKLDNNMem(in_grad,
-                                      bw_pdesc.diff_src_primitive_desc(), req);
-    stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
-                                                  *diff_dst_memory,
-                                                  *diff_src_memory.second));
-  });
+  mkldnn_output_t diff_src_memory =
+      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+  bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
+  stream->RegisterPrim(bwd.GetBwd());
   CommitOutput(in_grad, diff_src_memory);
   stream->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 6eb90f8..c4a4e52 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -509,6 +509,9 @@ class OpCheck {
            const std::vector<mxnet::NDArray> &inputs_,
            const std::vector<mxnet::OpReqType> &req,
            const std::vector<mxnet::NDArray> &outputs_);
+
+  void CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+                  const std::vector<size_t>& indice);
 };
 
 bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
@@ -525,6 +528,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
 
 #define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs)    \
     if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs);
+#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \
+    if (debug) check.CopyResult(outputs, indice);
 
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index f3facd9..029f23b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
   }
 }
 
+void OpCheck::CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+                         const std::vector<size_t> &indice) {
+  CHECK(!MKLDNNStream::Get()->HasOps());
+  auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray> &>(outputs_);
+  for (auto i = indice.begin(); i != indice.end(); ++i) {
+    auto mem = outputs[*i].GetMKLDNNData();
+    non_const_outputs_[*i].CopyFrom(*mem);
+  }
+  MKLDNNStream::Get()->Submit();
+}
+
 bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
                        const int dev_mask,
                        bool support_mkldnn,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 9046836..496ff99 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
   }
 }
 
+class MKLDNNBNBackward {
+  std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
+  std::shared_ptr<mkldnn::memory> data_m;
+  std::shared_ptr<mkldnn::memory> diff_m;
+  std::shared_ptr<mkldnn::memory> gradi_m;
+  std::shared_ptr<mkldnn::memory> mean_m;
+  std::shared_ptr<mkldnn::memory> var_m;
+  const std::shared_ptr<mkldnn::memory> weight_m;
+  const std::shared_ptr<mkldnn::memory> gradw_m;
+
+ public:
+  const t_bn_b_pdesc pd;
+
+  explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
+      : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
+        gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
+        pd(_pd) {}
+
+  const mkldnn::memory &GetWeight() const { return *weight_m; }
+
+  const mkldnn::memory &GetGradw() const { return *gradw_m; }
+
+  void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
+                     const NDArray &mean, const mkldnn::memory &var,
+                     const mkldnn::memory &gradi) {
+    auto mean_ptr = mean.data().dptr_;
+    if (bwd == nullptr) {
+      data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
+                                      data.get_data_handle()));
+      diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
+                                      diff.get_data_handle()));
+      gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
+                                       gradi.get_data_handle()));
+      mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
+      var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
+                                     var.get_data_handle()));
+      bwd.reset(new mkldnn::batch_normalization_backward(
+          pd, *data_m, mkldnn::primitive::at(*mean_m),
+          mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
+          *gradw_m));
+    } else {
+      data_m->set_data_handle(data.get_data_handle());
+      diff_m->set_data_handle(diff.get_data_handle());
+      gradi_m->set_data_handle(gradi.get_data_handle());
+      mean_m->set_data_handle(mean_ptr);
+      var_m->set_data_handle(var.get_data_handle());
+    }
+  }
+
+  const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
+};
+
+template <typename DType>
+static MKLDNNBNBackward &GetBNBackward(
+    const BatchNormParam &param, const OpContext &ctx, const NDArray &in_data,
+    const mkldnn::memory &in_mem, const NDArray &diff_data,
+    const mkldnn::memory &diff_mem, unsigned flags) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#endif
+  MKLDNNBNSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(diff_data);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
+    MKLDNNBNBackward bwd(bwd_pd);
+    auto ins_ret =
+        bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 template <typename DType>
 void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
                              const std::vector<NDArray>    &out_grad,
@@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
     data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
   else if (diff.IsDefaultData())
     diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
-  auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags);
+  auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
   auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());
 
   if (flags & use_scale_shift) {
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
-    // TODO(tao): how to reuse this memory?
-    std::shared_ptr<const mkldnn::memory> weight_mem(
-                    new mkldnn::memory(bwd_pd.weights_primitive_desc()));
-
-    DType* weight_buf = reinterpret_cast<DType *>(weight_mem->get_data_handle());
+    DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
     nnvm::dim_t channels_ = data.shape()[1];
     for (int i = 0; i < channels_; i++) {
       if (!param.fix_gamma)
@@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
       weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i];  // bias
     }
 
-    std::shared_ptr<const mkldnn::memory> gradw_mem(
-                    new mkldnn::memory(bwd_pd.diff_weights_primitive_desc()));
     // training but no input mean and variance
     if (ctx.is_train && !param.use_global_stats) {
       DType* moving_mean_ptr  = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
       DType* moving_var_ptr   = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
       DType* out_mean_ptr     = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
       DType* out_var_ptr      = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
-      mkldnn::memory var_mem(bwd_pd.variance_primitive_desc());
+      mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
       DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
 
       DType minus_mom = (1.0f - param.momentum);
@@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
         moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
                             variance * minus_mom;
       }
-
-      std::shared_ptr<const mkldnn::memory> out_mean_mem(
-                      new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr));
-      std::shared_ptr<const mkldnn::memory> out_var_mem(
-                      new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr));
-
-      auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
-                                                         *data_mem,
-                                                         mkldnn::primitive::at(*out_mean_mem),
-                                                         mkldnn::primitive::at(var_mem),
-                                                         *diff_mem,
-                                                         *weight_mem,
-                                                         *gradi_mem,
-                                                         *gradw_mem);
-
-      MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+      bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
+      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
       MKLDNNStream::Get()->Submit();
     } else {
-      std::shared_ptr<const mkldnn::memory> imean_mem(
-                      new mkldnn::memory(bwd_pd.mean_primitive_desc(),
-                      moving_mean.data().dptr<DType>()));
-      std::shared_ptr<const mkldnn::memory> ivar_mem(
-                      new mkldnn::memory(bwd_pd.variance_primitive_desc(),
-                      moving_var.data().dptr<DType>()));
-      auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
-                                                         *data_mem,
-                                                         mkldnn::primitive::at(*imean_mem),
-                                                         mkldnn::primitive::at(*ivar_mem),
-                                                         *diff_mem,
-                                                         *weight_mem,
-                                                         *gradi_mem,
-                                                         *gradw_mem);
-
-      MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+      bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
+                        *moving_var.GetMKLDNNData(), *gradi_mem);
+      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
       MKLDNNStream::Get()->Submit();
     }
 
     // copy data from gradw_mem to in_grad[1] and in_grad[2]
-    DType* gw_buf = reinterpret_cast<DType *>(gradw_mem->get_data_handle());
+    DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
     for (int i = 0; i < channels_; i++) {
       if (!param.fix_gamma)
         (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index cf04ea8..2e19d32 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
   MKLDNNStream::Get()->Submit();
 }
 
+class MKLDNNConvBackward {
+  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data;
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight;
+  // conv::kData
+  std::shared_ptr<mkldnn::memory> out_grad;
+  std::shared_ptr<mkldnn::memory> in_grad;
+  std::shared_ptr<mkldnn::memory> weight;
+  // conv::kWeight
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> output;
+  std::shared_ptr<mkldnn::memory> in_grad_weight;
+  std::shared_ptr<mkldnn::memory> in_grad_bias;
+
+ public:
+  mkldnn::convolution_backward_data::primitive_desc bwdData_pd;
+  mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd;
+
+  MKLDNNConvBackward(
+      const ConvolutionParam &param, const NDArray &data,
+      const NDArray &weights, const NDArray *bias, const NDArray &output,
+      const mkldnn::convolution_forward::primitive_desc &fwd_pd):
+      bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)),
+      bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) {
+  }
+
+  void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight,
+                     const mkldnn::memory &in_grad) {
+    if (this->out_grad == nullptr)
+      this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->out_grad->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad == nullptr)
+      this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
+    else
+      this->in_grad->set_data_handle(in_grad.get_data_handle());
+    if (this->weight == nullptr)
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+         bwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
+    else
+      this->weight->set_data_handle(weight.get_data_handle());
+    if (this->bwd_data == nullptr)
+      this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>(
+        new mkldnn::convolution_backward_data(
+          this->bwdData_pd, mkldnn::primitive::at(*this->out_grad),
+          mkldnn::primitive::at(*this->weight), *this->in_grad));
+  }
+
+void SetWeightNewMem(const mkldnn::memory &data,
+                     const mkldnn::memory &out_grad,
+                     const mkldnn::memory &in_grad_weight) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+    if (this->output == nullptr)
+      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->output->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad_weight == nullptr)
+      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+                             in_grad_weight.get_data_handle()));
+    else
+      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+    if (this->bwd_weight == nullptr)
+      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(
+              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->output), *this->in_grad_weight));
+  }
+
+  void SetWeightNewMem(const mkldnn::memory &data,
+                       const mkldnn::memory &out_grad,
+                       const mkldnn::memory &in_grad_weight,
+                       const mkldnn::memory &in_grad_bias) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+    if (this->output == nullptr)
+      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->output->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad_weight == nullptr)
+      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+                             in_grad_weight.get_data_handle()));
+    else
+      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+    if (this->in_grad_bias == nullptr)
+      this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(),
+                             in_grad_bias.get_data_handle()));
+    else
+      this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());
+    if (this->bwd_weight == nullptr)
+      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(
+              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->output), *this->in_grad_weight,
+              *this->in_grad_bias));
+  }
+
+  const mkldnn::convolution_backward_data &GetBwdData() const {
+    return *bwd_data;
+  }
+
+  const mkldnn::convolution_backward_weights &GetBwdWeights() const {
+    return *bwd_weight;
+  }
+};
+
+static inline MKLDNNConvBackward &GetConvBwd(
+    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights,
+    const NDArray *bias, const NDArray &output,
+    const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#endif
+  const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
+  MKLDNNConvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+  if (bias)
+    key.AddSign(*bias);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                                const std::vector<NDArray>& inputs,
                                const std::vector<OpReqType>& req,
@@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
       param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]);
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  mkldnn::convolution_backward_data::primitive_desc bwdData_pd
-    = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
-        inputs[conv::kOut], fwd_pd);
+  MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
+             inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
   auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
-      bwdData_pd.diff_dst_primitive_desc());
+      convBwd.bwdData_pd.diff_dst_primitive_desc());
   if (req[conv::kData]) {
     auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
-        bwdData_pd.weights_primitive_desc(), param.num_group);
+        convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
     auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
-        bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
-          *out_grad_mem, *weight_mem, *in_grad_mem.second));
+        convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
+    convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+    MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
     CommitOutput(in_grad[conv::kData], in_grad_mem);
   }
   if (req[conv::kWeight]) {
-    mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
-        = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
-                            param.no_bias ? nullptr : &inputs[conv::kBias + 1],
-                            inputs[conv::kOut], fwd_pd);
-    if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc())
+    MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1],
+             inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1],
+             inputs[conv::kOut], fwd_pd);
+    if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
+        convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
       out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
-          bwdWeights_pd.diff_dst_primitive_desc());
+          convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
     auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
-        bwdWeights_pd.src_primitive_desc());
-    auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight],
-                                                 bwdWeights_pd.diff_weights_primitive_desc(),
-                                                 req[conv::kWeight]);
+        convBwdWeight.bwdWeights_pd.src_primitive_desc());
+    auto in_grad_weight = CreateMKLDNNWeightGrad(
+        in_grad[conv::kWeight],
+        convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
+        req[conv::kWeight]);
     mkldnn_output_t in_grad_bias;
     if (param.no_bias) {
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-              bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
+      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+                              *in_grad_weight.second);
+      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
     } else {
-      in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
-                                     bwdWeights_pd.diff_bias_primitive_desc(),
-                                     req[conv::kBias]);
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-              bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
-              *in_grad_bias.second));
+      in_grad_bias = CreateMKLDNNMem(
+          in_grad[conv::kBias],
+          convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
+      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+                              *in_grad_weight.second, *in_grad_bias.second);
+      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
       CommitOutput(in_grad[conv::kBias], in_grad_bias);
     }
     CommitOutput(in_grad[conv::kWeight], in_grad_weight);
@@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
 
 }  // namespace op
 }  // namespace mxnet
-
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 7f3676a..54d4f67 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -93,9 +93,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
   return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd);
 }
 
-static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
-    const DeconvolutionParam &param, const NDArray &data, const NDArray &weights,
-    bool has_bias, const NDArray &output) {
+static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, bool has_bias, const NDArray &output) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetWeightDesc(weights, param.num_group);
   auto out_md = GetMemDesc(output);
@@ -116,9 +116,10 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
       strides, padding, dilate);
 }
 
-static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
-    const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
-    bool has_bias, const NDArray &output,
+static mkldnn::convolution_backward_weights::primitive_desc
+GetDeconvBwdWeightsImpl(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, bool has_bias, const NDArray &output,
     const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetWeightDesc(weights, param.num_group);
@@ -308,55 +309,203 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
   MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
 }
 
-void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                                 const std::vector<NDArray>& inputs,
-                                 const std::vector<OpReqType>& req,
-                                 const std::vector<NDArray>& outputs) {
+class MKLDNNDeconvBackwardData {
+  std::shared_ptr<mkldnn::convolution_forward> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  const mkldnn::convolution_forward::primitive_desc pd;
+
+  MKLDNNDeconvBackwardData(const DeconvolutionParam &param, const NDArray &data,
+                           const NDArray &weights, const NDArray &output)
+      : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) {
+  }
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+                 const mkldnn::memory &output) {
+    if (bwd == nullptr) {
+    this->data = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle()));
+    this->weight = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle()));
+    this->out = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle()));
+    bwd = std::shared_ptr<mkldnn::convolution_forward>(
+        new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data),
+                                        mkldnn::primitive::at(*this->weight),
+                                        *this->out));
+    } else {
+      this->data->set_data_handle(data.get_data_handle());
+      this->weight->set_data_handle(weight.get_data_handle());
+      this->out->set_data_handle(output.get_data_handle());
+    }
+  }
+
+  const mkldnn::convolution_forward &GetBwd() const { return *bwd; }
+};
+
+typedef ParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;
+
+static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNDeconvSignature,
+                                         MKLDNNDeconvBackwardData, OpHash>
+      bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+                                            MKLDNNDeconvBackwardData, OpHash>
+      bwds;
+#endif
+  MKLDNNDeconvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNDeconvBackwardData bwd(param, data, weights, output);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
+class MKLDNNDeconvBackwardWeights {
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  const mkldnn::convolution_backward_weights::primitive_desc pd;
+
+  MKLDNNDeconvBackwardWeights(
+      const DeconvolutionParam &param, const NDArray &data,
+      const NDArray &weights, const NDArray &output,
+      const mkldnn::convolution_forward::primitive_desc &bwd_data_pd)
+      : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output,
+                                   bwd_data_pd)) {}
+
+  void SetNewMem(
+      const mkldnn::memory &data, const mkldnn::memory &weight,
+      const mkldnn::memory &output,
+      const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+    if (bwd == nullptr) {
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.src_primitive_desc(), data.get_data_handle()));
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.weights_primitive_desc(), weight.get_data_handle()));
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.dst_primitive_desc(), output.get_data_handle()));
+      bwd = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(pd, *this->data,
+                                                   *this->weight, *this->out));
+    } else {
+      this->data->set_data_handle(data.get_data_handle());
+      this->weight->set_data_handle(weight.get_data_handle());
+      this->out->set_data_handle(output.get_data_handle());
+    }
+  }
+
+  const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, const NDArray &output,
+    const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNDeconvSignature,
+                                         MKLDNNDeconvBackwardWeights, OpHash>
+      bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+                                            MKLDNNDeconvBackwardWeights, OpHash>
+      bwds;
+#endif
+  MKLDNNDeconvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardWeights>(key,
+                                                                      bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
+void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
+                                 const OpContext &ctx,
+                                 const std::vector<NDArray> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<NDArray> &outputs) {
   TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
   const std::vector<NDArray> &in_grad = outputs;
-  const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
-  CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData(
-      param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false,
-      inputs[deconv::kOut]);
+  const DeconvolutionParam &param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+  CHECK_NE(req[deconv::kWeight], kWriteInplace)
+      << "cannot write weight inplace";
+  MKLDNNDeconvBackwardData &bwd_data =
+      GetDeconvBwdData(param, inputs[deconv::kData + 1],
+                       inputs[deconv::kWeight + 1], inputs[deconv::kOut]);
   auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
-      bwdData_pd.src_primitive_desc());
+      bwd_data.pd.src_primitive_desc());
   if (req[deconv::kData]) {
-    auto weight_mem = GetWeights(inputs[deconv::kWeight + 1],
-                                 bwdData_pd.weights_primitive_desc(),
-                                 param.num_group);
-    auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData],
-                                       bwdData_pd.dst_primitive_desc(),
-                                       req[deconv::kData]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd,
-          *out_grad_mem, *weight_mem, *in_grad_mem.second));
+    auto weight_mem =
+        GetWeights(inputs[deconv::kWeight + 1],
+                   bwd_data.pd.weights_primitive_desc(), param.num_group);
+    auto in_grad_mem =
+        CreateMKLDNNMem(in_grad[deconv::kData],
+                        bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
+    bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+    MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd());
     CommitOutput(in_grad[deconv::kData], in_grad_mem);
   }
   if (req[deconv::kWeight]) {
-    mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
-      = GetDeconvBwdWeights(param, inputs[deconv::kData + 1],
-          inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd);
-    if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc())
+    MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
+        param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1],
+        inputs[deconv::kOut], bwd_data.pd);
+    if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc())
       out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
-          bwdWeights_pd.src_primitive_desc());
+          bwd_weights.pd.src_primitive_desc());
     auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
-        bwdWeights_pd.diff_dst_primitive_desc());
-    auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight],
-                                                 bwdWeights_pd.diff_weights_primitive_desc(),
-                                                 req[deconv::kWeight]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-          bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second));
+        bwd_weights.pd.diff_dst_primitive_desc());
+    auto in_grad_weight = CreateMKLDNNWeightGrad(
+        in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
+        req[deconv::kWeight]);
+    bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd);
+    MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd());
     CommitOutput(in_grad[deconv::kWeight], in_grad_weight);
   }
   MKLDNNStream::Get()->Submit();
   if (!param.no_bias) {
     typedef float DType;
     Stream<cpu> *s = ctx.get_stream<cpu>();
-    Tensor<cpu, 1, DType> gbias = in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
+    Tensor<cpu, 1, DType> gbias =
+        in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
     // If there is bias, the out grad has already been converted to the default
     // format, so this shouldn't cause any performance issues.
-    Tensor<cpu, 4, DType> grad = inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
-    Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad));
+    Tensor<cpu, 4, DType> grad =
+        inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
+    Assign(gbias, req[deconv::kBias],
+           mshadow::expr::sumall_except_dim<1>(grad));
   }
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index 4b179a7..bc386be 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file mkldnn_lrn-inl.h 
+ * \file mkldnn_lrn-inl.h
  * \brief
  * \Author: Patric Zhao, patric.zhao@intel.com
 */
@@ -40,9 +40,8 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
   return algorithm::lrn_across_channels;
 }
 
-inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
-                                                 const bool is_train,
-                                                 const memory::desc &src_md) {
+inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
+    const LRNParam &param, const bool is_train, const memory::desc &src_md) {
   mkldnn::engine &engine = CpuEngine::Get()->get_engine();
   const algorithm  alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
@@ -59,11 +58,10 @@ inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
   return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
 }
 
-inline mkldnn::lrn_backward::primitive_desc
-GetLRNBwd(const LRNParam &param,
-          const mkldnn::memory::desc &data_in_md,
-          const mkldnn::memory::desc &diff_md,
-          const lrn_forward::primitive_desc &lrnFwd_desc) {
+inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
+    const LRNParam &param, const mkldnn::memory::desc &data_in_md,
+    const mkldnn::memory::desc &diff_md,
+    const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
   mkldnn::engine &engine = CpuEngine::Get()->get_engine();
   const algorithm alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
@@ -96,8 +94,15 @@ class MKLDNNLRNFwd {
                  const NDArray &output,
                  const OpReqType req);
 
+  void SetNewMem(const NDArray &in_data,
+                 const mkldnn::memory *out_mem);
+
   void Execute(const NDArray &out_data);
 
+  mkldnn::lrn_forward &GetFwd();
+
+  const mkldnn::memory *GetWs();
+
  private:
   std::shared_ptr<mkldnn::lrn_forward> fwd;
   std::shared_ptr<mkldnn::memory> in_mem;
@@ -113,15 +118,17 @@ class MKLDNNLRNFwd {
 void MKLDNNLRNFwd::_Init(const LRNParam &param,
                          bool is_train,
                          const NDArray &in_data) {
-  mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc();
-  lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
+  mkldnn::memory::desc in_data_md =
+      in_data.GetMKLDNNData()->get_primitive_desc().desc();
+  mkldnn::lrn_forward::primitive_desc fwd_pd =
+      GetLRNFwdDesc(param, is_train, in_data_md);
 
   this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
                      ->get_primitive_desc()));
   this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
   if (is_train) {
-    // If it's training, we have to create a workspace memory. Otherwise, MKLDNN
-    // will have segmentation fault.
+    // If it's training, we have to create a workspace memory. Otherwise,
+    // MKLDNN will have segmentation fault.
     ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
     this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
         new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
@@ -142,11 +149,22 @@ void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
   this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
 }
 
+void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
+                             const mkldnn::memory *out_mem) {
+  const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
+  this->in_mem->set_data_handle(in_data_mem->get_data_handle());
+  this->out_mem->set_data_handle(out_mem->get_data_handle());
+}
+
 void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
   MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
   CommitOutput(out_data, output_mem_t);
   MKLDNNStream::Get()->Submit();
 }
+
+mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
+
+const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
 // End of LRN Class and its functions
 
 static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
@@ -161,16 +179,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
                                             MKLDNNLRNFwd,
                                             OpHash> lrn_fwds;
 #endif
-  auto alg_ = algorithm::lrn_across_channels;
-  auto kind_ = prop_kind::forward_training;
-  if (ctx.is_train) {
-    kind_ = prop_kind::forward_training;
-  } else {
-    kind_ = prop_kind::forward_scoring;
-  }
+  auto kind_ =
+      ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
 
   MKLDNNLRNSignature key(param);
-  key.AddSign(alg_);
   key.AddSign(kind_);
   key.AddSign(in_data);
 
@@ -185,10 +197,8 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
   return it->second;
 }
 
-void MKLDNNLRNForward(const OpContext &ctx,
-                      const LRNParam &param,
-                      const NDArray &in_data,
-                      const OpReqType req,
+void MKLDNNLRNForward(const OpContext &ctx, const LRNParam &param,
+                      const NDArray &in_data, const OpReqType req,
                       const NDArray &out_data) {
   auto in_buffer = in_data;
   if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
@@ -198,6 +208,90 @@ void MKLDNNLRNForward(const OpContext &ctx,
   fwd.Execute(out_data);
 }
 
+// LRN Backward Class
+class MKLDNNLRNBwd {
+  std::shared_ptr<mkldnn::lrn_backward> bwd;
+  std::shared_ptr<mkldnn::memory> in_data_mem;
+  std::shared_ptr<mkldnn::memory> diff_dst_mem;
+  std::shared_ptr<mkldnn::memory> ws_mem;
+  std::shared_ptr<mkldnn::memory> diff_src_mem;
+
+ public:
+  const mkldnn::lrn_forward::primitive_desc fwd_pd;
+  const mkldnn::lrn_backward::primitive_desc bwd_pd;
+
+  ~MKLDNNLRNBwd() {}
+
+  MKLDNNLRNBwd(const LRNParam &param, const mkldnn::memory::desc in_data_md,
+               const mkldnn::memory::desc diff_md)
+      : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
+        bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}
+
+  void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
+                 const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
+    if (bwd == nullptr) {
+      this->in_data_mem.reset(
+          new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
+                             in_data.GetMKLDNNData()->get_data_handle()));
+      this->diff_dst_mem.reset(
+          new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
+                             out_grad.GetMKLDNNData()->get_data_handle()));
+      this->ws_mem.reset(
+          new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
+                             ws->get_data_handle()));
+      this->diff_src_mem.reset(
+          new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
+                             diff_src_mem->get_data_handle()));
+      this->bwd.reset(new mkldnn::lrn_backward(
+          this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
+          mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
+          *this->diff_src_mem));
+    } else {
+      this->in_data_mem->set_data_handle(
+          in_data.GetMKLDNNData()->get_data_handle());
+      this->diff_dst_mem->set_data_handle(
+          out_grad.GetMKLDNNData()->get_data_handle());
+      this->ws_mem->set_data_handle(ws->get_data_handle());
+      this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
+    }
+  }
+
+  void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
+    MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
+    CommitOutput(in_grad, diff_src_mem_);
+    MKLDNNStream::Get()->Submit();
+  }
+};  // End of LRN Class
+
+static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
+                               const NDArray &in_grad, const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local
+      std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#else
+  static MX_THREAD_LOCAL
+      std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#endif
+  MKLDNNLRNSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(in_grad);
+  key.AddSign(out_grad);
+
+  auto it = lrn_bwds.find(key);
+  if (it == lrn_bwds.end()) {
+    const mkldnn::memory::desc in_data_md =
+        in_data.GetMKLDNNData()->get_primitive_desc().desc();
+    const mkldnn::memory::desc diff_md =
+        out_grad.GetMKLDNNData()->get_primitive_desc().desc();
+    MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
+    auto ins_ret =
+        lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
                        const NDArray &out_grad,
                        const NDArray &in_data,
@@ -206,43 +300,27 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
   if (req == kNullOp) {
     return;
   }
-
   // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c
   auto in_buffer = in_data;
   if (in_buffer.IsMKLDNNData()) {
     in_buffer = in_data.Reorder2Default();
   }
-
+  MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
   // Repeat FW for getting workspace
-  const mkldnn::memory *data_mem = in_buffer.GetMKLDNNData();
-  const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
-  const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train,
-                                                              data_md);
-
   // TODO(Patric): To keep the function stateless, we can't pass workspace
   //               from LRN forward to backward. We have to re-compute
   //               LRN forward to get the workspace.
   //               Will refine this code later.
-  std::shared_ptr<const mkldnn::memory> ws_mem(
-          new mkldnn::memory(pdesc_fwd.workspace_primitive_desc()));
+  MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
   std::shared_ptr<const mkldnn::memory> dst_temp(
-          new mkldnn::memory(pdesc_fwd.dst_primitive_desc()));
-  MKLDNNStream::Get()->RegisterPrim(
-          lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
-          *ws_mem, *dst_temp));
-
-  const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData();
-  const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
-  const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_md,
-                                                                   diff_md, pdesc_fwd);
-  mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
-                                                 pdesc_bwd.diff_src_primitive_desc(), req);
-
-  MKLDNNStream::Get()->RegisterPrim(
-        lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem),
-        mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second));
-  CommitOutput(in_grad, diff_src_mem);
-  MKLDNNStream::Get()->Submit();
+      new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
+  fwd.SetNewMem(in_buffer, dst_temp.get());
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+
+  mkldnn_output_t diff_src_mem =
+      CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
+  bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
+  bwd.Execute(in_grad, diff_src_mem);
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 5d349d3..6667961 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -80,6 +80,27 @@ class MKLDNNPoolingFwd {
             const int padding_l, const int padding_r);
 };
 
+class MKLDNNPoolingBwd {
+  std::shared_ptr<const mkldnn::pooling_backward> bwd;
+  std::shared_ptr<mkldnn::memory> diff_dst;
+  std::shared_ptr<mkldnn::memory> diff_src;
+  std::shared_ptr<mkldnn::memory> ws;
+  bool with_workspace;
+
+ public:
+  const mkldnn::pooling_backward::primitive_desc pd;
+
+  MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
+                   bool with_ws);
+
+  ~MKLDNNPoolingBwd() {}
+  void SetNewMem(const mxnet::NDArray *workspace,
+                 const mxnet::NDArray &out_grad,
+                 const mkldnn::memory *diff_src_mem);
+  const mkldnn::pooling_backward &GetBwd();
+  const mkldnn::pooling_backward::primitive_desc &GetPd();
+};
+
 inline bool SupportMKLDNNPooling(const PoolingParam &param) {
   return param.kernel.ndim() == 2 &&
          (param.pool_type == pool_enum::kMaxPooling ||
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index d8d65ba..1610944 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -134,10 +134,9 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
   }
 }
 
-mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam &param,
-                                                      const bool is_train,
-                                                      const memory::desc &data_md,
-                                                      const memory::desc &out_md) {
+mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
+    const PoolingParam &param, const bool is_train, const memory::desc &data_md,
+    const memory::desc &out_md) {
   CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
   int kernel_h_, kernel_w_;
   if (param.global_pool) {
@@ -255,11 +254,124 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
 void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
                           const NDArray &in_data, const OpReqType req,
                           const NDArray &out_data, const NDArray *workspace) {
-  auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
+  auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
   fwd.SetNewMem(in_data, out_data, req, workspace);
   fwd.Execute(out_data);
 }
 
+MKLDNNPoolingBwd::MKLDNNPoolingBwd(
+    const pooling_backward::primitive_desc &pdesc, bool with_ws)
+    : with_workspace(with_ws), pd(pdesc) {}
+
+void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
+                                 const mxnet::NDArray &out_grad,
+                                 const mkldnn::memory *diff_src_mem) {
+  if (bwd == nullptr) {
+    diff_dst.reset(
+        new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
+                           out_grad.GetMKLDNNData()->get_data_handle()));
+    diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
+                                      diff_src_mem->get_data_handle()));
+    if (with_workspace) {
+      CHECK(workspace != nullptr);
+      ws.reset(
+          new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
+                             workspace->GetMKLDNNData()->get_data_handle()));
+      bwd.reset(
+          new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
+    } else {
+      bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
+    }
+  } else {
+    diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
+    diff_src->set_data_handle(diff_src_mem->get_data_handle());
+    if (with_workspace) {
+      CHECK(workspace != nullptr);
+      ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
+    }
+  }
+}
+
+const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
+  return *this->bwd;
+}
+
+MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
+                                const NDArray &in_data,
+                                const NDArray &in_grad,
+                                const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local
+      std::unordered_map<MKLDNNPoolingSignature,
+                         MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#else
+  static MX_THREAD_LOCAL
+      std::unordered_map<MKLDNNPoolingSignature,
+                         MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#endif
+
+  bool with_workspace = MKLDNNRequireWorkspace(param);
+  MKLDNNPoolingSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(in_grad);
+  key.AddSign(out_grad);
+
+  auto it = pooling_bwds.find(key);
+  if (it == pooling_bwds.end()) {
+    auto diff_dst_mem = out_grad.GetMKLDNNData();
+    auto input_mem = in_data.GetMKLDNNData();
+    mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
+    const mkldnn::memory::desc data_md = data_mpd.desc();
+    const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
+                               static_cast<int>(out_grad.shape()[2]),
+                               static_cast<int>(out_grad.shape()[3])};
+    const memory::desc out_md(
+        {dims}, static_cast<memory::data_type>(data_md.data.data_type),
+        static_cast<memory::format>(data_md.data.format));
+    auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
+
+    const mkldnn::memory::desc diff_md =
+        diff_dst_mem->get_primitive_desc().desc();
+    const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
+                                static_cast<int>(in_grad.shape()[2]),
+                                static_cast<int>(in_grad.shape()[3])};
+    const memory::desc diff_in_md(
+        {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
+        static_cast<memory::format>(diff_md.data.format));
+    const mkldnn::engine cpu_engine = data_mpd.get_engine();
+    const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+
+    int kernel_h_, kernel_w_;
+    if (param.global_pool) {
+      kernel_h_ = data_md.data.dims[2];
+      kernel_w_ = data_md.data.dims[3];
+    } else {
+      kernel_h_ = param.kernel[0];
+      kernel_w_ = param.kernel[1];
+    }
+
+    int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
+    int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
+    int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
+    if (param.global_pool) {
+      pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
+      stride_h_ = stride_w_ = 1;
+    }
+
+    const pooling_backward::desc desc(
+        alg, diff_in_md, diff_md, {stride_h_, stride_w_},
+        {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
+        mkldnn::padding_kind::zero);
+    const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
+    MKLDNNPoolingBwd bwd(pdesc, with_workspace);
+    auto ins_ret = pooling_bwds.insert(
+        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
                               const NDArray &out_grad, const NDArray &in_data,
                               const NDArray *workspace, const OpReqType req,
@@ -267,68 +379,14 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
   if (req == kNullOp) {
     return;
   }
-
   TmpMemMgr::Get()->Init(ctx.requested[0]);
-  // mkldnn::memory
-  auto diff_dst_mem = out_grad.GetMKLDNNData();
-  auto input_mem = in_data.GetMKLDNNData();
-  mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
-  const mkldnn::memory::desc data_md = data_mpd.desc();
-  const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
-                             static_cast<int>(out_grad.shape()[2]),
-                             static_cast<int>(out_grad.shape()[3])};
-  const memory::desc out_md({dims},
-                            static_cast<memory::data_type>(data_md.data.data_type),
-                            static_cast<memory::format>(data_md.data.format));
-  auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md);
-
-  const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc();
-  const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
-                              static_cast<int>(in_grad.shape()[2]),
-                              static_cast<int>(in_grad.shape()[3])};
-  const memory::desc diff_in_md(
-      {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
-      static_cast<memory::format>(diff_md.data.format));
-  const mkldnn::engine  cpu_engine = data_mpd.get_engine();
-  const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
-
-  int kernel_h_, kernel_w_;
-  if (param.global_pool) {
-    kernel_h_ = data_md.data.dims[2];
-    kernel_w_ = data_md.data.dims[3];
-  } else {
-    kernel_h_ = param.kernel[0];
-    kernel_w_ = param.kernel[1];
-  }
-
-  int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
-  int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
-  int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
-  if (param.global_pool) {
-    pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
-    stride_h_ = stride_w_ = 1;
-  }
-
-  const pooling_backward::desc desc(alg, diff_in_md, diff_md,
-                                    {stride_h_, stride_w_},
-                                    {kernel_h_, kernel_w_},
-                                    {pad_t_, pad_l_}, {pad_b_, pad_r_},
-                                    mkldnn::padding_kind::zero);
-  const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd);
 
+  auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
   auto diff_src_mem =
-      CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req);
-
-  if (MKLDNNRequireWorkspace(param)) {
-    CHECK(workspace != nullptr);
-    auto workspace_mem = workspace->GetMKLDNNData();
-    MKLDNNStream::Get()->RegisterPrim(
-        pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem),
-                         *diff_src_mem.second));
-  } else {
-    MKLDNNStream::Get()->RegisterPrim(
-        pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second));
-  }
+      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+
+  bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
   CommitOutput(in_grad, diff_src_mem);
   MKLDNNStream::Get()->Submit();
 }