You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/13 05:50:26 UTC

[GitHub] eric-haibin-lin closed pull request #11301: MKLDNN Backward op cache

eric-haibin-lin closed pull request #11301: MKLDNN Backward op cache
URL: https://github.com/apache/incubator-mxnet/pull/11301
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index eb881d29abd..17a485c53f6 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
   uint32_t out_expected = param.no_bias ? 2 : 3;
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), out_expected);
-
-  bool dispatched = false;
   // TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
   // It seems there is a bug.
+  bool dispatched = false;
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index 587cf930920..41084033191 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
     MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
     MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
     MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    // Copy outputs[1] from opcheck reference as backward check needs it.
+    MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
     return;
   }
   FallBackCompute(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 744fed2c299..c914b38b542 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
                                         alg, data_md, alpha);
     return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
   });
-  LOG(INFO) << "Unsupported data type for MKLDNN activation";
+  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
   mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
       mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
   return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
@@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   stream->Submit();
 }
 
+static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
+    const ActivationParam &param, const mkldnn::memory &input_mem,
+    const mkldnn::memory &diff_dst_memory, int dtype) {
+  mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
+  mkldnn::memory::desc data_md = data_mpd.desc();
+  mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
+  auto cpu_engine = data_mpd.get_engine();
+  auto alg = GetMKLDNNActAlgo(param);
+
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    DType alpha = 0;
+    mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+                                          alg, data_md, alpha);
+    mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+    mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
+    mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+                                                      fw_pdesc);
+    return bw_pdesc;
+  });
+  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
+  mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+                                        alg, data_md, 0.0);
+  mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+  mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
+  mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+                                                    fw_pdesc);
+  return bw_pdesc;
+}
+
+class MKLDNNActBackward {
+  std::shared_ptr<mkldnn::eltwise_backward> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> diff_dst_memory;
+  std::shared_ptr<mkldnn::memory> diff_src_memory;
+
+ public:
+  const mkldnn::eltwise_backward::primitive_desc pd;
+
+  explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
+                             const mkldnn::memory &mem,
+                             const mkldnn::memory &diff_dst_memory)
+      : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}
+
+  void SetNewMem(const mkldnn::memory &data,
+                 const mkldnn::memory &diff_dst_memory,
+                 const mkldnn::memory &diff_src_memory) {
+    if (this->bwd != nullptr) {
+      this->data->set_data_handle(data.get_data_handle());
+      this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
+      this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
+    } else {
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          data.get_primitive_desc(), data.get_data_handle()));
+      this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
+                             diff_dst_memory.get_data_handle()));
+      this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(diff_src_memory.get_primitive_desc(),
+                             diff_src_memory.get_data_handle()));
+      this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
+          new mkldnn::eltwise_backward(
+              this->pd, mkldnn::primitive::at(*this->data),
+              *this->diff_dst_memory, *this->diff_src_memory));
+    }
+  }
+
+  const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
+                                                const OpContext &ctx,
+                                                const NDArray &in_data,
+                                                const NDArray &out_grad,
+                                                const mkldnn::memory &in_mem) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#endif
+  MKLDNNActSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(out_grad);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
+    auto ins_ret =
+        bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 // For backward relu activation, it's okay to pass "out_data" as "in_data" to this
 // function, since the computation only involes non-zeros.
 void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
@@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
   // descriptor. Otherwise, the perf will suffer.
   if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
     input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
-  mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
-  mkldnn::memory::desc data_md = data_mpd.desc();
-  mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
-  auto cpu_engine = data_mpd.get_engine();
-
+  MKLDNNActBackward &bwd =
+      GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
   MKLDNNStream *stream = MKLDNNStream::Get();
-  auto alg = GetMKLDNNActAlgo(param);
-  mkldnn_output_t diff_src_memory;
-
-  MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
-    DType alpha = 0;
-    mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
-                                          alg, data_md, alpha);
-    mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
-    mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
-    mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
-                                                      fw_pdesc);
-
-    diff_src_memory = CreateMKLDNNMem(in_grad,
-                                      bw_pdesc.diff_src_primitive_desc(), req);
-    stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
-                                                  *diff_dst_memory,
-                                                  *diff_src_memory.second));
-  });
+  mkldnn_output_t diff_src_memory =
+      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+  bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
+  stream->RegisterPrim(bwd.GetBwd());
   CommitOutput(in_grad, diff_src_memory);
   stream->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 273afcd32dc..4d0b8d062af 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -497,6 +497,9 @@ class OpCheck {
            const std::vector<mxnet::NDArray> &inputs_,
            const std::vector<mxnet::OpReqType> &req,
            const std::vector<mxnet::NDArray> &outputs_);
+
+  void CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+                  const std::vector<size_t>& indice);
 };
 
 bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
@@ -513,6 +516,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
 
 #define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs)    \
     if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs);
+#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \
+    if (debug) check.CopyResult(outputs, indice);
 
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index f3facd966aa..029f23bd8f5 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
   }
 }
 
