You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/09/18 03:09:54 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN Convolution (#16141)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new 1ff9429  [mkldnn-v1.0] Add MKL-DNN Convolution (#16141)
1ff9429 is described below

commit 1ff942948599dd248446fcb610b9fe0cc3070580
Author: rongzha1 <ro...@intel.com>
AuthorDate: Wed Sep 18 11:09:03 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN Convolution (#16141)
    
    * add mkldnn conv
    
    * revert unnecessary change
    
    * fix testcase fail for cpu: test_convolution_independent_gradients
    
    * fix failed testcase: test_reshape_transpose_6d&&test_weight_async_reorder
    
    * fix comments
    
    * change variable name from weights to weight in mkldnn_conv
---
 include/mxnet/ndarray.h                         |   4 +-
 src/common/exec_utils.h                         |   8 +-
 src/executor/attach_op_execs_pass.cc            |   8 +-
 src/imperative/imperative_utils.h               |  20 +-
 src/ndarray/ndarray.cc                          |  16 +-
 src/operator/nn/convolution.cc                  |  28 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h        |  22 +
 src/operator/nn/mkldnn/mkldnn_base.cc           |   3 +-
 src/operator/nn/mkldnn/mkldnn_convolution-inl.h |  67 +--
 src/operator/nn/mkldnn/mkldnn_convolution.cc    | 539 ++++++++----------------
 src/operator/nn/mkldnn/mkldnn_ops-inl.h         |  21 +-
 src/operator/operator_common.h                  |  41 +-
 src/operator/tensor/cast_storage-inl.h          |   6 +-
 13 files changed, 349 insertions(+), 434 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index fc4375b..16bb7e4 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -761,8 +761,8 @@ class NDArray {
    * It changes the layout of this NDArray, but it happens after all accesses to
    * the array are complete.
    */
-  void Reorder2DefaultAsync();
-  void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md);
+  void Reorder2DefaultAsync() const;
+  void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md) const;
 
   /*
    * This creates a new NDArray with the reordered data.
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index d8b7a33..f0b29e7 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -59,7 +59,7 @@ inline bool SetupDefaultBlobsIn(const std::vector<NDArray>& src,
   for (size_t i = 0; i < src.size(); i++) {
     auto& nd = src[i];
     bool is_default = nd.storage_type() == kDefaultStorage;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     // We have to make sure it's default storage and default layout.
     is_default = nd.IsDefaultData();
 #endif
@@ -67,7 +67,7 @@ inline bool SetupDefaultBlobsIn(const std::vector<NDArray>& src,
       (*idx_map)[i] = temp_dst->size();
       NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
                                                              true, nd.dtype());
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       CHECK(temp.IsDefaultData());
 #endif
       temp_src->emplace_back(nd);
@@ -91,7 +91,7 @@ inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
   for (size_t i = 0; i < src.size(); i++) {
     auto& nd = src[i];
     bool is_default = nd.storage_type() == kDefaultStorage;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     if (req->at(i) == kWriteInplace && nd.IsMKLDNNData())
       // If it's write inplace and the output array doesn't use the default
       // layout, we'll generate a temporary output array below, which means
@@ -102,7 +102,7 @@ inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
     is_default = nd.IsDefaultData();
 #endif
     if (!is_default) {
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       NDArray temp;
       if (bufs != nullptr) {
         temp = bufs->at(i);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 8f47bc2..ebd0328 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -116,7 +116,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
  public:
   void Run(RunContext rctx, bool is_gpu) override {
     op_ctx.run_ctx = rctx;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     InvalidateOutputs(out_array, req);
 #endif
     PreFCompute(is_gpu);
@@ -155,7 +155,7 @@ class StatefulComputeExExecutor : public OpExecutor {
  public:
   void Run(RunContext rctx, bool is_gpu) override {
     op_ctx.run_ctx = rctx;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     InvalidateOutputs(out_array, req);
     // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
     const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
@@ -202,7 +202,7 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
   void Run(RunContext rctx, bool is_gpu) override {
     using namespace common;
     op_ctx.run_ctx = rctx;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     InvalidateOutputs(out_array, req);
 #endif
     PreFCompute(is_gpu);
@@ -231,7 +231,7 @@ class FComputeExExecutor : public OpExecutor {
  public:
   void Run(RunContext rctx, bool is_gpu) override {
     op_ctx.run_ctx = rctx;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     InvalidateOutputs(out_array, req);
     // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
     const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 21caafa..3a2875e 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -418,7 +418,7 @@ inline void PushFCompute(const FCompute& fn,
       std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
       // mapping from index in input_blobs to index in pre_temp_dst
       std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       if (exec_type != ExecType::kCrossDeviceCopy) {
         // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in
         // its FCcomputeEx, but AsyncPush the copy operation to engine.
@@ -467,7 +467,7 @@ inline void PushFComputeEx(const FComputeEx& fn,
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
   const auto& run = [=](RunContext rctx) {
       OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested};
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       if (exec_type != ExecType::kCrossDeviceCopy) {
         // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in
         // its FCcomputeEx, but AsyncPush the copy operation to engine.
@@ -476,8 +476,18 @@ inline void PushFComputeEx(const FComputeEx& fn,
         // copying A to B may not happen, and will corrupt A's memory.
         InvalidateOutputs(outputs, req);
       }
+      // add for mkldnn OP + no mkldnn OP
+      const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+      if (!is_mkldnn.get(attrs.op, false)) {
+        std::vector<NDArray> inputs_fallback;
+        CreateDefaultInputs(inputs, &inputs_fallback);
+        fn(attrs, opctx, inputs_fallback, req, outputs);
+      } else {
+#endif
+        fn(attrs, opctx, inputs, req, outputs);
+#if MXNET_USE_MKLDNN == 100
+      }
 #endif
-      fn(attrs, opctx, inputs, req, outputs);
       if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) {
         rctx.get_stream<gpu>()->Wait();
       }
@@ -521,7 +531,7 @@ inline void PushOperator(const OpStatePtr& state,
     const auto& run = [=](RunContext rctx,
                           engine::CallbackOnComplete on_complete) {
       OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       if (exec_type != ExecType::kCrossDeviceCopy) {
         // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in
         // its FCcomputeEx, but AsyncPush the copy operation to engine.
@@ -567,7 +577,7 @@ inline void PushOperator(const OpStatePtr& state,
         std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
         // mapping from index in input_blobs to index in pre_temp_dst
         std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
       if (exec_type != ExecType::kCrossDeviceCopy) {
         // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in
         // its FCcomputeEx, but AsyncPush the copy operation to engine.
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 97daa29..f174d75 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -474,7 +474,7 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) {
 
   mkldnn::memory::dims dims;
   // These are shapes supprted by MKLDNN.
-  if (shape.ndim() >= 1 && shape.ndim() <= 5) {
+  if (shape.ndim() >= 1 && shape.ndim() <= 6) {
     dims.resize(shape.ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape[i];
@@ -488,6 +488,7 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) {
     case 3: layout = mkldnn::memory::format_tag::abc; break;
     case 4: layout = mkldnn::memory::format_tag::abcd; break;
     case 5: layout = mkldnn::memory::format_tag::abcde; break;
+    case 6: layout = mkldnn::memory::format_tag::abcdef; break;
     default:
       LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for MKLDNN";
   }
@@ -592,7 +593,7 @@ NDArray NDArray::Reorder2Default() const {
   return ret;
 }
 
-void NDArray::Reorder2DefaultAsync() {
+void NDArray::Reorder2DefaultAsync() const {
   std::vector<Engine::VarHandle> const_vars;
   std::vector<Engine::VarHandle> mutable_vars(1, this->var());
   NDArray tmp = *this;
@@ -604,13 +605,18 @@ void NDArray::Reorder2DefaultAsync() {
     FnProperty::kNormal, 0, "Reorder2Default");
 }
 
-void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) {
+void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) const {
   std::vector<Engine::VarHandle> const_vars;
   std::vector<Engine::VarHandle> mutable_vars(1, this->var());
   NDArray tmp = *this;
+  const auto version = this->version();
   Engine::Get()->PushAsync(
-    [tmp, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
-      tmp.ptr_->MKLDNNDataReorder(desc);
+    [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      // MXNet will try to reuse NDArray from memory planning, so we need to ensure
+      // the NDArray is still holding the original trunk data.
+      if (tmp.version() == version) {
+        tmp.ptr_->MKLDNNDataReorder(desc);
+      }
       on_complete();
     }, ctx(), const_vars, mutable_vars,
     FnProperty::kNormal, 0, "Reorder");
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 32ed93e..ad19128 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -30,7 +30,7 @@
 #if MXNET_USE_NNPACK == 1
 #include "../nnpack/nnpack_pooling-inl.h"
 #endif  // MXNET_USE_NNPACK
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "./mkldnn/mkldnn_base-inl.h"
 #include "./mkldnn/mkldnn_ops-inl.h"
 #endif  // MXNET_USE_MKLDNN
@@ -51,7 +51,7 @@ static inline std::vector<std::string> ListArguments(const ConvolutionParam& par
   }
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const OpContext& ctx,
                                     const std::vector<NDArray>& inputs,
@@ -60,7 +60,12 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
   const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
   if (SupportMKLDNNConv(params, inputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
-    MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs);
+    if (CheckMKLDNNInputArrayIsView(inputs)) {
+      const auto mkldnn_inputs = GetMKLDNNInputArray(inputs);
+      MKLDNNConvolutionForward(attrs, ctx, mkldnn_inputs, req, outputs);
+    } else {
+      MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs);
+    }
     MKLDNN_OPCHECK_RUN(ConvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
@@ -75,7 +80,12 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
   const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
   if (SupportMKLDNNConv(params, inputs[0])) {
     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-    MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
+    if (CheckMKLDNNInputArrayIsView(inputs)) {
+      const auto mkldnn_inputs = GetMKLDNNInputArray(inputs);
+      MKLDNNConvolutionBackward(attrs, ctx, mkldnn_inputs, req, outputs);
+    } else {
+      MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
+    }
     MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
@@ -302,7 +312,7 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 inline static bool ConvStorageType(const nnvm::NodeAttrs& attrs,
                                    const int dev_mask,
                                    DispatchMode* dispatch_mode,
@@ -491,11 +501,11 @@ There are other options to tune the performance.
 })
 .set_attr<mxnet::FInferShape>("FInferShape", ConvolutionShape)
 .set_attr<nnvm::FInferType>("FInferType", ConvolutionType)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FInferStorageType>("FInferStorageType", ConvStorageType)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", ConvolutionCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionComputeExCPU)
 #endif
@@ -514,14 +524,14 @@ NNVM_REGISTER_OP(_backward_Convolution)
   return params.no_bias ? 2 : 3;
 })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FInferStorageType>("FInferStorageType", BackwardConvStorageType)
 #endif
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
 .set_attr_parser(ConvolutionParamParser)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ConvolutionGradComputeExCPU)
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 85d42ff..054f422 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -277,6 +277,28 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
   }
 }
 
+inline static bool CheckMKLDNNInputArrayIsView(const std::vector<NDArray> &inputs) {
+  for (const auto &in : inputs) {
+    if (in.IsView() && in.IsMKLDNNData()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+inline static const std::vector<NDArray> GetMKLDNNInputArray(const std::vector<NDArray> &inputs) {
+  std::vector<NDArray> ret;
+  ret.reserve(inputs.size());
+  for (const auto &in : inputs) {
+    if (in.IsView() && in.IsMKLDNNData()) {
+      ret.push_back(in.Reorder2Default());
+    } else {
+      ret.push_back(in);
+    }
+  }
+  return ret;
+}
+
 typedef std::shared_ptr<mkldnn::memory> mkldnn_mem_ptr;
 typedef std::shared_ptr<const mkldnn::memory> mkldnn_mem_const_ptr;
 
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 31ffbbb..cfd7ad7 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -312,6 +312,7 @@ mkldnn_format_tag_t GetDefaultFormat(int num_dims) {
     case 3: return mkldnn_abc;
     case 4: return mkldnn_abcd;
     case 5: return mkldnn_abcde;
+    case 6: return mkldnn_abcdef;
     default:
       LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for MKLDNN";
       return mkldnn_format_tag_undef;
@@ -530,7 +531,7 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
     if (v == - 1) v = kDefaultStorage;
 
   DispatchMode wanted_mode;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   if (dev_mask == mshadow::cpu::kDevMask && !MKLDNNEnvSet())
     wanted_mode = DispatchMode::kFComputeFallback;
   else if (dev_mask == mshadow::cpu::kDevMask && support_mkldnn)
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
index 880b9d1..b3ceb95 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
@@ -25,7 +25,7 @@
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include <vector>
 #include <utility>
@@ -79,47 +79,26 @@ struct MKLDNNConvFullParam {
   MKLDNNPostEltwiseParam postsum_act_param;
 };
 
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam &param,
-                                                           const bool is_train,
-                                                           const NDArray &data,
-                                                           const NDArray &weights,
-                                                           const NDArray *bias,
-                                                           const NDArray &output);
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
+    const ConvolutionParam &param, const bool is_train, const NDArray &data, const NDArray &weight,
+    const NDArray *bias, const NDArray &output);
 
 class MKLDNNConvForward {
  public:
-  mkldnn::convolution_forward::primitive_desc fwd_pd;
-
   MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train, const NDArray &data,
-                    const NDArray &weights, const NDArray *bias, const NDArray &output);
+                    const NDArray &weight, const NDArray *bias, const NDArray &output);
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output);
+  const mkldnn::convolution_forward &GetFwd() const { return *fwd_; }
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
-    this->data_->set_data_handle(data.get_data_handle());
-    this->out_->set_data_handle(output.get_data_handle());
-  }
-
-  const mkldnn::convolution_forward &GetFwd() const {
-    return *fwd_;
-  }
+  const mkldnn::convolution_forward::primitive_desc &GetPd() const { return *pd_; }
 
  private:
   std::shared_ptr<mkldnn::convolution_forward> fwd_;
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> weight_;
-  std::shared_ptr<mkldnn::memory> bias_;
-  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::convolution_forward::primitive_desc> pd_;
 };
 
 typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
 
-MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
-                              const bool is_train, const NDArray &data,
-                              const NDArray &weights, const NDArray *bias,
-                              const NDArray &output);
-
 void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
                                          const OpContext &ctx,
                                          MKLDNNConvForward *fwd,
@@ -127,6 +106,36 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
                                          const std::vector<OpReqType> &req,
                                          const std::vector<NDArray> &out_data);
 
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<NDArray> &in_data,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<NDArray> &out_data);
+
+class MKLDNNConvBackward {
+ public:
+  MKLDNNConvBackward(const MKLDNNConvFullParam &param, const NDArray &data, const NDArray &weight,
+                     const NDArray *bias, const NDArray &output);
+
+  const mkldnn::convolution_backward_data &GetBwdData() const { return *bwd_data_; }
+
+  const mkldnn::convolution_backward_weights &GetBwdWeights() const { return *bwd_weight_; }
+
+  const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const {
+    return *bwd_data_pd_;
+  }
+
+  const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() const {
+    return *bwd_weight_pd_;
+  }
+
+ private:
+  std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> bwd_data_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> bwd_weight_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data_;
+  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight_;
+};
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 9cab2dd..4114188 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -21,10 +21,9 @@
  * \file mkldnn_convolution.cc
  * \brief
  * \author Da Zheng
-*/
-
+ */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include "../convolution-inl.h"
 #include "./mkldnn_ops-inl.h"
@@ -45,8 +44,10 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
           (input.shape().ndim() == 4));
 }
 
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam &param,
-                                                           const bool is_train, const NDArray &data,
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
+                                                           const MKLDNNConvFullParam &param,
+                                                           const bool is_train,
+                                                           const NDArray &data,
                                                            const NDArray &weights,
                                                            const NDArray *bias,
                                                            const NDArray &output) {
@@ -57,7 +58,7 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
   auto bias_md =
       bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias))
            : mkldnn::memory::desc{
-             {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any};
+             {}, mkldnn::memory::data_type::undef, mkldnn::memory::format_tag::any};
   auto bias_md_ptr = bias ? &bias_md : nullptr;
 
   mkldnn::memory::dims strides(param.conv_param.kernel.ndim());