+void OpCheck::CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+                         const std::vector<size_t> &indice) {
+  CHECK(!MKLDNNStream::Get()->HasOps());
+  auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray> &>(outputs_);
+  for (auto i = indice.begin(); i != indice.end(); ++i) {
+    auto mem = outputs[*i].GetMKLDNNData();
+    non_const_outputs_[*i].CopyFrom(*mem);
+  }
+  MKLDNNStream::Get()->Submit();
+}
+
 bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
                        const int dev_mask,
                        bool support_mkldnn,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 9046836e8e7..496ff99f4ee 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
   }
 }
 
+class MKLDNNBNBackward {
+  std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
+  std::shared_ptr<mkldnn::memory> data_m;
+  std::shared_ptr<mkldnn::memory> diff_m;
+  std::shared_ptr<mkldnn::memory> gradi_m;
+  std::shared_ptr<mkldnn::memory> mean_m;
+  std::shared_ptr<mkldnn::memory> var_m;
+  const std::shared_ptr<mkldnn::memory> weight_m;
+  const std::shared_ptr<mkldnn::memory> gradw_m;
+
+ public:
+  const t_bn_b_pdesc pd;
+
+  explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
+      : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
+        gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
+        pd(_pd) {}
+
+  const mkldnn::memory &GetWeight() const { return *weight_m; }
+
+  const mkldnn::memory &GetGradw() const { return *gradw_m; }
+
+  void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
+                     const NDArray &mean, const mkldnn::memory &var,
+                     const mkldnn::memory &gradi) {
+    auto mean_ptr = mean.data().dptr_;
+    if (bwd == nullptr) {
+      data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
+                                      data.get_data_handle()));
+      diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
+                                      diff.get_data_handle()));
+      gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
+                                       gradi.get_data_handle()));
+      mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
+      var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
+                                     var.get_data_handle()));
+      bwd.reset(new mkldnn::batch_normalization_backward(
+          pd, *data_m, mkldnn::primitive::at(*mean_m),
+          mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
+          *gradw_m));
+    } else {
+      data_m->set_data_handle(data.get_data_handle());
+      diff_m->set_data_handle(diff.get_data_handle());
+      gradi_m->set_data_handle(gradi.get_data_handle());
+      mean_m->set_data_handle(mean_ptr);
+      var_m->set_data_handle(var.get_data_handle());
+    }
+  }
+
+  const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
+};
+
+template <typename DType>
+static MKLDNNBNBackward &GetBNBackward(
+    const BatchNormParam &param, const OpContext &ctx, const NDArray &in_data,
+    const mkldnn::memory &in_mem, const NDArray &diff_data,
+    const mkldnn::memory &diff_mem, unsigned flags) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#endif
+  MKLDNNBNSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(diff_data);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
+    MKLDNNBNBackward bwd(bwd_pd);
+    auto ins_ret =
+        bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 template <typename DType>
 void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
                              const std::vector<NDArray>    &out_grad,
@@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
     data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
   else if (diff.IsDefaultData())
     diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
-  auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags);
+  auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
   auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());
 
   if (flags & use_scale_shift) {
     const NDArray &gamma    = in_data[batchnorm::kGamma];
     const NDArray &beta     = in_data[batchnorm::kBeta];
-    // TODO(tao): how to reuse this memory?
-    std::shared_ptr<const mkldnn::memory> weight_mem(
-                    new mkldnn::memory(bwd_pd.weights_primitive_desc()));
-
-    DType* weight_buf = reinterpret_cast<DType *>(weight_mem->get_data_handle());
+    DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
     nnvm::dim_t channels_ = data.shape()[1];
     for (int i = 0; i < channels_; i++) {
       if (!param.fix_gamma)
@@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
       weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i];  // bias
     }
 
-    std::shared_ptr<const mkldnn::memory> gradw_mem(
-                    new mkldnn::memory(bwd_pd.diff_weights_primitive_desc()));
     // training but no input mean and variance
     if (ctx.is_train && !param.use_global_stats) {
       DType* moving_mean_ptr  = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
       DType* moving_var_ptr   = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
       DType* out_mean_ptr     = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
       DType* out_var_ptr      = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
-      mkldnn::memory var_mem(bwd_pd.variance_primitive_desc());
+      mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
       DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
 
       DType minus_mom = (1.0f - param.momentum);
@@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
         moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
                             variance * minus_mom;
       }
-
-      std::shared_ptr<const mkldnn::memory> out_mean_mem(
-                      new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr));
-      std::shared_ptr<const mkldnn::memory> out_var_mem(
-                      new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr));
-
-      auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
-                                                         *data_mem,
-                                                         mkldnn::primitive::at(*out_mean_mem),
-                                                         mkldnn::primitive::at(var_mem),
-                                                         *diff_mem,
-                                                         *weight_mem,
-                                                         *gradi_mem,
-                                                         *gradw_mem);
-
-      MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+      bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
+      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
       MKLDNNStream::Get()->Submit();
     } else {
-      std::shared_ptr<const mkldnn::memory> imean_mem(
-                      new mkldnn::memory(bwd_pd.mean_primitive_desc(),
-                      moving_mean.data().dptr<DType>()));
-      std::shared_ptr<const mkldnn::memory> ivar_mem(
-                      new mkldnn::memory(bwd_pd.variance_primitive_desc(),
-                      moving_var.data().dptr<DType>()));
-      auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
-                                                         *data_mem,
-                                                         mkldnn::primitive::at(*imean_mem),
-                                                         mkldnn::primitive::at(*ivar_mem),
-                                                         *diff_mem,
-                                                         *weight_mem,
-                                                         *gradi_mem,
-                                                         *gradw_mem);
-
-      MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+      bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
+                        *moving_var.GetMKLDNNData(), *gradi_mem);
+      MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
       MKLDNNStream::Get()->Submit();
     }
 
     // copy data from gradw_mem to in_grad[1] and in_grad[2]