@@ -98,19 +99,19 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
   if (param.mkldnn_param.quantized && param.requantize_scales.size()) {
     int mask = (param.requantize_scales.size() > 1) ? 2 : 0;
     attr.set_output_scales(mask, param.requantize_scales);
-    attr.set_int_output_round_mode(round_nearest);
   }
   auto GetConvFwdPd = [&param, &data, &weights, &output,
                        &attr](const mkldnn::convolution_forward::desc &desc) {
     auto engine = CpuEngine::Get()->get_engine();
     try {
-      auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine);
-      while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
-             conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
+      auto conv_pd =
+          std::make_shared<mkldnn::convolution_forward::primitive_desc>(desc, attr, engine);
+      while (conv_pd->dst_desc().get_size() != GetArraySize(output) ||
+             conv_pd->src_desc().get_size() != GetArraySize(data) ||
              (!param.mkldnn_param.quantized &&
-              conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) {
+              conv_pd->weights_desc().get_size() != GetArraySize(weights))) {
         // next_impl() will visit desc and engine, please make sure they are still alive here.
-        CHECK(conv_pd.next_impl()) << "No convolution implementation for this request.";
+        CHECK(conv_pd->next_impl()) << "No convolution implementation for this request.";
       }
       return conv_pd;
     } catch (mkldnn::error &e) {
@@ -126,13 +127,12 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
 
   if (param.conv_param.dilate.ndim() == 0 && bias_md_ptr == nullptr) {
     mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
-                                           weight_md, out_md, strides, padding, padding,
-                                           mkldnn::padding_kind::zero);
+                                           weight_md, out_md, strides, padding, padding);
     return GetConvFwdPd(desc);
   } else if (param.conv_param.dilate.ndim() == 0) {
     mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
                                            weight_md, *bias_md_ptr, out_md, strides, padding,
-                                           padding, mkldnn::padding_kind::zero);
+                                           padding);
     return GetConvFwdPd(desc);
   } else {
     mkldnn::memory::dims dilates(param.conv_param.kernel.ndim());
@@ -147,23 +147,22 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
     }
     if (bias_md_ptr == nullptr) {
       mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
-                                             weight_md, out_md, strides, dilates, padding, padding,
-                                             mkldnn::padding_kind::zero);
+                                             weight_md, out_md, strides, dilates, padding, padding);
       return GetConvFwdPd(desc);
     } else {
       mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
                                              weight_md, *bias_md_ptr, out_md, strides, dilates,
-                                             padding, padding, mkldnn::padding_kind::zero);
+                                             padding, padding);
       return GetConvFwdPd(desc);
     }
   }
 }
 
-static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
-    const ConvolutionParam& param, const NDArray &data, const NDArray &weights,
+static std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetConvBwdData(
+    const ConvolutionParam &param, const NDArray &data, const NDArray &weight,
     const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
   auto data_md = GetMemDesc(data);
-  auto weight_md = GetWeightDesc(weights, param.num_group);
+  auto weight_md = GetWeightDesc(weight, param.num_group);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   mkldnn::memory::dims strides(param.kernel.ndim());
@@ -187,21 +186,29 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
                << ", supporting only 1 or 2.";
   }
 
-  // MKL-DNN introduced padded formats since 0.15 which require more memory
-  // for computation compared with the actual tensor size. Currently, MKL-DNN
-  // operators are still reusing those memory from memory planning and the
-  // memory size may smaller than what MKL-DNN kernels require. So here we need
-  // select suboptimal kernel for computation according to tensor sizes.
-  if (param.dilate.ndim() == 0) {
-    mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
-    auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd);
-    while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
+  auto GetConvBwdDataPd = [&data, &weight, &output,
+                           &fwd_pd](const mkldnn::convolution_backward_data::desc &desc) {
+    auto engine = CpuEngine::Get()->get_engine();
+    try {
+      auto conv_pd =
+          std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(desc, engine, fwd_pd);
+      while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) ||
+             conv_pd->diff_src_desc().get_size() != GetArraySize(data) ||
+             conv_pd->weights_desc().get_size() != GetArraySize(weight)) {
+        // next_impl() will visit desc and engine, please make sure they are still alive here.
+        CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request.";
+      }
+      return conv_pd;
+    } catch (mkldnn::error &e) {
+      LOG(ERROR) << e.message;
+      throw;
     }