-    DType* gw_buf = reinterpret_cast<DType *>(gradw_mem->get_data_handle());
+    DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
     for (int i = 0; i < channels_; i++) {
       if (!param.fix_gamma)
         (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index cf04ea8da3d..2e19d3219ab 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
   MKLDNNStream::Get()->Submit();
 }
 
+class MKLDNNConvBackward {
+  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data;
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight;
+  // conv::kData
+  std::shared_ptr<mkldnn::memory> out_grad;
+  std::shared_ptr<mkldnn::memory> in_grad;
+  std::shared_ptr<mkldnn::memory> weight;
+  // conv::kWeight
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> output;
+  std::shared_ptr<mkldnn::memory> in_grad_weight;
+  std::shared_ptr<mkldnn::memory> in_grad_bias;
+
+ public:
+  mkldnn::convolution_backward_data::primitive_desc bwdData_pd;
+  mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd;
+
+  MKLDNNConvBackward(
+      const ConvolutionParam &param, const NDArray &data,
+      const NDArray &weights, const NDArray *bias, const NDArray &output,
+      const mkldnn::convolution_forward::primitive_desc &fwd_pd):
+      bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)),
+      bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) {
+  }
+
+  void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight,
+                     const mkldnn::memory &in_grad) {
+    if (this->out_grad == nullptr)
+      this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->out_grad->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad == nullptr)
+      this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
+    else
+      this->in_grad->set_data_handle(in_grad.get_data_handle());
+    if (this->weight == nullptr)
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+         bwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
+    else
+      this->weight->set_data_handle(weight.get_data_handle());
+    if (this->bwd_data == nullptr)
+      this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>(
+        new mkldnn::convolution_backward_data(
+          this->bwdData_pd, mkldnn::primitive::at(*this->out_grad),
+          mkldnn::primitive::at(*this->weight), *this->in_grad));
+  }
+
+void SetWeightNewMem(const mkldnn::memory &data,
+                     const mkldnn::memory &out_grad,
+                     const mkldnn::memory &in_grad_weight) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+    if (this->output == nullptr)
+      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->output->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad_weight == nullptr)
+      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+                             in_grad_weight.get_data_handle()));
+    else
+      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+    if (this->bwd_weight == nullptr)
+      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(
+              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->output), *this->in_grad_weight));
+  }
+
+  void SetWeightNewMem(const mkldnn::memory &data,
+                       const mkldnn::memory &out_grad,
+                       const mkldnn::memory &in_grad_weight,
+                       const mkldnn::memory &in_grad_bias) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+    if (this->output == nullptr)
+      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+    else
+      this->output->set_data_handle(out_grad.get_data_handle());
+    if (this->in_grad_weight == nullptr)
+      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+                             in_grad_weight.get_data_handle()));
+    else
+      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+    if (this->in_grad_bias == nullptr)
+      this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(),
+                             in_grad_bias.get_data_handle()));
+    else
+      this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());
+    if (this->bwd_weight == nullptr)
+      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(
+              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->output), *this->in_grad_weight,
+              *this->in_grad_bias));
+  }
+
+  const mkldnn::convolution_backward_data &GetBwdData() const {
+    return *bwd_data;
+  }
+
+  const mkldnn::convolution_backward_weights &GetBwdWeights() const {
+    return *bwd_weight;
+  }
+};
+
+static inline MKLDNNConvBackward &GetConvBwd(
+    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights,
+    const NDArray *bias, const NDArray &output,
+    const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#endif
+  const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
+  MKLDNNConvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+  if (bias)
+    key.AddSign(*bias);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                                const std::vector<NDArray>& inputs,
                                const std::vector<OpReqType>& req,
@@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
       param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]);
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  mkldnn::convolution_backward_data::primitive_desc bwdData_pd
-    = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
-        inputs[conv::kOut], fwd_pd);
+  MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
+             inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
   auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
-      bwdData_pd.diff_dst_primitive_desc());
+      convBwd.bwdData_pd.diff_dst_primitive_desc());
   if (req[conv::kData]) {
     auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
-        bwdData_pd.weights_primitive_desc(), param.num_group);
+        convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
     auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
-        bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
-          *out_grad_mem, *weight_mem, *in_grad_mem.second));
+        convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
+    convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+    MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
     CommitOutput(in_grad[conv::kData], in_grad_mem);
   }
   if (req[conv::kWeight]) {
-    mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
-        = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
-                            param.no_bias ? nullptr : &inputs[conv::kBias + 1],
-                            inputs[conv::kOut], fwd_pd);
-    if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc())
+    MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1],
+             inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1],
+             inputs[conv::kOut], fwd_pd);
+    if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
+        convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
       out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