-    return conv_pd;
+  };
+
+  if (param.dilate.ndim() == 0) {
+    mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md,
+                                                 weight_md, out_md, strides, padding, padding);
+    return GetConvBwdDataPd(desc);
   } else {
     mkldnn::memory::dims dilates(param.kernel.ndim());
     if (param.dilate.ndim() == 1) {
@@ -213,25 +220,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
       LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
                  << param.dilate.ndim() << ", supporting only 1 or 2.";
     }
-    mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, out_md, strides, dilates, padding, padding,
-        mkldnn::padding_kind::zero);
-    auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd);
-    while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
-    }
-    return conv_pd;
+    mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md,
+                                                 weight_md, out_md, strides, dilates, padding,
+                                                 padding);
+    return GetConvBwdDataPd(desc);
   }
 }
 
-static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
-    const ConvolutionParam& param, const NDArray &data,
-    const NDArray &weights, const NDArray *bias, const NDArray &output,
-    const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
+static std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> GetConvBwdWeights(
+    const ConvolutionParam &param, const NDArray &data, const NDArray &weight, const NDArray *bias,
+    const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
   auto data_md = GetMemDesc(data);
-  auto weight_md = GetWeightDesc(weights, param.num_group);
+  auto weight_md = GetWeightDesc(weight, param.num_group);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   mkldnn::memory::dims strides(param.kernel.ndim());
@@ -255,33 +255,35 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
                << ", supporting only 1 or 2.";
   }
 
-  // MKL-DNN introduced padded formats since 0.15 which require more memory
-  // for computation compared with the actual tensor size. Currently, MKL-DNN
-  // operators are still reusing those memory from memory planning and the
-  // memory size may smaller than what MKL-DNN kernels require. So here we need
-  // select suboptimal kernel for computation according to tensor sizes.
-  if (param.dilate.ndim() == 0 && bias == nullptr) {
-    mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
-    auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
-    while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
+  auto GetConvBwdWeightsPd = [&data, &weight, &output,
+                              &fwd_pd](const mkldnn::convolution_backward_weights::desc &desc) {
+    auto engine = CpuEngine::Get()->get_engine();
+    try {
+      auto conv_pd = std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
+          desc, engine, fwd_pd);
+      while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) ||
+             conv_pd->src_desc().get_size() != GetArraySize(data) ||
+             conv_pd->diff_weights_desc().get_size() != GetArraySize(weight)) {
+        // next_impl() will visit desc and engine, please make sure they are still alive here.
+        CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request.";
+      }
+      return conv_pd;
+    } catch (mkldnn::error &e) {
+      LOG(ERROR) << e.message;
+      throw;
     }
-    return conv_pd;
+  };
+
+  if (param.dilate.ndim() == 0 && bias == nullptr) {
+    mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md,
+                                                    weight_md, out_md, strides, padding, padding);
+    return GetConvBwdWeightsPd(desc);
   } else if (param.dilate.ndim() == 0) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, bias_md, out_md, strides, padding, padding,
-        mkldnn::padding_kind::zero);
-    auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
-    while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
-    }
-    return conv_pd;
+    mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md,
+                                                    weight_md, bias_md, out_md, strides, padding,
+                                                    padding);
+    return GetConvBwdWeightsPd(desc);
   } else {
     mkldnn::memory::dims dilates(param.kernel.ndim());
     if (param.dilate.ndim() == 1) {
@@ -295,313 +297,154 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
     }
     if (bias == nullptr) {
       mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
-          data_md, weight_md, out_md, strides, dilates, padding, padding,
-          mkldnn::padding_kind::zero);
-      auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
-      while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-             conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-             conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) {
-        CHECK(conv_pd.next_impl()) << "No implementation";
-      }
-      return conv_pd;
+                                                      data_md, weight_md, out_md, strides, dilates,
+                                                      padding, padding);
+      return GetConvBwdWeightsPd(desc);
     } else {
       auto bias_md = GetMemDesc(*bias);
       mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
-                                                      data_md, weight_md, bias_md, out_md,
-                                                      strides, dilates, padding, padding,
-                                                      mkldnn::padding_kind::zero);
-      auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
-      while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) ||
-             conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-             conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) {
-        CHECK(conv_pd.next_impl()) << "No implementation";
-      }
-      return conv_pd;
+                                                      data_md, weight_md, bias_md, out_md, strides,
+                                                      dilates, padding, padding);
+      return GetConvBwdWeightsPd(desc);
     }
   }
 }
 
 MKLDNNConvForward::MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train,
-                                     const NDArray &data, const NDArray &weights,
+                                     const NDArray &data, const NDArray &weight,
                                      const NDArray *bias, const NDArray &output)
-    : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {
-  data_ = std::make_shared<mkldnn::memory>(fwd_pd.src_primitive_desc(), nullptr);
-  weight_ = std::make_shared<mkldnn::memory>(fwd_pd.weights_primitive_desc(), nullptr);
-  out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(), nullptr);
-  if (bias) {
-    bias_ = std::make_shared<mkldnn::memory>(fwd_pd.bias_primitive_desc(), nullptr);
-    fwd_ = std::make_shared<mkldnn::convolution_forward>(fwd_pd, *this->data_, *this->weight_,
-                                                         *this->bias_, *this->out_);
-  } else {
-    fwd_ = std::make_shared<mkldnn::convolution_forward>(fwd_pd, *this->data_, *this->weight_,
-                                                         *this->out_);
-  }
+    : pd_(GetConvFwdImpl(param, is_train, data, weight, bias, output)) {
+  fwd_ = std::make_shared<mkldnn::convolution_forward>(GetPd());
 }
 
-void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                                  const mkldnn::memory *bias, const mkldnn::memory &output) {
-  data_->set_data_handle(data.get_data_handle());
-  weight_->set_data_handle(weight.get_data_handle());
-  out_->set_data_handle(output.get_data_handle());
-  if (bias != nullptr) bias_->set_data_handle(bias->get_data_handle());
-}
-
-MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
-                              const bool is_train, const NDArray &data,
-                              const NDArray &weights, const NDArray *bias,
+MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam &param, const bool is_train,
+                              const NDArray &data, const NDArray &weight, const NDArray *bias,
                               const NDArray &output) {
+  using conv_fwd_map = std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, OpHash>;
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, OpHash> fwds;
+  static thread_local conv_fwd_map fwds;
 #else
-  static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, OpHash> fwds;
+  static MX_THREAD_LOCAL conv_fwd_map fwds;
 #endif
-  MKLDNNConvSignature key(param);
+  // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for
+  // fused conv
+  MKLDNNConvSignature key(param.conv_param);
   key.AddSign(is_train);
-  // Here we can sign the conv op with NDArray because conv primitive will
-  // decide the right layout for the, so we only need to get the shape and the
-  // data type of the arrays.
+  // Here we can sign the conv op with NDArray because conv primitive will decide the right layout
+  // for the, so we only need to get the shape and the data type of the arrays.
   key.AddSign(data);
-  key.AddSign(weights);
+  key.AddSign(weight);
   key.AddSign(output);
-  if (bias)
-    key.AddSign(*bias);
+  if (bias) key.AddSign(*bias);
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    MKLDNNConvFullParam full_param;
-    full_param.conv_param = param;
-    full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
-    MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output);
+    auto fwd = MKLDNNConvForward(param, is_train, data, weight, bias, output);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
 