-          bwdWeights_pd.diff_dst_primitive_desc());
+          convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
     auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
-        bwdWeights_pd.src_primitive_desc());
-    auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight],
-                                                 bwdWeights_pd.diff_weights_primitive_desc(),
-                                                 req[conv::kWeight]);
+        convBwdWeight.bwdWeights_pd.src_primitive_desc());
+    auto in_grad_weight = CreateMKLDNNWeightGrad(
+        in_grad[conv::kWeight],
+        convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
+        req[conv::kWeight]);
     mkldnn_output_t in_grad_bias;
     if (param.no_bias) {
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-              bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
+      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+                              *in_grad_weight.second);
+      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
     } else {
-      in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
-                                     bwdWeights_pd.diff_bias_primitive_desc(),
-                                     req[conv::kBias]);
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-              bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
-              *in_grad_bias.second));
+      in_grad_bias = CreateMKLDNNMem(
+          in_grad[conv::kBias],
+          convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
+      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+                              *in_grad_weight.second, *in_grad_bias.second);
+      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
       CommitOutput(in_grad[conv::kBias], in_grad_bias);
     }
     CommitOutput(in_grad[conv::kWeight], in_grad_weight);
@@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
 
 }  // namespace op
 }  // namespace mxnet
-
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 7f3676a70dd..54d4f670852 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -93,9 +93,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
   return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd);
 }
 
-static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
-    const DeconvolutionParam &param, const NDArray &data, const NDArray &weights,
-    bool has_bias, const NDArray &output) {
+static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, bool has_bias, const NDArray &output) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetWeightDesc(weights, param.num_group);
   auto out_md = GetMemDesc(output);
@@ -116,9 +116,10 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
       strides, padding, dilate);
 }
 
-static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
-    const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
-    bool has_bias, const NDArray &output,
+static mkldnn::convolution_backward_weights::primitive_desc
+GetDeconvBwdWeightsImpl(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, bool has_bias, const NDArray &output,
     const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetWeightDesc(weights, param.num_group);
@@ -308,55 +309,203 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
   MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
 }
 
-void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                                 const std::vector<NDArray>& inputs,
-                                 const std::vector<OpReqType>& req,
-                                 const std::vector<NDArray>& outputs) {
+class MKLDNNDeconvBackwardData {
+  std::shared_ptr<mkldnn::convolution_forward> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  const mkldnn::convolution_forward::primitive_desc pd;
+
+  MKLDNNDeconvBackwardData(const DeconvolutionParam &param, const NDArray &data,
+                           const NDArray &weights, const NDArray &output)
+      : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) {
+  }
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+                 const mkldnn::memory &output) {
+    if (bwd == nullptr) {
+    this->data = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle()));
+    this->weight = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle()));
+    this->out = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle()));
+    bwd = std::shared_ptr<mkldnn::convolution_forward>(
+        new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data),
+                                        mkldnn::primitive::at(*this->weight),
+                                        *this->out));
+    } else {
+      this->data->set_data_handle(data.get_data_handle());
+      this->weight->set_data_handle(weight.get_data_handle());
+      this->out->set_data_handle(output.get_data_handle());
+    }
+  }
+
+  const mkldnn::convolution_forward &GetBwd() const { return *bwd; }
+};
+
+typedef ParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;
+
+static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNDeconvSignature,
+                                         MKLDNNDeconvBackwardData, OpHash>
+      bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+                                            MKLDNNDeconvBackwardData, OpHash>
+      bwds;
+#endif
+  MKLDNNDeconvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNDeconvBackwardData bwd(param, data, weights, output);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
+class MKLDNNDeconvBackwardWeights {
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd;
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  const mkldnn::convolution_backward_weights::primitive_desc pd;
+
+  MKLDNNDeconvBackwardWeights(
+      const DeconvolutionParam &param, const NDArray &data,
+      const NDArray &weights, const NDArray &output,
+      const mkldnn::convolution_forward::primitive_desc &bwd_data_pd)
+      : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output,
+                                   bwd_data_pd)) {}
+
+  void SetNewMem(
+      const mkldnn::memory &data, const mkldnn::memory &weight,
+      const mkldnn::memory &output,
+      const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+    if (bwd == nullptr) {
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.src_primitive_desc(), data.get_data_handle()));
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.weights_primitive_desc(), weight.get_data_handle()));
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+          bwd_data_pd.dst_primitive_desc(), output.get_data_handle()));
+      bwd = std::shared_ptr<mkldnn::convolution_backward_weights>(
+          new mkldnn::convolution_backward_weights(pd, *this->data,
+                                                   *this->weight, *this->out));
+    } else {
+      this->data->set_data_handle(data.get_data_handle());
+      this->weight->set_data_handle(weight.get_data_handle());
+      this->out->set_data_handle(output.get_data_handle());
+    }
+  }
+
+  const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights(
+    const DeconvolutionParam &param, const NDArray &data,
+    const NDArray &weights, const NDArray &output,
+    const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNDeconvSignature,
+                                         MKLDNNDeconvBackwardWeights, OpHash>
+      bwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+                                            MKLDNNDeconvBackwardWeights, OpHash>
+      bwds;
+#endif
+  MKLDNNDeconvSignature key(param);
+  // Here we can sign the conv op with NDArray because conv primitive will
+  // decide the right layout for the, so we only need to get the shape and the
+  // data type of the arrays.
+  key.AddSign(data);
+  key.AddSign(weights);
+  key.AddSign(output);
+
+  auto it = bwds.find(key);
+  if (it == bwds.end()) {
+    MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd);
+    auto ins_ret = bwds.insert(
+        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardWeights>(key,
+                                                                      bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
+void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
+                                 const OpContext &ctx,
+                                 const std::vector<NDArray> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<NDArray> &outputs) {
   TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
   const std::vector<NDArray> &in_grad = outputs;
-  const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
-  CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData(
-      param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false,
-      inputs[deconv::kOut]);
+  const DeconvolutionParam &param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+  CHECK_NE(req[deconv::kWeight], kWriteInplace)
+      << "cannot write weight inplace";
+  MKLDNNDeconvBackwardData &bwd_data =
+      GetDeconvBwdData(param, inputs[deconv::kData + 1],
+                       inputs[deconv::kWeight + 1], inputs[deconv::kOut]);
   auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
-      bwdData_pd.src_primitive_desc());
+      bwd_data.pd.src_primitive_desc());
   if (req[deconv::kData]) {
-    auto weight_mem = GetWeights(inputs[deconv::kWeight + 1],
-                                 bwdData_pd.weights_primitive_desc(),
-                                 param.num_group);
-    auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData],
-                                       bwdData_pd.dst_primitive_desc(),
-                                       req[deconv::kData]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd,
-          *out_grad_mem, *weight_mem, *in_grad_mem.second));
+    auto weight_mem =
+        GetWeights(inputs[deconv::kWeight + 1],
+                   bwd_data.pd.weights_primitive_desc(), param.num_group);
+    auto in_grad_mem =
+        CreateMKLDNNMem(in_grad[deconv::kData],
+                        bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
+    bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+    MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd());
     CommitOutput(in_grad[deconv::kData], in_grad_mem);
   }
   if (req[deconv::kWeight]) {
-    mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
-      = GetDeconvBwdWeights(param, inputs[deconv::kData + 1],
-          inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd);
-    if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc())
+    MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
+        param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1],
+        inputs[deconv::kOut], bwd_data.pd);
+    if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc())
       out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