-void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
-                                         const OpContext &ctx,
+void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param, const OpContext &ctx,
                                          MKLDNNConvForward *fwd,
                                          const std::vector<NDArray> &in_data,
                                          const std::vector<OpReqType> &req,
                                          const std::vector<NDArray> &out_data) {
   TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
 
-  auto data = in_data[conv::kData];
-  if (data.IsView() && data.IsMKLDNNData())
-    data = data.Reorder2Default();
-
-  auto weight = in_data[conv::kWeight];
-  if (weight.IsView() && weight.IsMKLDNNData())
-    weight = weight.Reorder2Default();
-
+  auto &data = in_data[conv::kData];
+  auto &weight = in_data[conv::kWeight];
   bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn;
 
-  auto data_mem = data.GetMKLDNNDataReorder(
-      fwd->fwd_pd.src_primitive_desc());
+  auto data_mem = data.GetMKLDNNDataReorder(fwd->GetPd().src_desc());
   const mkldnn::memory *weight_mem;
   if (ctx.is_train) {
-    // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
-    // to the default format for now.
+    // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it to the default format
+    // for now.
     if (weight.IsMKLDNNData())
-      // This asks the engine to change the layout of the weight array after
-      // it's used.
+      // This asks the engine to change the layout of the weight array after it's used.
       weight.Reorder2DefaultAsync();
-    weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
-                            param.conv_param.num_group);
+    weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
   } else {
-    // For inference, we want to reorder the weight array so we don't need to
-    // reorder data every time.
+    // For inference, we want to reorder the weight array so we don't need to reorder data every
+    // time.
     if (weight.IsDefaultData()) {
-      // We also need to modify the layout on the original weight array. The
-      // data conversion happens after the weight array is used.
-      weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
-      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
-                              param.conv_param.num_group);
-
+      // We also need to modify the layout on the original weight array. The data conversion happens
+      // after the weight array is used.
+      weight.MKLDNNDataReorderAsync(fwd->GetPd().weights_desc());
+      weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
     } else {
       weight_mem = weight.GetMKLDNNData();
-      CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
+      CHECK(weight_mem->get_desc() == fwd->GetPd().weights_desc());
     }
   }
   mkldnn_output_t out_mem;
   if (param.mkldnn_param.with_sum) {
-    out_mem = mkldnn_output_t(
-        OutDataOp::Noop,
-        const_cast<mkldnn::memory *>(out_data[conv::kOut].GetMKLDNNData()));
+    out_mem = mkldnn_output_t(OutDataOp::Noop,
+                              const_cast<mkldnn::memory *>(out_data[conv::kOut].GetMKLDNNData()));
   } else {
-    out_mem = CreateMKLDNNMem(out_data[conv::kOut],
-                              fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]);
+    out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->GetPd().dst_desc(), req[conv::kOut]);
   }
 
-  const mkldnn::memory *bias_mem = nullptr;
+  mkldnn_args_map_t net_args;
   if (!no_bias) {
-    bias_mem = in_data[conv::kBias].GetMKLDNNData();
+    const mkldnn::memory *bias_mem = in_data[conv::kBias].GetMKLDNNData();
+    net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
   }
-  fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
-  MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
 
+  net_args.insert({MKLDNN_ARG_SRC, *data_mem});
+  net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem});
+  net_args.insert({MKLDNN_ARG_DST, *out_mem.second});
+  MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), net_args);
   CommitOutput(out_data[conv::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }
 
-void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
                               const std::vector<NDArray> &in_data,
                               const std::vector<OpReqType> &req,
                               const std::vector<NDArray> &out_data) {
   MKLDNNConvFullParam param;
   param.conv_param = nnvm::get<ConvolutionParam>(attrs.parsed);
   param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
-  auto &fwd = GetConvFwd(
-      param.conv_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight],
-      param.conv_param.no_bias ? nullptr : &in_data[conv::kBias],
-      out_data[conv::kOut]);
+  auto &fwd =
+      GetConvFwd(param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight],
+                 param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]);
   MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data);
 }
 
-class MKLDNNConvBackward {
-  std::shared_ptr<mkldnn::convolution_backward_data> bwd_data;
-  std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight;
-  // conv::kData
-  std::shared_ptr<mkldnn::memory> out_grad;
-  std::shared_ptr<mkldnn::memory> in_grad;
-  std::shared_ptr<mkldnn::memory> weight;
-  // conv::kWeight
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> output;
-  std::shared_ptr<mkldnn::memory> in_grad_weight;
-  std::shared_ptr<mkldnn::memory> in_grad_bias;
-
- public:
-  mkldnn::convolution_backward_data::primitive_desc bwdData_pd;
-  mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd;
-
-  MKLDNNConvBackward(
-      const ConvolutionParam &param, const NDArray &data,
-      const NDArray &weights, const NDArray *bias, const NDArray &output,
-      const mkldnn::convolution_forward::primitive_desc &fwd_pd):
-      bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)),
-      bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) {
-  }
-
-  void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight,
-                     const mkldnn::memory &in_grad) {
-    if (this->out_grad == nullptr)
-      this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-        bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
-    else
-      this->out_grad->set_data_handle(out_grad.get_data_handle());
-    if (this->in_grad == nullptr)
-      this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-        bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
-    else
-      this->in_grad->set_data_handle(in_grad.get_data_handle());
-    if (this->weight == nullptr)
-      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-         bwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
-    else
-      this->weight->set_data_handle(weight.get_data_handle());
-    if (this->bwd_data == nullptr)
-      this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>(
-        new mkldnn::convolution_backward_data(
-          this->bwdData_pd, mkldnn::primitive::at(*this->out_grad),
-          mkldnn::primitive::at(*this->weight), *this->in_grad));
-  }
-
-  void SetWeightNewMem(const mkldnn::memory &data,
-                       const mkldnn::memory &out_grad,
-                       const mkldnn::memory &in_grad_weight) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
-    else
-      this->data->set_data_handle(data.get_data_handle());
-    if (this->output == nullptr)
-      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
-    else
-      this->output->set_data_handle(out_grad.get_data_handle());
-    if (this->in_grad_weight == nullptr)
-      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
-          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
-                             in_grad_weight.get_data_handle()));
-    else
-      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
-
-    if (this->bwd_weight == nullptr)
-      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
-          new mkldnn::convolution_backward_weights(
-              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
-              mkldnn::primitive::at(*this->output), *this->in_grad_weight));
-  }
-
-  void SetWeightNewMem(const mkldnn::memory &data,
-                       const mkldnn::memory &out_grad,
-                       const mkldnn::memory &in_grad_weight,
-                       const mkldnn::memory &in_grad_bias) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-          bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
-    else
-      this->data->set_data_handle(data.get_data_handle());
-    if (this->output == nullptr)
-      this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-          bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
-    else
-      this->output->set_data_handle(out_grad.get_data_handle());
-    if (this->in_grad_weight == nullptr)
-      this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
-          new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
-                             in_grad_weight.get_data_handle()));
-    else
-      this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
-
-    if (this->in_grad_bias == nullptr)
-      this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
-          new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(),
-                             in_grad_bias.get_data_handle()));
-    else
-      this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());
-    if (this->bwd_weight == nullptr)
-      this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
-          new mkldnn::convolution_backward_weights(
-              this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
-              mkldnn::primitive::at(*this->output), *this->in_grad_weight,
-              *this->in_grad_bias));
-  }
-
-  const mkldnn::convolution_backward_data &GetBwdData() const {
-    return *bwd_data;
-  }
-
-  const mkldnn::convolution_backward_weights &GetBwdWeights() const {
-    return *bwd_weight;
-  }
-};
+MKLDNNConvBackward::MKLDNNConvBackward(const MKLDNNConvFullParam &param, const NDArray &data,
+                                       const NDArray &weight, const NDArray *bias,
+                                       const NDArray &output) {
+  const auto fwd_pd = GetConvFwdImpl(param, true, data, weight, bias, output);
+  bwd_data_pd_ = GetConvBwdData(param.conv_param, data, weight, output, *fwd_pd);
+  bwd_weight_pd_ = GetConvBwdWeights(param.conv_param, data, weight, bias, output, *fwd_pd);
+  bwd_data_ = std::make_shared<mkldnn::convolution_backward_data>(GetDataPd());
+  bwd_weight_ = std::make_shared<mkldnn::convolution_backward_weights>(GetWeightsPd());
+}
 