-          bwdWeights_pd.src_primitive_desc());
+          bwd_weights.pd.src_primitive_desc());
     auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
-        bwdWeights_pd.diff_dst_primitive_desc());
-    auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight],
-                                                 bwdWeights_pd.diff_weights_primitive_desc(),
-                                                 req[deconv::kWeight]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
-          bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second));
+        bwd_weights.pd.diff_dst_primitive_desc());
+    auto in_grad_weight = CreateMKLDNNWeightGrad(
+        in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
+        req[deconv::kWeight]);
+    bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd);
+    MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd());
     CommitOutput(in_grad[deconv::kWeight], in_grad_weight);
   }
   MKLDNNStream::Get()->Submit();
   if (!param.no_bias) {
     typedef float DType;
     Stream<cpu> *s = ctx.get_stream<cpu>();
-    Tensor<cpu, 1, DType> gbias = in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
+    Tensor<cpu, 1, DType> gbias =
+        in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
     // If there is bias, the out grad has already been converted to the default
     // format, so this shouldn't cause any performance issues.
-    Tensor<cpu, 4, DType> grad = inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
-    Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad));
+    Tensor<cpu, 4, DType> grad =
+        inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
+    Assign(gbias, req[deconv::kBias],
+           mshadow::expr::sumall_except_dim<1>(grad));
   }
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index 4b179a7fbc9..bc386bedde1 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file mkldnn_lrn-inl.h 
+ * \file mkldnn_lrn-inl.h
  * \brief
  * \Author: Patric Zhao, patric.zhao@intel.com
 */
@@ -40,9 +40,8 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
   return algorithm::lrn_across_channels;
 }
 
-inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
-                                                 const bool is_train,
-                                                 const memory::desc &src_md) {
+inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
+    const LRNParam &param, const bool is_train, const memory::desc &src_md) {
   mkldnn::engine &engine = CpuEngine::Get()->get_engine();
   const algorithm  alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
@@ -59,11 +58,10 @@ inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
   return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
 }
 