-static inline MKLDNNConvBackward &GetConvBwd(
-    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights,
-    const NDArray *bias, const NDArray &output,
-    const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
+static inline MKLDNNConvBackward &GetConvBwd(const MKLDNNConvFullParam &param, const NDArray &data,
+                                             const NDArray &weight, const NDArray *bias,
+                                             const NDArray &output) {
+  using mkldnn_conv_bwd_map = std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash>;
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+  static thread_local mkldnn_conv_bwd_map bwds;
 #else
-  static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+  static MX_THREAD_LOCAL mkldnn_conv_bwd_map bwds;
 #endif
-  const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
-  MKLDNNConvSignature key(param);
-  // Here we can sign the conv op with NDArray because conv primitive will
-  // decide the right layout for the, so we only need to get the shape and the
-  // data type of the arrays.
+  // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for
+  // fused conv
+  MKLDNNConvSignature key(param.conv_param);
+  // Here we can sign the conv op with NDArray because conv primitive will decide the right layout
+  // for the, so we only need to get the shape and the data type of the arrays.
   key.AddSign(data);
-  key.AddSign(weights);
+  key.AddSign(weight);
   key.AddSign(output);
-  if (bias)
-    key.AddSign(*bias);
-
+  if (bias) key.AddSign(*bias);
 
   auto it = bwds.find(key);
   if (it == bwds.end()) {
-    MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
+    auto bwd = MKLDNNConvBackward(param, data, weight, bias, output);
     it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
@@ -617,69 +460,53 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
   full_param.conv_param = nnvm::get<ConvolutionParam>(attrs.parsed);
   full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
 
-  auto data = inputs[conv::kData + 1];
-  if (data.IsView() && data.IsMKLDNNData())
-    data = data.Reorder2Default();
+  auto &data = inputs[conv::kData + 1];
+  auto &weight = inputs[conv::kWeight + 1];
+  const auto *bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1];
+  auto &out_grad = inputs[conv::kOut];
 
-  auto weight = inputs[conv::kWeight + 1];
-  if (weight.IsView() && weight.IsMKLDNNData())
-    weight = weight.Reorder2Default();
-
-  const NDArray* bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1];
-
-  auto out_grad = inputs[conv::kOut];
-  if (out_grad.IsView() && out_grad.IsMKLDNNData())
-    out_grad = out_grad.Reorder2Default();
-
-  mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(
-      full_param, ctx.is_train, data, weight, bias, out_grad);
   const ConvolutionParam &param = full_param.conv_param;
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  MKLDNNConvBackward &convBwd = GetConvBwd(attrs, data,
-      weight, bias, out_grad, fwd_pd);
-  auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
-      convBwd.bwdData_pd.diff_dst_primitive_desc());
+  MKLDNNConvBackward &convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
+  auto out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetDataPd().diff_dst_desc());
   if (req[conv::kData]) {
-    auto weight_mem = GetWeights(weight,
-        convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
-    auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
-        convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
-    convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
-    MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
+    auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group);
+    auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], convBwd.GetDataPd().diff_src_desc(),
+                                       req[conv::kData]);
+    MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdData(),
+                                          {{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+                                           {MKLDNN_ARG_WEIGHTS, *weight_mem},
+                                           {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}});
     CommitOutput(in_grad[conv::kData], in_grad_mem);
   }
   if (req[conv::kWeight] || req[conv::kBias]) {
-    MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data,
-        weight, bias, out_grad, fwd_pd);
-    if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
-        convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
-      out_grad_mem = out_grad.GetMKLDNNDataReorder(
-          convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
-    auto data_mem = data.GetMKLDNNDataReorder(
-        convBwdWeight.bwdWeights_pd.src_primitive_desc());
+    if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc())
+      out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetWeightsPd().diff_dst_desc());
+    auto data_mem = data.GetMKLDNNDataReorder(convBwd.GetWeightsPd().src_desc());
     auto in_grad_weight = CreateMKLDNNWeightGrad(
-        in_grad[conv::kWeight],
-        convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
-        req[conv::kWeight]);
-    if (param.no_bias) {
-      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
-          *in_grad_weight.second);
-      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
-    } else {
-      auto in_grad_bias = CreateMKLDNNMem(
-          in_grad[conv::kBias],
-          convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
-      convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
-          *in_grad_weight.second, *in_grad_bias.second);
-      MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
-      CommitOutput(in_grad[conv::kBias], in_grad_bias);
+        in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]);
+
+    mkldnn_args_map_t net_args = {{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+                                  {MKLDNN_ARG_SRC, *data_mem},
+                                  {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}};
+    mkldnn_output_t in_grad_bias;
+    if (!param.no_bias) {
+      in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
+                                          convBwd.GetWeightsPd().diff_bias_desc(),
+                                          req[conv::kBias]);
+      net_args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
     }
+    MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdWeights(), net_args);
     CommitOutput(in_grad[conv::kWeight], in_grad_weight);
+    // CommitOutput Should run after RegisterPrimArgs for memory dependency
+    if (!param.no_bias) {
+      CommitOutput(in_grad[conv::kBias], in_grad_bias);
+    }
   }
   MKLDNNStream::Get()->Submit();
 }
 
 }  // namespace op
 }  // namespace mxnet
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 122ad9f..ddfcecc 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -54,16 +54,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                       const std::vector<OpReqType> &req,
                       const std::vector<NDArray> &outputs);
 
-/* For convolution. */
-void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                              const std::vector<NDArray> &in_data,
-                              const std::vector<OpReqType> &req,
-                              const std::vector<NDArray> &out_data);
-void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                               const std::vector<NDArray>& inputs,
-                               const std::vector<OpReqType>& req,
-                               const std::vector<NDArray>& outputs);
-
 /* For deconvolution */
 void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                                 const std::vector<NDArray> &in_data,
@@ -133,6 +123,17 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
 #endif
 
 #if MXNET_USE_MKLDNN == 100
+/* For convolution. */
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                              const std::vector<NDArray> &in_data,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<NDArray> &out_data);
+void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<NDArray>& outputs);
+
+
 void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
          const mkldnn::memory &out);
 #endif
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 5290c09..753be48 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -526,17 +526,46 @@ class OpSignature {
    * and the layout to sign the op.
    */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   void AddSign(const mkldnn::memory &mem) {
-    auto desc = mem.get_primitive_desc().desc();
-    hash = hash * 2 + desc.data.format;
-    eles.push_back(desc.data.format);
+    auto desc = mem.get_desc();
+    hash = hash * 2 + desc.data.format_kind;
+    eles.push_back(desc.data.format_kind);
     hash = hash * 2 + desc.data.data_type;
     eles.push_back(desc.data.data_type);
     for (int i = 0; i < desc.data.ndims; i++) {
       hash = hash * 2 + desc.data.dims[i];
       eles.push_back(desc.data.dims[i]);
     }
+    switch (desc.data.format_kind) {
+      case mkldnn_blocked:
+        hash = hash * 2 + desc.data.ndims;
+        eles.push_back(desc.data.ndims);
+        for (int i = 0; i < desc.data.ndims; i++) {
+          hash = hash * 2 + desc.data.format_desc.blocking.strides[i];
+          eles.push_back(desc.data.format_desc.blocking.strides[i]);
+        }
+        hash = hash * 2 + desc.data.format_desc.blocking.inner_nblks;
+        eles.push_back(desc.data.format_desc.blocking.inner_nblks);
+        for (int i = 0; i < desc.data.format_desc.blocking.inner_nblks; i++) {
+          hash = hash * 2 + desc.data.format_desc.blocking.inner_blks[i];
+          hash = hash * 2 + desc.data.format_desc.blocking.inner_idxs[i];
+          eles.push_back(desc.data.format_desc.blocking.inner_blks[i]);
+          eles.push_back(desc.data.format_desc.blocking.inner_idxs[i]);
+        }
+        break;
+      case mkldnn_format_kind_wino:
+        hash = hash * 2 + desc.data.format_desc.wino_desc.wino_format;
+        eles.push_back(desc.data.format_desc.wino_desc.wino_format);
+        break;
+      case mkldnn_format_kind_rnn_packed:
+        hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format;
+        eles.push_back(desc.data.format_desc.rnn_packed_desc.format);
+        break;
+      default:
+      // nothing need to add
+        break;
+    }
   }
 #endif
 
@@ -547,7 +576,7 @@ class OpSignature {
   }
 
   void AddSign(const NDArray &arr) {
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     if (arr.IsMKLDNNData()) {
       AddSign(*(arr.GetMKLDNNData()));
     } else {
@@ -555,7 +584,7 @@ class OpSignature {
       hash = hash * 2 + arr.dtype();
       eles.push_back(arr.dtype());
       AddSign(arr.shape());
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     }
 #endif
   }
diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h
index 93606fc..4a8a273 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -34,7 +34,7 @@
 #ifdef __CUDACC__
 #include "./cast_storage-inl.cuh"
 #endif  // __CUDACC__
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "../nn/mkldnn/mkldnn_base-inl.h"
 #endif
 
@@ -397,7 +397,7 @@ void CastStorageComputeImpl(const OpContext& ctx,
   } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) {
     NDArray ret = output;
     CastStorageRspRspImpl<xpu>(ctx, input, &ret);
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) {
     CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type);
     // If one of them uses the MKLDNN layout.
@@ -449,7 +449,7 @@ inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs,
   if (!dispatched && in_stype == kDefaultStorage && param_stype == kDefaultStorage) {
     // dns -> dns
     DispatchMode mode = DispatchMode::kFCompute;
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
     // If we use MKLDNN and the arrays are in CPU memory, the array may store
     // MKLDNN layout, we should convert its layout explicitly.
     if (dev_mask == kCPU)