-inline mkldnn::lrn_backward::primitive_desc
-GetLRNBwd(const LRNParam &param,
-          const mkldnn::memory::desc &data_in_md,
-          const mkldnn::memory::desc &diff_md,
-          const lrn_forward::primitive_desc &lrnFwd_desc) {
+inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
+    const LRNParam &param, const mkldnn::memory::desc &data_in_md,
+    const mkldnn::memory::desc &diff_md,
+    const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
   mkldnn::engine &engine = CpuEngine::Get()->get_engine();
   const algorithm alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
@@ -96,8 +94,15 @@ class MKLDNNLRNFwd {
                  const NDArray &output,
                  const OpReqType req);
 
+  void SetNewMem(const NDArray &in_data,
+                 const mkldnn::memory *out_mem);
+
   void Execute(const NDArray &out_data);
 
+  mkldnn::lrn_forward &GetFwd();
+
+  const mkldnn::memory *GetWs();
+
  private:
   std::shared_ptr<mkldnn::lrn_forward> fwd;
   std::shared_ptr<mkldnn::memory> in_mem;
@@ -113,15 +118,17 @@ class MKLDNNLRNFwd {
 void MKLDNNLRNFwd::_Init(const LRNParam &param,
                          bool is_train,
                          const NDArray &in_data) {
-  mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc();
-  lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
+  mkldnn::memory::desc in_data_md =
+      in_data.GetMKLDNNData()->get_primitive_desc().desc();
+  mkldnn::lrn_forward::primitive_desc fwd_pd =
+      GetLRNFwdDesc(param, is_train, in_data_md);
 
   this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
                      ->get_primitive_desc()));
   this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
   if (is_train) {
-    // If it's training, we have to create a workspace memory. Otherwise, MKLDNN
-    // will have segmentation fault.
+    // If it's training, we have to create a workspace memory. Otherwise,
+    // MKLDNN will have segmentation fault.
     ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
     this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
         new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
@@ -142,11 +149,22 @@ void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
   this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
 }
 
+void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
+                             const mkldnn::memory *out_mem) {
+  const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
+  this->in_mem->set_data_handle(in_data_mem->get_data_handle());
+  this->out_mem->set_data_handle(out_mem->get_data_handle());
+}
+
 void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
   MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
   CommitOutput(out_data, output_mem_t);
   MKLDNNStream::Get()->Submit();
 }
+
+mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
+
+const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
 // End of LRN Class and its functions
 
 static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
@@ -161,16 +179,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
                                             MKLDNNLRNFwd,
                                             OpHash> lrn_fwds;
 #endif
-  auto alg_ = algorithm::lrn_across_channels;
-  auto kind_ = prop_kind::forward_training;
-  if (ctx.is_train) {
-    kind_ = prop_kind::forward_training;
-  } else {
-    kind_ = prop_kind::forward_scoring;
-  }
+  auto kind_ =
+      ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
 
   MKLDNNLRNSignature key(param);
-  key.AddSign(alg_);
   key.AddSign(kind_);
   key.AddSign(in_data);
 
@@ -185,10 +197,8 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
   return it->second;
 }
 
-void MKLDNNLRNForward(const OpContext &ctx,
-                      const LRNParam &param,
-                      const NDArray &in_data,
-                      const OpReqType req,
+void MKLDNNLRNForward(const OpContext &ctx, const LRNParam &param,
+                      const NDArray &in_data, const OpReqType req,
                       const NDArray &out_data) {
   auto in_buffer = in_data;
   if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
@@ -198,6 +208,90 @@ void MKLDNNLRNForward(const OpContext &ctx,
   fwd.Execute(out_data);
 }
 
+// LRN Backward Class
+class MKLDNNLRNBwd {
+  std::shared_ptr<mkldnn::lrn_backward> bwd;
+  std::shared_ptr<mkldnn::memory> in_data_mem;
+  std::shared_ptr<mkldnn::memory> diff_dst_mem;
+  std::shared_ptr<mkldnn::memory> ws_mem;
+  std::shared_ptr<mkldnn::memory> diff_src_mem;
+
+ public:
+  const mkldnn::lrn_forward::primitive_desc fwd_pd;
+  const mkldnn::lrn_backward::primitive_desc bwd_pd;
+
+  ~MKLDNNLRNBwd() {}
+
+  MKLDNNLRNBwd(const LRNParam &param, const mkldnn::memory::desc in_data_md,
+               const mkldnn::memory::desc diff_md)
+      : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
+        bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}
+
+  void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
+                 const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
+    if (bwd == nullptr) {
+      this->in_data_mem.reset(
+          new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
+                             in_data.GetMKLDNNData()->get_data_handle()));
+      this->diff_dst_mem.reset(
+          new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
+                             out_grad.GetMKLDNNData()->get_data_handle()));
+      this->ws_mem.reset(
+          new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
+                             ws->get_data_handle()));
+      this->diff_src_mem.reset(
+          new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
+                             diff_src_mem->get_data_handle()));
+      this->bwd.reset(new mkldnn::lrn_backward(
+          this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
+          mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
+          *this->diff_src_mem));
+    } else {
+      this->in_data_mem->set_data_handle(
+          in_data.GetMKLDNNData()->get_data_handle());
+      this->diff_dst_mem->set_data_handle(
+          out_grad.GetMKLDNNData()->get_data_handle());
+      this->ws_mem->set_data_handle(ws->get_data_handle());
+      this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
+    }
+  }
+
+  void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
+    MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
+    CommitOutput(in_grad, diff_src_mem_);
+    MKLDNNStream::Get()->Submit();
+  }
+};  // End of LRN Class
+
+static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
+                               const NDArray &in_grad, const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local
+      std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#else
+  static MX_THREAD_LOCAL
+      std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#endif
+  MKLDNNLRNSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(in_grad);
+  key.AddSign(out_grad);
+
+  auto it = lrn_bwds.find(key);
+  if (it == lrn_bwds.end()) {
+    const mkldnn::memory::desc in_data_md =
+        in_data.GetMKLDNNData()->get_primitive_desc().desc();
+    const mkldnn::memory::desc diff_md =
+        out_grad.GetMKLDNNData()->get_primitive_desc().desc();
+    MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
+    auto ins_ret =
+        lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
                        const NDArray &out_grad,
                        const NDArray &in_data,
@@ -206,43 +300,27 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
   if (req == kNullOp) {
     return;
   }
-
   // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c
   auto in_buffer = in_data;
   if (in_buffer.IsMKLDNNData()) {
     in_buffer = in_data.Reorder2Default();
   }
-
+  MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
   // Repeat FW for getting workspace
-  const mkldnn::memory *data_mem = in_buffer.GetMKLDNNData();
-  const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
-  const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train,
-                                                              data_md);
-
   // TODO(Patric): To keep the function stateless, we can't pass workspace
   //               from LRN forward to backward. We have to re-compute
   //               LRN forward to get the workspace.
   //               Will refine this code later.
-  std::shared_ptr<const mkldnn::memory> ws_mem(
-          new mkldnn::memory(pdesc_fwd.workspace_primitive_desc()));
+  MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
   std::shared_ptr<const mkldnn::memory> dst_temp(
-          new mkldnn::memory(pdesc_fwd.dst_primitive_desc()));
-  MKLDNNStream::Get()->RegisterPrim(
-          lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
-          *ws_mem, *dst_temp));
-
-  const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData();
-  const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
-  const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_md,
-                                                                   diff_md, pdesc_fwd);
-  mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
-                                                 pdesc_bwd.diff_src_primitive_desc(), req);
-
-  MKLDNNStream::Get()->RegisterPrim(
-        lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem),
-        mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second));
-  CommitOutput(in_grad, diff_src_mem);
-  MKLDNNStream::Get()->Submit();
+      new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
+  fwd.SetNewMem(in_buffer, dst_temp.get());
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+
+  mkldnn_output_t diff_src_mem =
+      CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
+  bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
+  bwd.Execute(in_grad, diff_src_mem);
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 5d349d37202..66679613d3a 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -80,6 +80,27 @@ class MKLDNNPoolingFwd {
             const int padding_l, const int padding_r);
 };
 
+class MKLDNNPoolingBwd {
+  std::shared_ptr<const mkldnn::pooling_backward> bwd;
+  std::shared_ptr<mkldnn::memory> diff_dst;
+  std::shared_ptr<mkldnn::memory> diff_src;
+  std::shared_ptr<mkldnn::memory> ws;
+  bool with_workspace;
+
+ public:
+  const mkldnn::pooling_backward::primitive_desc pd;
+
+  MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
+                   bool with_ws);
+
+  ~MKLDNNPoolingBwd() {}
+  void SetNewMem(const mxnet::NDArray *workspace,
+                 const mxnet::NDArray &out_grad,
+                 const mkldnn::memory *diff_src_mem);
+  const mkldnn::pooling_backward &GetBwd();
+  const mkldnn::pooling_backward::primitive_desc &GetPd();
+};
+
 inline bool SupportMKLDNNPooling(const PoolingParam &param) {
   return param.kernel.ndim() == 2 &&
          (param.pool_type == pool_enum::kMaxPooling ||
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index d8d65badc1c..1610944304e 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -134,10 +134,9 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
   }
 }
 
-mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam &param,
-                                                      const bool is_train,
-                                                      const memory::desc &data_md,
-                                                      const memory::desc &out_md) {
+mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
+    const PoolingParam &param, const bool is_train, const memory::desc &data_md,
+    const memory::desc &out_md) {
   CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
   int kernel_h_, kernel_w_;
   if (param.global_pool) {
@@ -255,11 +254,124 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
 void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
                           const NDArray &in_data, const OpReqType req,
                           const NDArray &out_data, const NDArray *workspace) {
-  auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
+  auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
   fwd.SetNewMem(in_data, out_data, req, workspace);
   fwd.Execute(out_data);
 }
 
+MKLDNNPoolingBwd::MKLDNNPoolingBwd(
+    const pooling_backward::primitive_desc &pdesc, bool with_ws)
+    : with_workspace(with_ws), pd(pdesc) {}
+
+void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
+                                 const mxnet::NDArray &out_grad,
+                                 const mkldnn::memory *diff_src_mem) {
+  if (bwd == nullptr) {
+    diff_dst.reset(
+        new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
+                           out_grad.GetMKLDNNData()->get_data_handle()));
+    diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
+                                      diff_src_mem->get_data_handle()));
+    if (with_workspace) {
+      CHECK(workspace != nullptr);
+      ws.reset(
+          new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
+                             workspace->GetMKLDNNData()->get_data_handle()));
+      bwd.reset(
+          new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
+    } else {
+      bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
+    }
+  } else {
+    diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
+    diff_src->set_data_handle(diff_src_mem->get_data_handle());
+    if (with_workspace) {
+      CHECK(workspace != nullptr);
+      ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
+    }
+  }
+}
+
+const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
+  return *this->bwd;
+}
+
+MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
+                                const NDArray &in_data,
+                                const NDArray &in_grad,
+                                const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local
+      std::unordered_map<MKLDNNPoolingSignature,
+                         MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#else
+  static MX_THREAD_LOCAL
+      std::unordered_map<MKLDNNPoolingSignature,
+                         MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#endif
+
+  bool with_workspace = MKLDNNRequireWorkspace(param);
+  MKLDNNPoolingSignature key(param);
+  key.AddSign(in_data);
+  key.AddSign(in_grad);
+  key.AddSign(out_grad);
+
+  auto it = pooling_bwds.find(key);
+  if (it == pooling_bwds.end()) {
+    auto diff_dst_mem = out_grad.GetMKLDNNData();
+    auto input_mem = in_data.GetMKLDNNData();
+    mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
+    const mkldnn::memory::desc data_md = data_mpd.desc();
+    const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
+                               static_cast<int>(out_grad.shape()[2]),
+                               static_cast<int>(out_grad.shape()[3])};
+    const memory::desc out_md(
+        {dims}, static_cast<memory::data_type>(data_md.data.data_type),
+        static_cast<memory::format>(data_md.data.format));
+    auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
+
+    const mkldnn::memory::desc diff_md =
+        diff_dst_mem->get_primitive_desc().desc();
+    const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
+                                static_cast<int>(in_grad.shape()[2]),
+                                static_cast<int>(in_grad.shape()[3])};
+    const memory::desc diff_in_md(
+        {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
+        static_cast<memory::format>(diff_md.data.format));
+    const mkldnn::engine cpu_engine = data_mpd.get_engine();
+    const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+
+    int kernel_h_, kernel_w_;
+    if (param.global_pool) {
+      kernel_h_ = data_md.data.dims[2];
+      kernel_w_ = data_md.data.dims[3];
+    } else {
+      kernel_h_ = param.kernel[0];
+      kernel_w_ = param.kernel[1];
+    }
+
+    int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
+    int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
+    int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
+    if (param.global_pool) {
+      pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
+      stride_h_ = stride_w_ = 1;
+    }
+
+    const pooling_backward::desc desc(
+        alg, diff_in_md, diff_md, {stride_h_, stride_w_},
+        {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
+        mkldnn::padding_kind::zero);
+    const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
+    MKLDNNPoolingBwd bwd(pdesc, with_workspace);
+    auto ins_ret = pooling_bwds.insert(
+        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
                               const NDArray &out_grad, const NDArray &in_data,
                               const NDArray *workspace, const OpReqType req,
@@ -267,68 +379,14 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
   if (req == kNullOp) {
     return;
   }
-
   TmpMemMgr::Get()->Init(ctx.requested[0]);
-  // mkldnn::memory
-  auto diff_dst_mem = out_grad.GetMKLDNNData();
-  auto input_mem = in_data.GetMKLDNNData();
-  mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
-  const mkldnn::memory::desc data_md = data_mpd.desc();
-  const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
-                             static_cast<int>(out_grad.shape()[2]),
-                             static_cast<int>(out_grad.shape()[3])};
-  const memory::desc out_md({dims},
-                            static_cast<memory::data_type>(data_md.data.data_type),
-                            static_cast<memory::format>(data_md.data.format));
-  auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md);
-
-  const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc();
-  const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
-                              static_cast<int>(in_grad.shape()[2]),
-                              static_cast<int>(in_grad.shape()[3])};
-  const memory::desc diff_in_md(
-      {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
-      static_cast<memory::format>(diff_md.data.format));
-  const mkldnn::engine  cpu_engine = data_mpd.get_engine();
-  const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
-
-  int kernel_h_, kernel_w_;
-  if (param.global_pool) {
-    kernel_h_ = data_md.data.dims[2];
-    kernel_w_ = data_md.data.dims[3];
-  } else {
-    kernel_h_ = param.kernel[0];
-    kernel_w_ = param.kernel[1];
-  }
-
-  int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
-  int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
-  int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
-  if (param.global_pool) {
-    pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
-    stride_h_ = stride_w_ = 1;
-  }
-
-  const pooling_backward::desc desc(alg, diff_in_md, diff_md,
-                                    {stride_h_, stride_w_},
-                                    {kernel_h_, kernel_w_},
-                                    {pad_t_, pad_l_}, {pad_b_, pad_r_},
-                                    mkldnn::padding_kind::zero);
-  const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd);
 
+  auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
   auto diff_src_mem =
-      CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req);
-
-  if (MKLDNNRequireWorkspace(param)) {
-    CHECK(workspace != nullptr);
-    auto workspace_mem = workspace->GetMKLDNNData();
-    MKLDNNStream::Get()->RegisterPrim(
-        pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem),
-                         *diff_src_mem.second));
-  } else {
-    MKLDNNStream::Get()->RegisterPrim(
-        pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second));
-  }
+      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+
+  bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
   CommitOutput(in_grad, diff_src_mem);
   MKLDNNStream::Get()->Submit();
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services