You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/09/24 09:45:56 UTC

[incubator-mxnet] branch v1.x updated: Decouple OneDNN data structures in MXNet C++ API (#20349)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 868d06b  Decouple OneDNN data structures in MXNet C++ API (#20349)
868d06b is described below

commit 868d06b236ee1005d40b43b7529171ce432605f8
Author: mozga <ma...@intel.com>
AuthorDate: Fri Sep 24 11:43:53 2021 +0200

    Decouple OneDNN data structures in MXNet C++ API (#20349)
    
    * NDArry file has been modified, there are a few chnages:
    1. OneDNN header was moved into *cc file
    2. OneDNN object are created on-the-fly: static_cast is needed
    
    * Santiy check
    
    * Adaptive pooling: fix
---
 include/mxnet/ndarray.h                            |  24 +-
 src/ndarray/ndarray.cc                             | 112 ++++-----
 src/operator/nn/concat.cc                          | 259 +++++++++++----------
 src/operator/nn/mkldnn/mkldnn_act.cc               |  25 +-
 src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc  |   9 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h           |   9 +-
 src/operator/nn/mkldnn/mkldnn_base.cc              |  95 ++++----
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h     |  50 ++--
 src/operator/nn/mkldnn/mkldnn_concat.cc            |   6 +-
 src/operator/nn/mkldnn/mkldnn_convolution.cc       |  41 ++--
 src/operator/nn/mkldnn/mkldnn_copy.cc              |  15 +-
 src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h  |  59 ++---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc     |  12 +-
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   |  34 ++-
 src/operator/nn/mkldnn/mkldnn_log_softmax.cc       |  22 +-
 src/operator/nn/mkldnn/mkldnn_lrn-inl.h            |  25 +-
 src/operator/nn/mkldnn/mkldnn_pooling.cc           |  22 +-
 src/operator/nn/mkldnn/mkldnn_reshape.cc           |  16 +-
 src/operator/nn/mkldnn/mkldnn_slice.cc             |   8 +-
 src/operator/nn/mkldnn/mkldnn_softmax.cc           |  22 +-
 src/operator/nn/mkldnn/mkldnn_softmax_output.cc    |   4 +-
 src/operator/nn/mkldnn/mkldnn_sum.cc               |   2 +-
 src/operator/nn/mkldnn/mkldnn_transpose.cc         |   5 +-
 src/operator/operator_common.h                     |   6 +-
 .../quantization/mkldnn/mkldnn_dequantize-inl.h    |   2 +-
 .../quantization/mkldnn/mkldnn_quantize-inl.h      |   2 +-
 .../quantization/mkldnn/mkldnn_quantize_v2-inl.h   |   5 +-
 .../mkldnn/mkldnn_quantized_batch_norm.cc          |  12 +-
 .../quantization/mkldnn/mkldnn_quantized_concat.cc |   4 +-
 .../quantization/mkldnn/mkldnn_quantized_conv.cc   |  16 +-
 .../mkldnn/mkldnn_quantized_elemwise_add.cc        |   6 +-
 .../mkldnn/mkldnn_quantized_fully_connected.cc     |  13 +-
 .../quantization/mkldnn/mkldnn_requantize-inl.h    |   2 +-
 src/operator/subgraph/mkldnn/mkldnn_common.h       |  12 +-
 src/operator/subgraph/mkldnn/mkldnn_conv.cc        |  37 +--
 src/operator/subgraph/mkldnn/mkldnn_fc.cc          |  53 ++---
 src/operator/tensor/amp_cast.cc                    |   7 +-
 src/operator/tensor/cast_storage-inl.h             |   4 +-
 tests/cpp/include/test_mkldnn.h                    |   6 +-
 tests/cpp/operator/mkldnn_operator_test.cc         |   3 +-
 tests/cpp/operator/mkldnn_test.cc                  | 169 +++++++-------
 41 files changed, 665 insertions(+), 570 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 0febc65..92aefd6 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -37,9 +37,7 @@
 #include <memory>
 #include <string>
 #include <vector>
-#if MXNET_USE_MKLDNN == 1
-#include <mkldnn.hpp>
-#endif
+
 #include "./base.h"
 #include "./engine.h"
 #include "./storage.h"
@@ -711,12 +709,12 @@ class NDArray {
    * Create NDArray from mkldnn memory.
    * mkldnn_mem The mkldnn memory to be managed.
    */
-  explicit NDArray(const std::shared_ptr<mkldnn::memory>& mkldnn_mem);
+  explicit NDArray(const std::shared_ptr<void>& mkldnn_mem);
   /*
    * Create NDArray from mkldnn memory descriptor.
    * mem_pd The mkldnn memory descriptor to be created.
    */
-  explicit NDArray(const mkldnn::memory::desc& md);
+  explicit NDArray(const void* md);
   /*
    * Test if the data is stored in one of special MKLDNN format.
    */
@@ -739,29 +737,29 @@ class NDArray {
   /*
    * This function returns mkldnn::memory with the default primitive_desc.
    */
-  const mkldnn::memory* GetMKLDNNData() const;
+  const void* GetMKLDNNData() const;
   /*
    * This function returns mkldnn::memory with the given primitive_desc
    * as long as the array size meets the required size in the given
    * primitive_desc.
    */
-  const mkldnn::memory* GetMKLDNNData(const mkldnn::memory::desc& md) const;
+  const void* GetMKLDNNData(const void* md) const;
   /*
    * This function returns mkldnn::memory with the given primitive_desc.
    * The returned mkldnn::memory will have the same physical layout as
    * the given primitive_desc.
    */
-  const mkldnn::memory* GetMKLDNNDataReorder(const mkldnn::memory::desc& md) const;
+  const void* GetMKLDNNDataReorder(const void* md) const;
 
   /*
    * This function copies data from mkldnn memory.
    */
-  void CopyFrom(const mkldnn::memory& mem);
+  void CopyFrom(const void* mem);
   /*
    * This function allocates memory for array and creates mkldnn memory
    * with the specified format.
    */
-  mkldnn::memory* CreateMKLDNNData(const mkldnn::memory::desc& md);
+  void* CreateMKLDNNData(const void* md);
 
   /*
    * These are the async version of the methods above.
@@ -769,7 +767,7 @@ class NDArray {
    * the array are complete.
    */
   void Reorder2DefaultAsync() const;
-  void MKLDNNDataReorderAsync(const mkldnn::memory::desc& md) const;
+  void MKLDNNDataReorderAsync(const void* md) const;
 
   /*
    * This creates a new NDArray with the reordered data.
@@ -800,7 +798,7 @@ class NDArray {
   /*!
    * \ Fix mkldnn memory descriptor mismatch from NDArray.
    */
-  void UpdateMKLDNNMemDesc(const mkldnn::memory::desc& desc);
+  void UpdateMKLDNNMemDesc(const void* desc);
 #endif
 
   /*!
@@ -1087,7 +1085,7 @@ class NDArray {
     // save the result in shandle.
     void Reorder2Default();
     // Reroder data to a specified layout.
-    void MKLDNNDataReorder(const mkldnn::memory::desc& md);
+    void MKLDNNDataReorder(const void* md);
     bool IsMKLDNN() const;
     bool IsDefault() const;
 #endif
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index c7188e3..46262ed 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -37,7 +37,9 @@
 #include "../operator/tensor/init_op.h"
 #include "../operator/tensor/matrix_op-inl.h"
 #include "./ndarray_function.h"
-
+#if MXNET_USE_MKLDNN == 1
+#include <mkldnn.hpp>
+#endif
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
 #endif  // MXNET_USE_OPENCV
@@ -147,8 +149,7 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape& shape, int dtype) {
   if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
     CHECK_LT(shape.Size(), (int64_t{1} << 31) - 1)
         << "[CheckAndAllocData] Size of tensor you are trying to allocate is "
-           "larger than "
-           "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+           "larger than 2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
   }
   if (shandle.size < dbytes) {
     // free storage
@@ -187,16 +188,19 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
 
 #if MXNET_USE_MKLDNN == 1
 
-NDArray::NDArray(const mkldnn::memory::desc& md) : storage_type_(kDefaultStorage), entry_(nullptr) {
-  shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
-  dtype_ = get_mxnet_type(md.data.data_type);
-  ptr_   = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+NDArray::NDArray(const void* md_desc) : storage_type_(kDefaultStorage), entry_(nullptr) {
+  mkldnn::memory::desc md = *static_cast<const mkldnn::memory::desc*>(md_desc);
+  shape_                  = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
+  dtype_                  = get_mxnet_type(md.data.data_type);
+  ptr_                    = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
   ptr_->CheckAndAlloc(md.get_size());
   ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(md, ptr_->shandle.dptr);
 }
 
-NDArray::NDArray(const std::shared_ptr<mkldnn::memory>& mkldnn_mem)
+NDArray::NDArray(const std::shared_ptr<void>& mkldnn_mem_ptr)
     : storage_type_(kDefaultStorage), entry_(nullptr) {
+  std::shared_ptr<mkldnn::memory> mkldnn_mem =
+      std::static_pointer_cast<mkldnn::memory>(mkldnn_mem_ptr);
   auto mem_desc      = mkldnn_mem->get_desc();
   shape_             = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
   dtype_             = get_mxnet_type(mem_desc.data.data_type);
@@ -399,8 +403,7 @@ bool NDArray::fresh_out_grad() const {
 
 void NDArray::set_fresh_out_grad(bool state) const {
   CHECK(!Imperative::AGInfo::IsNone(*this))
-      << "NDArray has not been marked as a variable and does not have gradient "
-         "state";
+      << "NDArray has not been marked as a variable and does not have gradient state";
   Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
   info.fresh_out_grad      = state;
 }
@@ -444,7 +447,8 @@ void NDArray::Chunk::Reorder2Default() {
   mkl_mem_ = nullptr;
 }
 
-void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::desc& md) {
+void NDArray::Chunk::MKLDNNDataReorder(const void* mem_desc) {
+  const mkldnn::memory::desc md = *static_cast<const mkldnn::memory::desc*>(mem_desc);
   // If the memory already uses the specified layout, don't do anything.
   if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(md))
     return;
@@ -530,12 +534,13 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape& shape, int dtype) {
   mkl_mem_.reset(new MKLDNNMemory(data_md, shandle.dptr));
 }
 
-const mkldnn::memory* NDArray::GetMKLDNNData(const mkldnn::memory::desc& desc) const {
+const void* NDArray::GetMKLDNNData(const void* mem_desc) const {
+  const mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
     return nullptr;
   }
-  const mkldnn::memory* mem  = GetMKLDNNData();
+  const mkldnn::memory* mem  = static_cast<const mkldnn::memory*>(GetMKLDNNData());
   mkldnn::memory::desc desc1 = mem->get_desc();
   // The MKL memory has the same format and shape as required,
   // or both use the default format, we can return the MKL memory.
@@ -546,10 +551,11 @@ const mkldnn::memory* NDArray::GetMKLDNNData(const mkldnn::memory::desc& desc) c
   }
 }
 
-const mkldnn::memory* NDArray::GetMKLDNNDataReorder(const mkldnn::memory::desc& new_desc) const {
+const void* NDArray::GetMKLDNNDataReorder(const void* mem_desc) const {
+  mkldnn::memory::desc new_desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
   CHECK(storage_type() == kDefaultStorage);
 
-  const mkldnn::memory* mem = GetMKLDNNData();
+  const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
   // If the memory descriptor matches, it's easy.
   MKLDNNStream* stream = MKLDNNStream::Get();
   if (mem->get_desc() == new_desc) {
@@ -578,7 +584,7 @@ const mkldnn::memory* NDArray::GetMKLDNNDataReorder(const mkldnn::memory::desc&
     for (int i = 0; i < new_desc.data.ndims; i++)
       required_shape[i] = new_desc.data.dims[i];
     NDArray reshaped          = MKLDNNDataReshape(required_shape);
-    const mkldnn::memory* ret = reshaped.GetMKLDNNData();
+    const mkldnn::memory* ret = static_cast<const mkldnn::memory*>(reshaped.GetMKLDNNData());
     if (ret->get_desc() == new_desc) {
       return GetMKLDNNExact(ret, new_desc);
     } else {
@@ -641,24 +647,25 @@ NDArray NDArray::Reorder2DefaultFloatFormat() const {
     return Reorder2Default();
   }
   NDArray ret(shape(), ctx(), false, mshadow::DataType<float>::kFlag);
-  auto src_mem = GetMKLDNNData();
-  auto dst_mem = ret.GetMKLDNNData();
+  auto src_mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
+  auto dst_mem = static_cast<const mkldnn::memory*>(ret.GetMKLDNNData());
   ReorderTo(src_mem, dst_mem);
 
   return ret;
 }
 
-void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const {
+void NDArray::MKLDNNDataReorderAsync(const void* mem_desc) const {
+  mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
   std::vector<Engine::VarHandle> const_vars;
   std::vector<Engine::VarHandle> mutable_vars(1, this->var());
   NDArray tmp        = *this;
   const auto version = this->version();
   Engine::Get()->PushAsync(
       [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
-        // MXNet will try to reuse NDArray from memory planning, so we need to
-        // ensure the NDArray is still holding the original trunk data.
+        // MXNet will try to reuse NDArray from memory planning, so we need to ensure
+        // the NDArray is still holding the original trunk data.
         if (tmp.version() == version) {
-          tmp.ptr_->MKLDNNDataReorder(desc);
+          tmp.ptr_->MKLDNNDataReorder(&desc);
         }
         on_complete();
       },
@@ -670,7 +677,7 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const {
       "Reorder");
 }
 
-const mkldnn::memory* NDArray::GetMKLDNNData() const {
+const void* NDArray::GetMKLDNNData() const {
   CHECK(storage_type() == kDefaultStorage);
   bool is_view = IsView();
   if (IsMKLDNNData()) {
@@ -715,7 +722,8 @@ void NDArray::InvalidateMKLDNNData() {
     ptr_->mkl_mem_ = nullptr;
 }
 
-void NDArray::CopyFrom(const mkldnn::memory& mem) {
+void NDArray::CopyFrom(const void* memory) {
+  auto mem = *static_cast<const mkldnn::memory*>(memory);
   CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
   if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem)
     return;
@@ -728,14 +736,14 @@ void NDArray::CopyFrom(const mkldnn::memory& mem) {
   if (IsMKLDNNData() && IsView())
     ptr_->Reorder2Default();
 
-  const mkldnn::memory* this_mem = GetMKLDNNData();
+  const mkldnn::memory* this_mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
   MKLDNNMemoryCopy(mem, this_mem);
 }
 
-mkldnn::memory* NDArray::CreateMKLDNNData(const mkldnn::memory::desc& desc) {
+void* NDArray::CreateMKLDNNData(const void* mem_desc) {
+  mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
-    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN "
-                  "memory desc. "
+    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc. "
                << "MKLDNN memory requests for " << desc.get_size() << " bytes, but got "
                << shape().Size() * GetTypeSize(dtype_) << " bytes from NDArray";
     return nullptr;
@@ -779,10 +787,11 @@ mkldnn::memory* NDArray::CreateMKLDNNData(const mkldnn::memory::desc& desc) {
   return ptr_->mkl_mem_->GetRaw();
 }
 
-void NDArray::UpdateMKLDNNMemDesc(const mkldnn::memory::desc& desc) {
-  auto new_desc           = desc;
-  auto this_dtype         = get_mkldnn_type(dtype());
-  new_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
+void NDArray::UpdateMKLDNNMemDesc(const void* mem_desc) {
+  mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
+  auto new_desc             = desc;
+  auto this_dtype           = get_mkldnn_type(dtype());
+  new_desc.data.data_type   = static_cast<mkldnn_data_type_t>(this_dtype);
   ptr_->mkl_mem_.reset(new MKLDNNMemory(new_desc, ptr_->shandle.dptr));
   MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
 }
@@ -1203,15 +1212,15 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
 #if MXNET_USE_MKLDNN == 1
   } else if (SupportMKLDNN(from.dtype(), from.shape()) && SupportMKLDNN(to.dtype(), to.shape()) &&
              from.ctx().dev_mask() == cpu::kDevMask && to.ctx().dev_mask() == cpu::kDevMask) {
-    // If we copy data directly, we need to make sure both NDArrays are
-    // supported by MKLDNN.
-    auto from_mem = from.GetMKLDNNData();
-    auto to_mem   = to.GetMKLDNNData();
+    // If we copy data directly, we need to make sure both NDArrays are supported
+    // by MKLDNN.
+    auto from_mem = static_cast<const mkldnn::memory*>(from.GetMKLDNNData());
+    auto to_mem   = static_cast<const mkldnn::memory*>(to.GetMKLDNNData());
     if (from_mem->get_desc() == to_mem->get_desc()) {
       size_t size = std::min(from_mem->get_desc().get_size(), to_mem->get_desc().get_size());
       memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size);
     } else {
-      const_cast<NDArray&>(to).CopyFrom(*from_mem);
+      const_cast<NDArray&>(to).CopyFrom(from_mem);
       MKLDNNStream::Get()->Submit();
     }
   } else {
@@ -1222,8 +1231,8 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
     if (tmp_from.IsMKLDNNData()) {
       // TODO(zhengda) tmp_from should be cached.
       tmp_from     = NDArray(from.shape(), from.ctx(), false, from.dtype());
-      auto tmp_mem = from.GetMKLDNNData();
-      tmp_from.CopyFrom(*tmp_mem);
+      auto tmp_mem = static_cast<const mkldnn::memory*>(from.GetMKLDNNData());
+      tmp_from.CopyFrom(tmp_mem);
       MKLDNNStream::Get()->Submit();
     }
     CHECK(tmp_from.IsDefaultData());
@@ -2053,8 +2062,7 @@ void NDArray::SyncCopyFromCPU(const void* data, size_t size) const {
   mxnet::TShape dshape = this->shape();
   if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
     CHECK_LT(size, (int64_t{1} << 31) - 1)
-        << "[SyncCopyFromCPU] Size of tensor you are trying to allocate is "
-           "larger than "
+        << "[SyncCopyFromCPU] Size of tensor you are trying to allocate is larger than "
            "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
   }
   CHECK_EQ(dshape.Size(), size) << "Memory size do not match";
@@ -2217,8 +2225,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const {
   mxnet::TShape dshape = this->shape();
   if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
     CHECK_LT(size, (int64_t{1} << 31) - 1)
-        << "[SyncCopyToCPU] Size of tensor you are trying to allocate is "
-           "larger than "
+        << "[SyncCopyToCPU] Size of tensor you are trying to allocate is larger than "
            "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
   }
   CHECK_EQ(dshape.Size(), size) << "Memory size do not match";
@@ -2290,12 +2297,12 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
   }
   this->WaitToWrite();
   CHECK_NE(err, kCSRShapeErr) << "Shape mismatch of this csr NDArray";
-  CHECK_NE(err, kCSRIndPtrErr) << "IndPtr of csr NDArray should be non-negative, in non-decreasing "
-                                  "order, "
-                               << "start with 0, and end with value equal with size of indices.";
-  CHECK_NE(err, kCSRIdxErr) << "Indices of csr NDArray should be non-negative, "
-                               "in ascending order per row "
-                            << " and less than the number of columns.";
+  CHECK_NE(err, kCSRIndPtrErr)
+      << "IndPtr of csr NDArray should be non-negative, in non-decreasing order, "
+      << "start with 0, and end with value equal with size of indices.";
+  CHECK_NE(err, kCSRIdxErr)
+      << "Indices of csr NDArray should be non-negative, in ascending order per row "
+      << " and less than the number of columns.";
   CHECK_NE(err, kRSPShapeErr) << "Shape mismatch of this row_sparse NDArray";
   CHECK_NE(err, kRSPIdxErr) << "Indices of row_sparse NDArray should be non-negative, "
                             << "less than the size of first dimension and in ascending order";
@@ -2313,8 +2320,7 @@ MXNET_REGISTER_NDARRAY_FUN(fill_element_0index)
     .set_function(TernaryOp<ndarray::MatFillRowElem>)
     .describe(
         "Fill one element of each line(row for python, column for R/Julia)"
-        " in lhs according to index indicated by rhs and values indicated by "
-        "mhs."
+        " in lhs according to index indicated by rhs and values indicated by mhs."
         " This function assume rhs uses 0-based index.");
 
 // register API function
@@ -2464,9 +2470,7 @@ MXNET_REGISTER_NDARRAY_FUN(_imdecode)
     .set_num_use_vars(1)
     .set_num_scalars(7)
     .set_num_mutate_vars(1)
-    .describe(
-        "Decode an image, clip to (x0, y0, x1, y1), subtract mean, and write "
-        "to buffer")
+    .describe("Decode an image, clip to (x0, y0, x1, y1), subtract mean, and write to buffer")
     .add_argument("mean", "NDArray-or-Symbol", "image mean")
     .add_argument("index", "int", "buffer position for output")
     .add_argument("x0", "int", "x0")
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 5a7ece1..de02bd6 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -22,30 +22,30 @@
  * \file concat.cc
  * \brief
  * \author Bing Xu
-*/
+ */
 
+#include "../../common/utils.h"
 #include "./concat-inl.h"
-#include "./mkldnn/mkldnn_ops-inl.h"
 #include "./mkldnn/mkldnn_base-inl.h"
-#include "../../common/utils.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
 
 namespace mxnet {
 namespace op {
 
 bool ConcatShape(const nnvm::NodeAttrs& attrs,
-                 mxnet::ShapeVector *in_shape,
-                 mxnet::ShapeVector *out_shape) {
+                 mxnet::ShapeVector* in_shape,
+                 mxnet::ShapeVector* out_shape) {
   using namespace mshadow;
   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
   mxnet::TShape dshape;
-  dim_t size = 0;
+  dim_t size                = 0;
   bool has_unknown_dim_size = false;
-  int axis = -1;
+  int axis                  = -1;
   for (int i = 0; i < param_.num_args; ++i) {
     mxnet::TShape tmp = (*in_shape)[i];
     if (tmp.ndim() > 0) {
-      axis = CheckAxis(param_.dim, tmp.ndim());
+      axis                 = CheckAxis(param_.dim, tmp.ndim());
       has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
       size += tmp[axis];
       tmp[axis] = -1;
@@ -55,12 +55,13 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_.dim, tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
 
-  if (dshape.ndim() == -1) return false;
+  if (dshape.ndim() == -1)
+    return false;
   CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
 
   for (int i = 0; i < param_.num_args; ++i) {
@@ -68,7 +69,8 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
   }
 
-  if (!has_unknown_dim_size) dshape[axis] = size;
+  if (!has_unknown_dim_size)
+    dshape[axis] = size;
   CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
 
@@ -80,8 +82,8 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
 // The first (and sometimes the second) input may be unknown on the target axis.
 // If the two inputs are unknown, they always have the same shape.
 static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
-                                mxnet::ShapeVector *in_shape,
-                                mxnet::ShapeVector *out_shape) {
+                                mxnet::ShapeVector* in_shape,
+                                mxnet::ShapeVector* out_shape) {
   using namespace mshadow;
   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
@@ -106,31 +108,32 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_.dim, tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
 
-  if (!mxnet::ndim_is_known(dshape)) return false;
+  if (!mxnet::ndim_is_known(dshape))
+    return false;
 
   for (int i = 0; i < param_.num_args; ++i) {
     CHECK(shape_assign(&(*in_shape)[i], dshape))
         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
   }
 
-  if (zero_indices.empty()) dshape[axis] = size;
+  if (zero_indices.empty())
+    dshape[axis] = size;
   CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
   if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) {
     int residual = (*out_shape)[0][axis] - size;
-    CHECK_GE(residual, 0)
-        << "Input size already exceeds output size. Residual: " << residual;
+    CHECK_GE(residual, 0) << "Input size already exceeds output size. Residual: " << residual;
     CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0)
         << "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
     bool need_infer = !shape_is_known((*out_shape)[0]);
     for (int i : zero_indices) {
       (*in_shape)[i][axis] = residual / zero_indices.size();
-      need_infer = need_infer || !shape_is_known((*in_shape)[i]);
+      need_infer           = need_infer || !shape_is_known((*in_shape)[i]);
     }
     return !need_infer;
   }
@@ -139,21 +142,20 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
 }
 
 bool ConcatType(const nnvm::NodeAttrs& attrs,
-                std::vector<int> *in_type,
-                std::vector<int> *out_type) {
+                std::vector<int>* in_type,
+                std::vector<int>* out_type) {
   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
-  int dtype = -1;
+  int dtype                 = -1;
 
   // checks uniformity of input
-  for (size_t i =0; i < in_type->size(); ++i) {
+  for (size_t i = 0; i < in_type->size(); ++i) {
     if (dtype == -1) {
       dtype = in_type->at(i);
     } else {
       CHECK(in_type->at(i) == dtype || in_type->at(i) == -1)
-          << "Non-uniform data type in "  << attrs.op->name
-          << ", expected data type " << mxnet::op::type_string(dtype)
-          << ", got data type " << mxnet::op::type_string(in_type->at(i))
-          << " for input " << i;
+          << "Non-uniform data type in " << attrs.op->name << ", expected data type "
+          << mxnet::op::type_string(dtype) << ", got data type "
+          << mxnet::op::type_string(in_type->at(i)) << " for input " << i;
     }
   }
 
@@ -166,18 +168,17 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
     for (size_t i = 0; i < nin; ++i) {
       in_type->push_back(dtype);
     }
-  // if out types are known in types are unknown
+    // if out types are known in types are unknown
   } else if ((*out_type)[0] != -1 && dtype == -1) {
     in_type->clear();
     for (size_t i = 0; i < nin; ++i) {
       in_type->push_back((*out_type)[0]);
     }
-  // if both out_types and in_types are known, and different
+    // if both out_types and in_types are known, and different
   } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) {
     std::ostringstream os;
-    os << "Type inconsistent, Provided output type = "
-       << mxnet::op::type_string((*out_type)[0]) << ','
-       << " inferred type = " << mxnet::op::type_string(dtype);
+    os << "Type inconsistent, Provided output type = " << mxnet::op::type_string((*out_type)[0])
+       << ',' << " inferred type = " << mxnet::op::type_string(dtype);
     throw mxnet::op::InferTypeError(os.str(), 0);
   }
   return true;
@@ -186,29 +187,27 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
 inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                                  const int dev_mask,
                                                  DispatchMode* dispatch_mode,
-                                                 std::vector<int> *in_attrs,
-                                                 std::vector<int> *out_attrs) {
+                                                 std::vector<int>* in_attrs,
+                                                 std::vector<int>* out_attrs) {
   CHECK(!in_attrs->empty());
   CHECK_EQ(out_attrs->size(), 1U);
-  auto& out_stype = out_attrs->at(0);
-  bool dispatched = false;
+  auto& out_stype          = out_attrs->at(0);
+  bool dispatched          = false;
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
-  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage)
-      && param.dim == 0) {
-    dispatched = storage_type_assign(&out_stype, kCSRStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param.dim == 0) {
+    dispatched =
+        storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
   }
 #if MXNET_USE_MKLDNN == 1
-  if (!dispatched && dev_mask == mshadow::cpu::kDevMask
-      && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
-      && param.dim > 0) {
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+  if (!dispatched && dev_mask == mshadow::cpu::kDevMask &&
+      common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0) {
+    dispatched =
+        storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
   }
 #endif  // MXNET_USE_MKLDNN == 1
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFCompute);
+    dispatched =
+        storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -223,15 +222,14 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
 inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
                                              const int dev_mask,
                                              DispatchMode* dispatch_mode,
-                                             std::vector<int> *in_attrs,
-                                             std::vector<int> *out_attrs) {
+                                             std::vector<int>* in_attrs,
+                                             std::vector<int>* out_attrs) {
   DispatchMode wanted_mode;
 #if MXNET_USE_MKLDNN == 1
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
-  if (dev_mask == mshadow::cpu::kDevMask
-      && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
-      && param.dim > 0)
+  if (dev_mask == mshadow::cpu::kDevMask &&
+      common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0)
     wanted_mode = DispatchMode::kFComputeEx;
   else
 #endif  // MXNET_USE_MKLDNN == 1
@@ -240,19 +238,23 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
   if (!MKLDNNEnvSet())
     wanted_mode = DispatchMode::kFComputeFallback;
 #endif  // MXNET_USE_MKLDNN == 1
-  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
-                             dispatch_mode, wanted_mode);
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode);
 }
 #if MXNET_USE_MKLDNN == 1
-bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
-  for (auto &arr : arrs) {
-    if (arr.IsView()) return false;
-    if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16)) return false;
+bool SupportMKLDNNConcat(const std::vector<NDArray>& arrs) {
+  for (auto& arr : arrs) {
+    if (arr.IsView())
+      return false;
+    if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16))
+      return false;
     // DO not support zero-size tensors.
-    if (arr.shape().Size() == 0) return false;
+    if (arr.shape().Size() == 0)
+      return false;
     int ndim = arr.shape().ndim();
-    const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
-    if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
+    const int mkldnn_ndims =
+        static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc().data.ndims;
+    if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims)
+      return false;
   }
   return true;
 }
@@ -265,7 +267,8 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
   CHECK(!inputs.empty());
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  if (req[0] == kNullOp) return;
+  if (req[0] == kNullOp)
+    return;
   if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
       outputs[0].storage_type() == kCSRStorage) {
     ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
@@ -299,7 +302,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
 #endif  // MXNET_USE_MKLDNN == 1
 
 struct ConcatGrad {
-  const char *op_name;
+  const char* op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
                                           const std::vector<nnvm::NodeEntry>& ograds) const {
     CHECK_EQ(ograds.size(), 1);
@@ -315,38 +318,37 @@ struct ConcatGrad {
 
 DMLC_REGISTER_PARAMETER(ConcatParam);
 
-#define CONCAT_FORWARD_ATTRS \
-.set_num_inputs([](const NodeAttrs& attrs) { \
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
-  return params.num_args; \
-}) \
-.set_num_outputs(1) \
-.set_attr_parser(ParamParser<ConcatParam>) \
-.set_attr<nnvm::FListInputNames>("FListInputNames", \
-    [](const NodeAttrs& attrs) { \
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
-  std::vector<std::string> ret; \
-  for (int i = 0; i < params.num_args; ++i) { \
-    ret.push_back(std::string("arg") + std::to_string(i)); \
-  } \
-  return ret; \
-}) \
-.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
-    [](const NodeAttrs& attrs) { \
-    return std::vector<std::string>{"output"}; \
-}) \
-.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
-.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
-.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
-.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
-.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
-.set_attr<std::string>("key_var_num_args", "num_args")
-
+#define CONCAT_FORWARD_ATTRS                                                                      \
+  .set_num_inputs([](const NodeAttrs& attrs) {                                                    \
+    const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);                             \
+    return params.num_args;                                                                       \
+  })                                                                                              \
+      .set_num_outputs(1)                                                                         \
+      .set_attr_parser(ParamParser<ConcatParam>)                                                  \
+      .set_attr<nnvm::FListInputNames>("FListInputNames",                                         \
+                                       [](const NodeAttrs& attrs) {                               \
+                                         const ConcatParam& params =                              \
+                                             nnvm::get<ConcatParam>(attrs.parsed);                \
+                                         std::vector<std::string> ret;                            \
+                                         for (int i = 0; i < params.num_args; ++i) {              \
+                                           ret.push_back(std::string("arg") + std::to_string(i)); \
+                                         }                                                        \
+                                         return ret;                                              \
+                                       })                                                         \
+      .set_attr<nnvm::FListOutputNames>(                                                          \
+          "FListOutputNames",                                                                     \
+          [](const NodeAttrs& attrs) { return std::vector<std::string>{"output"}; })              \
+      .set_attr<nnvm::FInferType>("FInferType", ConcatType)                                       \
+      .set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)            \
+      .set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)                                    \
+      .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)                                \
+      .set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})                     \
+      .set_attr<std::string>("key_var_num_args", "num_args")
 
 NNVM_REGISTER_OP(Concat)
 MXNET_ADD_SPARSE_OP_ALIAS(concat)
-.add_alias("concat")
-.describe(R"code(Joins input arrays along a given axis.
+    .add_alias("concat")
+    .describe(R"code(Joins input arrays along a given axis.
 
 .. note:: `Concat` is deprecated. Use `concat` instead.
 
@@ -384,59 +386,60 @@ Example::
 
 )code" ADD_FILELINE)
 #if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
-  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
-.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
-.set_attr<bool>("TIsMKLDNN", true)
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& n) {
+                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
+    .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
+    .set_attr<bool>("TIsMKLDNN", true)
 #endif  // MXNET_USE_MKLDNN == 1
-CONCAT_FORWARD_ATTRS
-.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
-.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-.add_arguments(ConcatParam::__FIELDS__());
+        CONCAT_FORWARD_ATTRS.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
+    .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+    .add_arguments(ConcatParam::__FIELDS__());
 
 NNVM_REGISTER_OP(_backward_Concat)
-.set_num_inputs([](const NodeAttrs& attrs) {
+    .set_num_inputs([](const NodeAttrs& attrs) {
 #if MXNET_USE_MKLDNN == 1
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
-  return 1 + params.num_args;
+      const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+      return 1 + params.num_args;
 #else
-  return 1;
+      return 1;
 #endif
-})
-.set_num_outputs([](const NodeAttrs& attrs) {
-  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
-  return params.num_args;
-})
-.set_attr_parser(ParamParser<ConcatParam>)
+    })
+    .set_num_outputs([](const NodeAttrs& attrs) {
+      const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+      return params.num_args;
+    })
+    .set_attr_parser(ParamParser<ConcatParam>)
 #if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
-  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& n) {
+                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
 #endif  // MXNET_USE_MKLDNN == 1
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
+    .set_attr<nnvm::TIsBackward>("TIsBackward", true)
+    .set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
 #if MXNET_USE_MKLDNN == 1
-.set_attr<bool>("TIsMKLDNN", true)
-.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
+    .set_attr<bool>("TIsMKLDNN", true)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
 #endif  // MXNET_USE_MKLDNN == 1
-.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
+    .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
 
 // _rnn_param_concat is a custom concat op with specialized infer_shape,
 // which handles the case where the first one or two inputs may have
 // unknown shape that can be inferred from output shape.
 NNVM_REGISTER_OP(_rnn_param_concat)
-.add_alias("_npi_rnn_param_concat")
+    .add_alias("_npi_rnn_param_concat")
 #if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
-  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& n) {
+                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
 #endif  // MXNET_USE_MKLDNN == 1
-CONCAT_FORWARD_ATTRS
-.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
-.set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
-.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-.add_arguments(ConcatParam::__FIELDS__());
+        CONCAT_FORWARD_ATTRS.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
+    .set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
+    .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+    .add_arguments(ConcatParam::__FIELDS__());
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 6f4ac3d..3f73310 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -152,7 +152,7 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs,
   param_.alg               = GetMKLDNNActAlgo(param);
   const NDArray& in_buffer = in_data;
   MKLDNNStream* stream     = MKLDNNStream::Get();
-  auto input_mem           = in_buffer.GetMKLDNNData();
+  auto input_mem           = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   MKLDNNActForward& fwd    = GetActForward(param_, ctx, in_buffer, *input_mem);
   auto out_mem_t           = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
   stream->RegisterPrimArgs(fwd.GetFwd(),
@@ -177,7 +177,7 @@ void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs,
   if (in_data.IsView() && in_data.IsMKLDNNData())
     in_buffer = in_data.Reorder2Default();
 
-  auto input_mem        = in_buffer.GetMKLDNNData();
+  auto input_mem        = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   MKLDNNActForward& fwd = GetActForward(param_, ctx, in_buffer, *input_mem);
   auto out_mem_t        = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
   stream->RegisterPrimArgs(fwd.GetFwd(),
@@ -222,14 +222,15 @@ static inline MKLDNNActBackward& GetActBackward(const MKLDNNActParam& param,
 
   auto it = bwds.find(key);
   if (it == bwds.end()) {
-    MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
+    MKLDNNActBackward bwd(
+        param, in_data, in_mem, *static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData()));
     it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
 
-// For backward relu activation, it's okay to pass "out_data" as "in_data" to
-// this function, since the computation only involes non-zeros.
+// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
+// function, since the computation only involes non-zeros.
 void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
                               const std::vector<NDArray>& inputs,
@@ -247,12 +248,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs,
   MKLDNNActParam param_;
   param_.alg = GetMKLDNNActAlgo(param);
   TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
-  auto diff_dst_memory = out_buffer.GetMKLDNNData();
-  auto input_mem       = in_buffer.GetMKLDNNData();
+  auto diff_dst_memory = static_cast<const mkldnn::memory*>(out_buffer.GetMKLDNNData());
+  auto input_mem       = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   // We need to make sure the two inputs to eltwise_backward has the same memory
   // descriptor. Otherwise, the perf will suffer.
+  auto diff_dst_desc = diff_dst_memory->get_desc();
   if (input_mem->get_desc() != diff_dst_memory->get_desc())
-    input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
+    input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNDataReorder(&diff_dst_desc));
   MKLDNNActBackward& bwd          = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
   MKLDNNStream* stream            = MKLDNNStream::Get();
   mkldnn_output_t diff_src_memory = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
@@ -286,12 +288,13 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
   param_.slope = param.slope;
 
   TmpMemMgr::Get()->Init(ctx.requested[leakyrelu::kRandom]);
-  auto diff_dst_memory = out_buffer.GetMKLDNNData();
-  auto input_mem       = in_buffer.GetMKLDNNData();
+  auto diff_dst_memory = static_cast<const mkldnn::memory*>(out_buffer.GetMKLDNNData());
+  auto input_mem       = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   // We need to make sure the two inputs to eltwise_backward has the same memory
   // descriptor. Otherwise, the perf will suffer.
+  auto diff_dst_desc = diff_dst_memory->get_desc();
   if (input_mem->get_desc() != diff_dst_memory->get_desc())
-    input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
+    input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNDataReorder(&diff_dst_desc));
   MKLDNNActBackward& bwd          = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
   MKLDNNStream* stream            = MKLDNNStream::Get();
   mkldnn_output_t diff_src_memory = CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc b/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
index 3b8f234..13c842e 100644
--- a/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
@@ -36,7 +36,8 @@ void MKLDNNAdaptivePoolingFwd::Init(const mxnet::NDArray& input,
                                     const mkldnn::memory::dims& pad_r,
                                     const bool is_train,
                                     const mkldnn::algorithm alg_kind) {
-  const auto src_md           = input.GetMKLDNNData()->get_desc();
+  const mkldnn::memory* mem   = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
+  const auto src_md           = mem->get_desc();
   const auto dst_md           = GetMemDesc(output);
   const mkldnn::engine engine = CpuEngine::Get()->get_engine();
 
@@ -69,7 +70,7 @@ void MKLDNNAdaptivePoolingFwd::Execute(const NDArray& input,
     in_buffer = input.Reorder2Default();
   }
 
-  auto input_mem    = in_buffer.GetMKLDNNData();
+  auto input_mem    = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   auto output_mem_t = CreateMKLDNNMem(output, this->fwd_pd_->dst_desc(), req);
 
   mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *(output_mem_t.second)}};
@@ -80,7 +81,9 @@ void MKLDNNAdaptivePoolingFwd::Execute(const NDArray& input,
       LOG(FATAL) << "MKLDNN Average Pooling: incorrect worskapce input";
     }
     auto ws = std::make_shared<mkldnn::memory>(
-        this->fwd_pd_->workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+        this->fwd_pd_->workspace_desc(),
+        engine,
+        static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData())->get_data_handle());
     args[MKLDNN_ARG_WORKSPACE] = *ws;
   }
   if (this->fwd_) {
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index e206e9e..698cd21 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -156,8 +156,7 @@ static inline int GetMKLDNNCacheSize() {
   return mkldnn_cache_size;
 }
 
-// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size
-// during runtime
+// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
 template <typename S, typename I, typename H>
 static typename std::unordered_map<S, I, H>::iterator AddToCache(std::unordered_map<S, I, H>* cache,
                                                                  const S& key,
@@ -208,7 +207,8 @@ static int GetTypeSize(int dtype) {
 
 static inline size_t GetArraySize(const NDArray& arr) {
   if (arr.IsMKLDNNData()) {
-    return arr.GetMKLDNNData()->get_desc().get_size();
+    auto arr_data = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
+    return arr_data->get_desc().get_size();
   }
   return arr.shape().Size() * GetTypeSize(arr.dtype());
 }
@@ -532,8 +532,7 @@ static inline void InvalidateOutputs(const std::vector<NDArray>& arrs,
   }
 }
 
-// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature
-// added
+// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added
 static inline void CreateDefaultInputs(const std::vector<NDArray>& arrs,
                                        std::vector<NDArray>* out_arrs) {
   out_arrs->clear();
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 5a65c94..8e1dd74 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -159,9 +159,12 @@ void MKLDNNMemoryCopy(const mkldnn::memory& mem, const mkldnn::memory* this_mem)
 }
 
 bool CanWriteTo(const NDArray& out_arr, const NDArray& in_arr, const mkldnn::memory::desc& desc) {
-  auto in_mem     = in_arr.GetMKLDNNData();
-  bool add_same   = in_mem->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle();
-  bool pdesc_same = out_arr.GetMKLDNNData()->get_desc() == desc && in_mem->get_desc() == desc;
+  auto in_mem   = static_cast<const mkldnn::memory*>(in_arr.GetMKLDNNData());
+  bool add_same = in_mem->get_data_handle() ==
+                  static_cast<const mkldnn::memory*>(out_arr.GetMKLDNNData())->get_data_handle();
+  bool pdesc_same =
+      static_cast<const mkldnn::memory*>(out_arr.GetMKLDNNData())->get_desc() == desc &&
+      in_mem->get_desc() == desc;
   return add_same && pdesc_same;
 }
 
@@ -173,7 +176,8 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray& out_arr,
     auto tmp = TmpMemMgr::Get()->Alloc(desc);
     return mkldnn_output_t(OutDataOp::AddBack, tmp);
   } else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
-    mkldnn::memory* mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+    mkldnn::memory* mem =
+        static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
     // mem is nullptr if out_arr is view and desc is MKLDNN format.
     // need to Reorder2Default before calling CreateMKLDNNMem
     CHECK(mem != nullptr);
@@ -182,7 +186,8 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray& out_arr,
     auto tmp = TmpMemMgr::Get()->Alloc(desc);
     return mkldnn_output_t(OutDataOp::CopyBack, tmp);
   } else if (kWriteTo == req) {
-    mkldnn::memory* mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+    mkldnn::memory* mem =
+        static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
     if (nullptr == mem) {
       auto tmp = TmpMemMgr::Get()->Alloc(desc);
       return mkldnn_output_t(OutDataOp::CopyBack, tmp);
@@ -205,7 +210,7 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray& out_arr,
   } else {
     mkldnn::memory* mem = nullptr;
     if (IsDefaultFormat(desc)) {
-      mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+      mem = static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
     }
     if (mem == nullptr) {
       auto tmp = TmpMemMgr::Get()->Alloc(desc);
@@ -218,16 +223,17 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray& out_arr,
 
 void CommitOutput(const NDArray& arr, const mkldnn_output_t& res) {
   if (res.first == CopyBack) {
-    const_cast<NDArray&>(arr).CopyFrom(*res.second);
+    const_cast<NDArray&>(arr).CopyFrom(res.second);
   } else if (res.first == AddBack) {
     auto res_memory = res.second;
-    auto target_pd  = arr.GetMKLDNNData()->get_desc();
-    auto mem        = arr.GetMKLDNNData(res.second->get_desc());
+    auto target_pd  = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc();
+    auto res_desc   = res.second->get_desc();
+    auto mem        = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&res_desc));
     if (mem == nullptr) {
       auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
       MKLDNNMemoryCopy(*res_memory, tmp_memory);
       res_memory = tmp_memory;
-      mem        = arr.GetMKLDNNData();
+      mem        = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
     }
     op::MKLDNNSum(*mem, *res_memory, *mem);
   }
@@ -283,22 +289,24 @@ const mkldnn::memory* GetWeights(const NDArray& arr, int num_groups) {
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
   }
   const auto md = mkldnn::memory::desc{tz, type, format_tag};
-  return arr.GetMKLDNNData(md);
+  return static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&md));
 }
 
 const mkldnn::memory* GetWeights(const NDArray& arr,
                                  const mkldnn::memory::desc& target_desc,
                                  int num_groups) {
-  const mkldnn::memory* mem = arr.GetMKLDNNData(target_desc);
-  // If the weight array already uses the target layout, simply return it
-  // directly.
-  if (mem)
+  const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&target_desc));
+  // If the weight array already uses the target layout, simply return it directly.
+  if (mem) {
     return mem;
+  }
   mem = GetWeights(arr, num_groups);
-  if (mem == nullptr)
-    mem = arr.GetMKLDNNDataReorder(target_desc);
-  if (mem->get_desc() == target_desc)
+  if (mem == nullptr) {
+    mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&target_desc));
+  }
+  if (mem->get_desc() == target_desc) {
     return mem;
+  }
 
   auto ret = TmpMemMgr::Get()->Alloc(target_desc);
   std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *mem}, {MKLDNN_ARG_TO, *ret}});
@@ -307,8 +315,7 @@ const mkldnn::memory* GetWeights(const NDArray& arr,
 }
 
 // default: block and dims' stride increase monotonically
-// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase
-// monotonically
+// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase monotonically
 bool IsMKLDNN(const mkldnn::memory::desc& desc) {
   bool rslt = true;
   if (desc.data.format_kind == mkldnn_blocked) {
@@ -454,8 +461,8 @@ void FallBackCompute(Compute fn,
   fn(attrs_states, ctx, in_blobs, new_req, out_blobs);
   for (size_t i = 0, bf16_pos = 0; i < out_blobs.size(); i++) {
     if (outputs[i].dtype() == mshadow::kBfloat16) {
-      auto src_mem = temp_bf16_src[bf16_pos].GetMKLDNNData();
-      auto dst_mem = temp_bf16_dst[bf16_pos].GetMKLDNNData();
+      auto src_mem = static_cast<const mkldnn::memory*>(temp_bf16_src[bf16_pos].GetMKLDNNData());
+      auto dst_mem = static_cast<const mkldnn::memory*>(temp_bf16_dst[bf16_pos].GetMKLDNNData());
       bf16_pos++;
       ReorderTo(src_mem, dst_mem);
     } else if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) {
@@ -489,13 +496,13 @@ static bool SimilarArray(const mxnet::NDArray& arr1,
   NDArray buf1, buf2;
   if (arr1.IsMKLDNNData()) {
     buf1     = NDArray(arr1.shape(), arr1.ctx(), false, arr1.dtype());
-    auto mem = arr1.GetMKLDNNData();
-    buf1.CopyFrom(*mem);
+    auto mem = static_cast<const mkldnn::memory*>(arr1.GetMKLDNNData());
+    buf1.CopyFrom(mem);
   }
   if (arr2.IsMKLDNNData()) {
     buf2     = NDArray(arr2.shape(), arr2.ctx(), false, arr2.dtype());
-    auto mem = arr2.GetMKLDNNData();
-    buf2.CopyFrom(*mem);
+    auto mem = static_cast<const mkldnn::memory*>(arr2.GetMKLDNNData());
+    buf2.CopyFrom(mem);
   }
   MKLDNNStream::Get()->Submit();
 
@@ -519,25 +526,25 @@ static bool SimilarArray(const mxnet::NDArray& arr1,
 
 template void FallBackCompute(void (*)(nnvm::NodeAttrs const&,
                                        OpContext const&,
-                                       std::vector<TBlob, std::allocator<TBlob>> const&,
-                                       std::vector<OpReqType, std::allocator<OpReqType>> const&,
-                                       std::vector<TBlob, std::allocator<TBlob>> const&),
+                                       std::vector<TBlob, std::allocator<TBlob> > const&,
+                                       std::vector<OpReqType, std::allocator<OpReqType> > const&,
+                                       std::vector<TBlob, std::allocator<TBlob> > const&),
                               nnvm::NodeAttrs const&,
                               OpContext const&,
-                              std::vector<NDArray, std::allocator<NDArray>> const&,
-                              std::vector<OpReqType, std::allocator<OpReqType>> const&,
-                              std::vector<NDArray, std::allocator<NDArray>> const&);
+                              std::vector<NDArray, std::allocator<NDArray> > const&,
+                              std::vector<OpReqType, std::allocator<OpReqType> > const&,
+                              std::vector<NDArray, std::allocator<NDArray> > const&);
 
 template void FallBackCompute(void (*)(OpStatePtr const&,
                                        OpContext const&,
-                                       std::vector<TBlob, std::allocator<TBlob>> const&,
-                                       std::vector<OpReqType, std::allocator<OpReqType>> const&,
-                                       std::vector<TBlob, std::allocator<TBlob>> const&),
+                                       std::vector<TBlob, std::allocator<TBlob> > const&,
+                                       std::vector<OpReqType, std::allocator<OpReqType> > const&,
+                                       std::vector<TBlob, std::allocator<TBlob> > const&),
                               OpStatePtr const&,
                               OpContext const&,
-                              std::vector<NDArray, std::allocator<NDArray>> const&,
-                              std::vector<OpReqType, std::allocator<OpReqType>> const&,
-                              std::vector<NDArray, std::allocator<NDArray>> const&);
+                              std::vector<NDArray, std::allocator<NDArray> > const&,
+                              std::vector<OpReqType, std::allocator<OpReqType> > const&,
+                              std::vector<NDArray, std::allocator<NDArray> > const&);
 
 void OpCheck::Init(const std::vector<mxnet::NDArray>& inputs_,
                    const std::vector<mxnet::NDArray>& outputs_) {
@@ -548,14 +555,14 @@ void OpCheck::Init(const std::vector<mxnet::NDArray>& inputs_,
     inputs.emplace_back(data.shape(), ctx, false, data.dtype());
     if (data.IsMKLDNNData() && data.IsView())
       data = data.Reorder2Default();
-    auto mem = data.GetMKLDNNData();
-    inputs[i].CopyFrom(*mem);
+    auto mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+    inputs[i].CopyFrom(mem);
   }
   for (size_t i = 0; i < outputs_.size(); i++) {
     outputs.emplace_back(outputs_[i].shape(), ctx, false, outputs_[i].dtype());
     if (backward) {
-      auto mem = outputs_[i].GetMKLDNNData();
-      outputs[i].CopyFrom(*mem);
+      auto mem = static_cast<const mkldnn::memory*>(outputs_[i].GetMKLDNNData());
+      outputs[i].CopyFrom(mem);
     }
   }
   MKLDNNStream::Get()->Submit();
@@ -602,8 +609,8 @@ void OpCheck::CopyResult(const std::vector<mxnet::NDArray>& outputs_,
   CHECK(!MKLDNNStream::Get()->HasOps());
   auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray>&>(outputs_);
   for (auto i = indice.begin(); i != indice.end(); ++i) {
-    auto mem = outputs[*i].GetMKLDNNData();
-    non_const_outputs_[*i].CopyFrom(*mem);
+    auto mem = static_cast<const mkldnn::memory*>(outputs[*i].GetMKLDNNData());
+    non_const_outputs_[*i].CopyFrom(mem);
   }
   MKLDNNStream::Get()->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 9b25b13..56eed58 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -176,11 +176,13 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
   NDArray& data = in_data[batchnorm::kData];
   if (data.IsMKLDNNData() && data.IsView())
     data = data.Reorder2Default();
-  auto data_mem = data.GetMKLDNNData();
+  auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
   auto& fwd     = GetBNForward<DType>(param, ctx, data_mem, flags);
 
   // for output memory
-  auto out_mem = const_cast<NDArray&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
+  auto fwd_dst_desc = fwd.GetPd().dst_desc();
+  auto out_mem =
+      static_cast<mkldnn::memory*>(const_cast<NDArray&>(out).CreateMKLDNNData(&fwd_dst_desc));
 
   // mxnet will always use scale shift.
   // But if fix_gamma is true, then all scale elements will be set to 1.0f
@@ -226,7 +228,9 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
         LOG(FATAL) << "MKLDNN BatchNorm: incorrect workspace input";
       }
       auto ws = std::make_shared<mkldnn::memory>(
-          fwd.GetPd().workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+          fwd.GetPd().workspace_desc(),
+          engine,
+          static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData())->get_data_handle());
       net_args[MKLDNN_ARG_WORKSPACE] = *ws;
     }
     if (!ctx.is_train || param.use_global_stats) {
@@ -239,15 +243,17 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
         omean[i] = inmean[i];
         ovar[i]  = VARIANCE_TO_INVSTD(invar[i], param.eps);
       }
-      net_args[MKLDNN_ARG_MEAN]     = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData());
-      net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData());
+      net_args[MKLDNN_ARG_MEAN] =
+          *(static_cast<const mkldnn::memory*>(aux_states[batchnorm::kMovingMean].GetMKLDNNData()));
+      net_args[MKLDNN_ARG_VARIANCE] =
+          *(static_cast<const mkldnn::memory*>(aux_states[batchnorm::kMovingVar].GetMKLDNNData()));
       MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
       MKLDNNStream::Get()->Submit();
     } else {  // training
-      const NDArray& outMean        = outputs[batchnorm::kMean];
-      const NDArray& outVar         = outputs[batchnorm::kVar];
-      net_args[MKLDNN_ARG_MEAN]     = *(outMean.GetMKLDNNData());
-      net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData());
+      const NDArray& outMean    = outputs[batchnorm::kMean];
+      const NDArray& outVar     = outputs[batchnorm::kVar];
+      net_args[MKLDNN_ARG_MEAN] = *(static_cast<const mkldnn::memory*>(outMean.GetMKLDNNData()));
+      net_args[MKLDNN_ARG_VARIANCE] = *(static_cast<const mkldnn::memory*>(outVar.GetMKLDNNData()));
       MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
       MKLDNNStream::Get()->Submit();
 
@@ -374,14 +380,17 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
     gradIn = gradIn.Reshape(new_shape);
   }
 
-  auto data_mem = data.GetMKLDNNData();
-  auto diff_mem = diff.GetMKLDNNData();
+  auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+  auto diff_mem = static_cast<const mkldnn::memory*>(diff.GetMKLDNNData());
   // MKLDNN batchnorm should run on special layouts. If one of them isn't, we
   // should reorder them.
-  if (data.IsDefaultData())
-    data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc());
-  else if (diff.IsDefaultData())
-    diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
+  if (data.IsDefaultData()) {
+    auto diff_desc = diff_mem->get_desc();
+    data_mem       = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&diff_desc));
+  } else if (diff.IsDefaultData()) {
+    auto data_mem_desc = data_mem->get_desc();
+    diff_mem = static_cast<const mkldnn::memory*>(diff.GetMKLDNNDataReorder(&data_mem_desc));
+  }
   auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
   auto gradi_mem =
       CreateMKLDNNMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
@@ -414,7 +423,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
       const NDArray* workspace = nullptr;
       workspace                = &inputs[8];
       if (workspace != nullptr) {
-        net_args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
+        net_args[MKLDNN_ARG_WORKSPACE] =
+            *(static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()));
       }
     }
 
@@ -434,11 +444,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
         tmp_var_ptr[i]     = variance;
         moving_var_ptr[i]  = moving_var_ptr[i] * param.momentum + variance * minus_mom;
       }
-      net_args[MKLDNN_ARG_MEAN]     = *(out_mean.GetMKLDNNData());
+      net_args[MKLDNN_ARG_MEAN] = *(static_cast<const mkldnn::memory*>(out_mean.GetMKLDNNData()));
       net_args[MKLDNN_ARG_VARIANCE] = var_mem;
     } else {
-      net_args[MKLDNN_ARG_MEAN]     = *(moving_mean.GetMKLDNNData());
-      net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
+      net_args[MKLDNN_ARG_MEAN] =
+          *(static_cast<const mkldnn::memory*>(moving_mean.GetMKLDNNData()));
+      net_args[MKLDNN_ARG_VARIANCE] =
+          *(static_cast<const mkldnn::memory*>(moving_var.GetMKLDNNData()));
     }
     MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
     CommitOutput(gradIn, gradi_mem);
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index 689888a..4904ad3 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -71,7 +71,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs,
   data_md.reserve(num_in_data);
   data_mem.reserve(num_in_data);
   for (int i = 0; i < num_in_data; i++) {
-    const mkldnn::memory* tmp_mem = in_data[i].GetMKLDNNData();
+    const mkldnn::memory* tmp_mem = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
     mkldnn::memory::desc tmp_md   = tmp_mem->get_desc();
     data_md.push_back(tmp_md);
     data_mem.push_back(tmp_mem);
@@ -98,7 +98,7 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs,
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   const int num_in_data    = param.num_args;
   const int axis           = param.dim;
-  const auto gradz_mem     = inputs[0].GetMKLDNNData();
+  const auto gradz_mem     = static_cast<const mkldnn::memory*>(inputs[0].GetMKLDNNData());
   /* init the offset */
   mkldnn::memory::dims offsets(outputs[0].shape().ndim());
   for (auto& v : offsets) {
@@ -107,7 +107,7 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs,
 
   for (int i = 0; i < num_in_data; i++) {
     mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end());
-    auto diff_src_md = outputs[i].GetMKLDNNData()->get_desc();
+    auto diff_src_md = static_cast<const mkldnn::memory*>(outputs[i].GetMKLDNNData())->get_desc();
     auto gradi_mem   = CreateMKLDNNMem(outputs[i], diff_src_md, req[i]);
 
     auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets);
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 966ba21..dcacf24 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -116,16 +116,14 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
           // MKL-DNN introduced padded formats since 0.15 which require more memory
           // compared to the actual size of the tensor. Currently, MKL-DNN operators
           // still reuse memory from memory planning, so here we need to select a
-          // suboptimal kernel for computation that has the expected memory size
-          // requirements
+          // suboptimal kernel for computation that has the expected memory size requirements
           auto conv_pd =
               std::make_shared<mkldnn::convolution_forward::primitive_desc>(desc, attr, engine);
           while (conv_pd->dst_desc().get_size() != GetArraySize(output) ||
                  conv_pd->src_desc().get_size() != GetArraySize(data) ||
                  (!param.mkldnn_param.quantized &&
                   conv_pd->weights_desc().get_size() != GetArraySize(weights))) {
-            // next_impl() will visit desc and engine, please make sure they are
-            // still alive here.
+            // next_impl() will visit desc and engine, please make sure they are still alive here.
             CHECK(conv_pd->next_impl()) << "No convolution implementation for this request.";
           }
           return conv_pd;
@@ -488,7 +486,8 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam& param,
   auto& weight = in_data[conv::kWeight];
   bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn;
 
-  auto data_mem = data.GetMKLDNNDataReorder(fwd->GetPd().src_desc());
+  auto fwd_src_desc = fwd->GetPd().src_desc();
+  auto data_mem     = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
   const mkldnn::memory* weight_mem;
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
@@ -502,25 +501,30 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam& param,
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
     if (weight.IsDefaultData()) {
-      // We also need to modify the layout on the original weight array. The
-      // data conversion happens after the weight array is used.
-      weight.MKLDNNDataReorderAsync(fwd->GetPd().weights_desc());
+      // We also need to modify the layout on the original weight array. The data conversion happens
+      // after the weight array is used.
+      auto fwd_weight_desc = fwd->GetPd().weights_desc();
+      weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
       weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
     } else {
-      weight_mem = weight.GetMKLDNNDataReorder(fwd->GetPd().weights_desc());
+      auto fwd_weight_desc = fwd->GetPd().weights_desc();
+      weight_mem =
+          static_cast<const mkldnn::memory*>(weight.GetMKLDNNDataReorder(&fwd_weight_desc));
     }
   }
   mkldnn_output_t out_mem;
   if (param.mkldnn_param.with_sum) {
     out_mem = mkldnn_output_t(OutDataOp::Noop,
-                              const_cast<mkldnn::memory*>(out_data[conv::kOut].GetMKLDNNData()));
+                              const_cast<mkldnn::memory*>(static_cast<const mkldnn::memory*>(
+                                  out_data[conv::kOut].GetMKLDNNData())));
   } else {
     out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->GetPd().dst_desc(), req[conv::kOut]);
   }
 
   mkldnn_args_map_t net_args;
   if (!no_bias) {
-    const mkldnn::memory* bias_mem = in_data[conv::kBias].GetMKLDNNData();
+    const mkldnn::memory* bias_mem =
+        static_cast<const mkldnn::memory*>(in_data[conv::kBias].GetMKLDNNData());
     net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
   }
 
@@ -612,7 +616,9 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs,
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
   MKLDNNConvBackward& convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
-  auto out_grad_mem           = out_grad.GetMKLDNNDataReorder(convBwd.GetDataPd().diff_dst_desc());
+  auto convBwd_data_diff_desc = convBwd.GetDataPd().diff_dst_desc();
+  auto out_grad_mem =
+      static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&convBwd_data_diff_desc));
   if (req[conv::kData]) {
     auto weight_mem  = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group);
     auto in_grad_mem = CreateMKLDNNMem(
@@ -624,9 +630,14 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs,
     CommitOutput(in_grad[conv::kData], in_grad_mem);
   }
   if (req[conv::kWeight] || req[conv::kBias]) {
-    if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc())
-      out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetWeightsPd().diff_dst_desc());
-    auto data_mem       = data.GetMKLDNNDataReorder(convBwd.GetWeightsPd().src_desc());
+    if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) {
+      auto convBwd_weight_diff_desc = convBwd.GetWeightsPd().diff_dst_desc();
+      out_grad_mem                  = static_cast<const mkldnn::memory*>(
+          out_grad.GetMKLDNNDataReorder(&convBwd_weight_diff_desc));
+    }
+    auto convBwd_weight_src_desc = convBwd.GetWeightsPd().src_desc();
+    auto data_mem =
+        static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&convBwd_weight_src_desc));
     auto in_grad_weight = CreateMKLDNNWeightGrad(
         in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]);
 
diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc
index 601df3c..4051144 100644
--- a/src/operator/nn/mkldnn/mkldnn_copy.cc
+++ b/src/operator/nn/mkldnn/mkldnn_copy.cc
@@ -38,18 +38,21 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs,
   if (req == kNullOp || req == kWriteInplace)
     return;
   TmpMemMgr::Get()->Init(ctx.requested[0]);
-  auto in_mem = in_data.GetMKLDNNData();
+  auto in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
   if (req == kAddTo) {
     TmpMemMgr::Get()->Init(ctx.requested[0]);
     // We should try and force the input memory has the same format
     // as the input output. If not, we'll have to reorder memory.
-    auto out_mem = out_data.GetMKLDNNData();
-    in_mem       = in_data.GetMKLDNNData(out_mem->get_desc());
-    if (in_mem == nullptr)
-      in_mem = in_data.GetMKLDNNDataReorder(out_mem->get_desc());
+    auto out_mem      = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData());
+    auto out_mem_desc = out_mem->get_desc();
+    in_mem            = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData(&out_mem_desc));
+    if (in_mem == nullptr) {
+      auto out_mem_desc = out_mem->get_desc();
+      in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNDataReorder(&out_mem_desc));
+    }
     MKLDNNSum(*out_mem, *in_mem, *out_mem);
   } else {
-    const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
+    const_cast<NDArray&>(out_data).CopyFrom(in_mem);
   }
   MKLDNNStream::Get()->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
index b048c13..f1cd322 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
@@ -58,9 +58,8 @@ using deconv_bwd_data_pd_t = mkldnn::deconvolution_backward_data::primitive_desc
 using deconv_bwd_weights_t    = mkldnn::deconvolution_backward_weights;
 using deconv_bwd_weights_pd_t = mkldnn::deconvolution_backward_weights::primitive_desc;
 
-// Swaps the logical order of dimensions that in plain format would correspond
-// to input and output channels (for example: oihw => iohw, iohw => oihw, goihw
-// => giohw).
+// Swaps the logical order of dimensions that in plain format would correspond to input and output
+// channels (for example: oihw => iohw, iohw => oihw, goihw => giohw).
 inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc& desc,
                                               const uint32_t num_group) {
   std::vector<int> order(desc.data.ndims);
@@ -74,18 +73,18 @@ inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc& desc,
 inline void IOLogicalSwapMKLDNNMem(const NDArray& arr, const uint32_t num_group) {
   mkldnn::memory::desc desc;
   if (arr.IsMKLDNNData()) {
-    desc = arr.GetMKLDNNData()->get_desc();
+    desc = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc();
   } else {
-    // GetMKLDNNData won't take groups into account when creating
-    // mkldnn::memory, we need to use descriptor from GetWeightDesc but with
-    // default format
+    // GetMKLDNNData won't take groups into account when creating mkldnn::memory, we need to use
+    // descriptor from GetWeightDesc but with default format
     const auto& temp = GetWeightDesc(arr, num_group);
     desc             = mkldnn::memory::desc(
         temp.dims(),
         temp.data_type(),
         static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(temp.data.ndims)));
   }
-  const_cast<NDArray&>(arr).UpdateMKLDNNMemDesc(IOLogicalSwapDesc(desc, num_group));
+  auto logical_swap = IOLogicalSwapDesc(desc, num_group);
+  const_cast<NDArray&>(arr).UpdateMKLDNNMemDesc(&logical_swap);
 }
 
 // Version of GetWeightsDesc for deconvolution (with swap)
@@ -152,7 +151,8 @@ MKLDNNDeconvFwd::MKLDNNDeconvFwd(const DeconvolutionParam& param, const Tensors&
 }
 
 inline const mkldnn::memory* MKLDNNDeconvFwd::DataMem(const NDArray& data) const {
-  return data.GetMKLDNNDataReorder(fwd_pd->src_desc());
+  auto fwd_src_desc = fwd_pd->src_desc();
+  return static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
 }
 
 inline const mkldnn::memory* MKLDNNDeconvFwd::WeightsMem(const uint32_t num_group,
@@ -161,7 +161,7 @@ inline const mkldnn::memory* MKLDNNDeconvFwd::WeightsMem(const uint32_t num_grou
 }
 
 inline const mkldnn::memory* MKLDNNDeconvFwd::BiasMem(const NDArray& bias) const {
-  return bias.GetMKLDNNData();
+  return static_cast<const mkldnn::memory*>(bias.GetMKLDNNData());
 }
 
 inline mkldnn_output_t MKLDNNDeconvFwd::OutMem(const OpReqType req, const NDArray& out) const {
@@ -210,8 +210,8 @@ class MKLDNNDeconvBwd {
                             const NDArray& weights,
                             const NDArray& weights_grad) const;
 
-  // returns the output gradient memory used to calculate the data (input)
-  // gradient, which might be reused when calculating the gradient of weights
+  // returns the output gradient memory used to calculate the data (input) gradient,
+  // which might be reused when calculating the gradient of weights
   const mkldnn::memory* ScheduleBwdData(const uint32_t num_group,
                                         const OpReqType req,
                                         const ReadTensors& read_tensors,
@@ -279,7 +279,8 @@ inline void MKLDNNDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group,
 }
 
 inline const mkldnn::memory* MKLDNNDeconvBwd::DataMem(const NDArray& data) const {
-  return data.GetMKLDNNDataReorder(bwd_weights_pd->src_desc());
+  auto bwd_weight_src_desc = bwd_weights_pd->src_desc();
+  return static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&bwd_weight_src_desc));
 }
 
 inline const mkldnn::memory* MKLDNNDeconvBwd::WeightsMem(const uint32_t num_group,
@@ -288,15 +289,18 @@ inline const mkldnn::memory* MKLDNNDeconvBwd::WeightsMem(const uint32_t num_grou
 }
 
 inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem(const NDArray& out_grad) const {
-  return out_grad.GetMKLDNNDataReorder(bwd_data_pd->diff_dst_desc());
+  auto bwd_data_diff_desc = bwd_data_pd->diff_dst_desc();
+  return static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&bwd_data_diff_desc));
 }
 
 inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem(
     const NDArray& out_grad,
     const mkldnn::memory* const out_grad_mem) const {
+  auto bwd_weight_diff_desc = bwd_weights_pd->diff_dst_desc();
   return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc())
              ? out_grad_mem
-             : out_grad.GetMKLDNNDataReorder(bwd_weights_pd->diff_dst_desc());
+             : static_cast<const mkldnn::memory*>(
+                   out_grad.GetMKLDNNDataReorder(&bwd_weight_diff_desc));
 }
 
 inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req,
@@ -307,14 +311,15 @@ inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req,
 inline mkldnn_output_t MKLDNNDeconvBwd::WeightsGradMem(const uint32_t num_group,
                                                        const OpReqType req,
                                                        const NDArray& weights_grad) const {
-  // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat
-  // always fails (because of the logical swap - explained in
-  // MKLDNNDeconvFwd::Execute). We try to reuse weights_grad memory (which, when
-  // not swapped, is always in default format), so here we check if after a
+  // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat always fails (because
+  // of the logical swap - explained in MKLDNNDeconvFwd::Execute). We try to reuse weights_grad
+  // memory (which, when not swapped, is always in default format), so here we check if after a
   // swap, weights_md will have a default format
   const auto& weights_md = bwd_weights_pd->diff_weights_desc();
   if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) {
-    return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateMKLDNNData(weights_md)};
+    return {OutDataOp::Noop,
+            static_cast<mkldnn::memory*>(
+                const_cast<NDArray&>(weights_grad).CreateMKLDNNData(&weights_md))};
   }
   return CreateMKLDNNWeightGrad(weights_grad, weights_md, req);
 }
@@ -334,13 +339,13 @@ class DeconvDescCreator {
                     const NDArray* const bias,
                     const NDArray& out);
 
-  // Imposes plain formats on memory descriptors with padding (so the next
-  // selected implementation will pass CheckImplSizeReq). After calling this
-  // method, new primitive descriptor (with new operator descriptor) should be
-  // created, which should select an implementation with matching size
-  // requirements. data_size, weights_size, out_size - size requirements of
-  // current implementation Returns whether successfully imposed a plain format
-  // on any of the data, weights, and output memory descriptors.
+  // Imposes plain formats on memory descriptors with padding (so the next selected implementation
+  // will pass CheckImplSizeReq). After calling this method, new primitive descriptor (with new
+  // operator descriptor) should be created, which should select an implementation with matching
+  // size requirements.
+  // data_size, weights_size, out_size - size requirements of current implementation
+  // Returns whether successfully imposed a plain format on any of the data, weights, and output
+  // memory descriptors.
   bool ImposePlainWherePadding(const size_t data_size,
                                const size_t weights_size,
                                const size_t out_size);
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 211ccd6..8428549 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -112,9 +112,10 @@ void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group,
     if (weights.IsDefaultData()) {
       // We also need to modify the layout on the original weights array.
       // The data conversion happens after the weights array is used.
-      weights.MKLDNNDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
+      auto logical_swap_desc = IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group);
+      weights.MKLDNNDataReorderAsync(&logical_swap_desc);
     } else {
-      CHECK(weights.GetMKLDNNData()->get_desc() ==
+      CHECK(static_cast<const mkldnn::memory*>(weights.GetMKLDNNData())->get_desc() ==
             IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
     }
   }
@@ -123,10 +124,9 @@ void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group,
 void MKLDNNDeconvFwd::Execute(const uint32_t num_group,
                               const OpReqType req,
                               const Tensors& tensors) const {
-  // MXNet (correctly) assumes that deconvolution is implemented using
-  // convolution primitives. For that, we would pass input tensor in place of
-  // output and output tensor in place of input (for appropriate convolution
-  // primitives: deconvolution forward = convolution backward data,
+  // MXNet (correctly) assumes that deconvolution is implemented using convolution primitives.
+  // For that, we would pass input tensor in place of output and output tensor in place of input
+  // (for appropriate convolution primitives: deconvolution forward = convolution backward data,
   // deconvolution backward data = convolution forward).
   // The convolution primitive expects weights tensor with the shape of
   // (primitive_out_channels, primitive_in_channels, h, w), but with swapped
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index ccd1287..a3ed6c8 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -80,8 +80,7 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
       return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
     } catch (mkldnn::error& e) {
       if (e.status == mkldnn_unimplemented && full_param.mkldnn_param.quantized) {
-        LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 "
-                      "fully_connected.";
+        LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected.";
       } else {
         LOG(ERROR) << e.message;
       }
@@ -202,7 +201,8 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
   NDArray weight = in_data[fullc::kWeight];
   NDArray data   = in_data[fullc::kData];
 
-  auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
+  auto fwd_src_desc = fwd->fwd_pd.src_desc();
+  auto data_mem     = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
   const mkldnn::memory* weight_mem;
   if (ctx.is_train) {
     if (weight.IsMKLDNNData()) {
@@ -210,9 +210,10 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
     }
     weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
   } else {
-    weight_mem = weight.GetMKLDNNData();
+    weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
     if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
-      weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
+      auto fwd_weight_desc = fwd->fwd_pd.weights_desc();
+      weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
       weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
     }
   }
@@ -225,7 +226,9 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
       {MKLDNN_ARG_DST, *out_mem.second},
   };
   if (!full_param.default_param.no_bias) {
-    auto bias_mem         = in_data[fullc::kBias].GetMKLDNNDataReorder(fwd->fwd_pd.bias_desc());
+    auto fwd_bias_desc = fwd->fwd_pd.bias_desc();
+    auto bias_mem      = static_cast<const mkldnn::memory*>(
+        in_data[fullc::kBias].GetMKLDNNDataReorder(&fwd_bias_desc));
     args[MKLDNN_ARG_BIAS] = *bias_mem;
   }
   MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
@@ -299,8 +302,14 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs,
   if (req[fullc::kWeight]) {
     mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(
         data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd);
-    auto out_grad_mem   = out_grad.GetMKLDNNDataReorder(ipBwdWeights_pd.diff_dst_desc());
-    auto data_mem       = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
+
+    auto ipBwdWeights_diff_dst_desc = ipBwdWeights_pd.diff_dst_desc();
+    auto out_grad_mem               = static_cast<const mkldnn::memory*>(
+        out_grad.GetMKLDNNDataReorder(&ipBwdWeights_diff_dst_desc));
+
+    auto ipBwdWeights_src_desc = ipBwdWeights_pd.src_desc();
+    auto data_mem =
+        static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&ipBwdWeights_src_desc));
     auto in_grad_weight = CreateMKLDNNWeightGrad(
         in_grad[fullc::kWeight], ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]);
     mkldnn_args_map_t args = {
@@ -323,8 +332,13 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs,
   if (req[fullc::kData]) {
     mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd =
         GetFCBwdData(data, weight, out_grad, fwd_pd);
-    auto out_grad_mem = out_grad.GetMKLDNNDataReorder(ipBwdData_pd.diff_dst_desc());
-    auto weight_mem   = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
+    auto ipBwdData_diff_dst_desc = ipBwdData_pd.diff_dst_desc();
+    auto out_grad_mem =
+        static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&ipBwdData_diff_dst_desc));
+
+    auto ipBwdData_weight_desc = ipBwdData_pd.weights_desc();
+    auto weight_mem =
+        static_cast<const mkldnn::memory*>(weight.GetMKLDNNDataReorder(&ipBwdData_weight_desc));
     auto in_grad_mem =
         CreateMKLDNNMem(in_grad[fullc::kData], ipBwdData_pd.diff_src_desc(), req[fullc::kData]);
     mkldnn_args_map_t args = {{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/mkldnn/mkldnn_log_softmax.cc b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
index a4cb66a..c3f41eb 100644
--- a/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
+++ b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
@@ -60,8 +60,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param,
   const int axis      = CheckAxis(param.axis, ndim);
   // MKLDNN does not support temperature argument in their log_softmax function
   // now. Need update this once they start to support it.
-  // Currently, MKLDNN shows bad performance when log_softmax is not performed
-  // on the last dimension
+  // Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension
   if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype ||
       axis != (ndim - 1)) {
     return false;
@@ -110,7 +109,8 @@ static MKLDNNLogSoftmaxFwd& GetLogSoftmaxFwd(const SoftmaxParam& param,
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    MKLDNNLogSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
+    MKLDNNLogSoftmaxFwd fwd(
+        is_train, real_axis, *(static_cast<const mkldnn::memory*>(data.GetMKLDNNData())));
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
@@ -123,16 +123,16 @@ void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
                              const NDArray& out_data) {
   if (req == kNullOp)
     return;
-  // same as the FCompute path, log_softmax only supports kWriteTo and
-  // kWriteInplace for now.
+  // same as the FCompute path, log_softmax only supports kWriteTo and kWriteInplace for now.
   CHECK_NE(req, kAddTo);
 
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   int axis                  = CheckAxis(param.axis, in_data.shape().ndim());
   auto fwd                  = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
 
-  auto in_mem          = in_data.GetMKLDNNData();
-  auto out_mem         = out_data.GetMKLDNNData(fwd.pd.dst_desc());
+  auto in_mem          = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
+  auto fwd_desc        = fwd.pd.dst_desc();
+  auto out_mem         = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData(&fwd_desc));
   MKLDNNStream* stream = MKLDNNStream::Get();
   stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
   stream->Submit();
@@ -176,8 +176,8 @@ static MKLDNNLogSoftmaxBwd& GetLogSoftmaxBwd(const SoftmaxParam& param,
 
   auto it = bwds.find(key);
   if (it == bwds.end()) {
-    auto diff_mem = data[0].GetMKLDNNData();
-    auto data_mem = data[1].GetMKLDNNData();
+    auto diff_mem = static_cast<const mkldnn::memory*>(data[0].GetMKLDNNData());
+    auto data_mem = static_cast<const mkldnn::memory*>(data[1].GetMKLDNNData());
     auto fwd_pd   = GetLogSoftmaxFwdPd(true, real_axis, *data_mem);
     MKLDNNLogSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
     it = AddToCache(&bwds, key, bwd);
@@ -195,8 +195,8 @@ void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_data.size(), 2U);
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   int axis                  = CheckAxis(param.axis, in_data[1].shape().ndim());
-  auto diff_mem             = in_data[0].GetMKLDNNData();
-  auto data_mem             = in_data[1].GetMKLDNNData();
+  auto diff_mem             = static_cast<const mkldnn::memory*>(in_data[0].GetMKLDNNData());
+  auto data_mem             = static_cast<const mkldnn::memory*>(in_data[1].GetMKLDNNData());
   auto bwd                  = GetLogSoftmaxBwd(param, axis, in_data, out_data);
 
   auto out_mem           = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index f47804d..91c6d41 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -106,8 +106,9 @@ class MKLDNNLRNFwd {
 };  // End of LRN Forword Class
 
 void MKLDNNLRNFwd::_Init(const LRNParam& param, bool is_train, const NDArray& in_data) {
-  mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_desc();
-  this->fwd_pd                    = GetLRNFwdDesc(param, is_train, in_data_md);
+  mkldnn::memory::desc in_data_md =
+      static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())->get_desc();
+  this->fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
 
   this->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward(this->fwd_pd));
 }
@@ -119,7 +120,7 @@ void MKLDNNLRNFwd::Execute(const OpContext& ctx,
   auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req);
 
   mkldnn_args_map_t args = {
-      {MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
+      {MKLDNN_ARG_SRC, *static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())},
       {MKLDNN_ARG_DST, *output_mem_t.second},
   };
   std::shared_ptr<mkldnn::memory> workspace;
@@ -206,10 +207,11 @@ class MKLDNNLRNBwd {
                const mkldnn_output_t& diff_src_mem) {
     auto engine    = CpuEngine::Get()->get_engine();
     auto workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
-    mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
-                              {MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()},
-                              {MKLDNN_ARG_WORKSPACE, *workspace},
-                              {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}};
+    mkldnn_args_map_t args = {
+        {MKLDNN_ARG_SRC, *static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())},
+        {MKLDNN_ARG_DIFF_DST, *static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData())},
+        {MKLDNN_ARG_WORKSPACE, *workspace},
+        {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}};
     MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args);
     CommitOutput(in_grad, diff_src_mem);
     MKLDNNStream::Get()->Submit();
@@ -232,8 +234,10 @@ static MKLDNNLRNBwd& GetLRNBwd(const LRNParam& param,
 
   auto it = lrn_bwds.find(key);
   if (it == lrn_bwds.end()) {
-    const mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_desc();
-    const mkldnn::memory::desc diff_md    = out_grad.GetMKLDNNData()->get_desc();
+    const mkldnn::memory::desc in_data_md =
+        static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())->get_desc();
+    const mkldnn::memory::desc diff_md =
+        static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData())->get_desc();
     MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
     it = AddToCache(&lrn_bwds, key, bwd);
   }
@@ -252,8 +256,7 @@ void MKLDNNLRNBackward(const nnvm::NodeAttrs& attrs,
   const NDArray& out_grad = inputs[0];
   const NDArray& in_data  = inputs[1];
   const NDArray& in_grad  = outputs[0];
-  // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when
-  // in_data is nchw8c
+  // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c
   const auto in_buffer         = in_data.Reorder2Default();
   MKLDNNLRNBwd& bwd            = GetLRNBwd(param, in_buffer, in_grad, out_grad);
   mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 3b3be5c..9a88b86 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -42,8 +42,8 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray& input,
                             const mkldnn::memory::dims& pad_r,
                             const bool is_train,
                             const mkldnn::algorithm alg_kind) {
-  const auto src_md           = input.GetMKLDNNData()->get_desc();
-  const auto dst_md           = GetMemDesc(output);
+  const auto src_md = static_cast<const mkldnn::memory*>(input.GetMKLDNNData())->get_desc();
+  const auto dst_md = GetMemDesc(output);
   const mkldnn::engine engine = CpuEngine::Get()->get_engine();
   if (alg_kind != mkldnn::algorithm::pooling_max && alg_kind != mkldnn::algorithm::pooling_avg &&
       alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
@@ -75,7 +75,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray& in_data,
   if (in_data.IsView() && in_data.IsMKLDNNData())
     in_buffer = in_data.Reorder2Default();
 
-  auto input_mem     = in_buffer.GetMKLDNNData();
+  auto input_mem     = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req);
 
   mkldnn_args_map_t args = {
@@ -91,7 +91,11 @@ void MKLDNNPoolingFwd::Execute(const NDArray& in_data,
     }
 
     auto ws = std::make_shared<mkldnn::memory>(
-        (*(this->fwd_pd_)).workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+        (*(this->fwd_pd_)).workspace_desc(),
+        engine,
+        static_cast<const mkldnn::memory*>(
+            static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()))
+            ->get_data_handle());
     args[MKLDNN_ARG_WORKSPACE] = *ws;
   }
   if (this->fwd_) {
@@ -231,7 +235,7 @@ MKLDNNPoolingFwd& GetPoolingFwd(const PoolingParam& param,
   if (it == pooling_fwds.end()) {
     CHECK(param.kernel.ndim() == 1 || param.kernel.ndim() == 2 || param.kernel.ndim() == 3)
         << "Not Implemented";
-    auto data_md = data.GetMKLDNNData()->get_desc();
+    auto data_md = static_cast<const mkldnn::memory*>(data.GetMKLDNNData())->get_desc();
 
     const auto kernel_ndims = param.kernel.ndim();
     mkldnn::memory::dims kernel(kernel_ndims);
@@ -288,7 +292,7 @@ MKLDNNPoolingBwd& GetPoolingBwd(const PoolingParam& param,
 
   auto it = pooling_bwds.find(key);
   if (it == pooling_bwds.end()) {
-    auto input_mem                     = in_data.GetMKLDNNData();
+    auto input_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
     const mkldnn::memory::desc data_md = input_mem->get_desc();
 
     auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), out_grad.shape().end());
@@ -337,14 +341,16 @@ void MKLDNNPoolingGradCompute(const OpContext& ctx,
   TmpMemMgr::Get()->Init(ctx.requested[0]);
 
   auto& bwd              = GetPoolingBwd(param, in_data, in_grad, out_grad);
-  auto diff_dst_mem      = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
+  auto bwd_diff_dst_desc = bwd.pd.diff_dst_desc();
+  auto diff_dst_mem =
+      static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&bwd_diff_dst_desc));
   auto diff_src_mem      = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
   mkldnn_args_map_t args = {
       {MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
       {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
   };
   if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
-    args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
+    args[MKLDNN_ARG_WORKSPACE] = *(static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()));
   }
 
   MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args);
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index ef23a74..b9d26a0 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -36,7 +36,7 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType& req,
                                    const NDArray& input,
                                    const NDArray& output) {
   const auto engine = CpuEngine::Get()->get_engine();
-  auto in_mem       = input.GetMKLDNNData();
+  auto in_mem       = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
 
   // Create temp memory
   auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
@@ -71,18 +71,22 @@ void MKLDNNReshapeFwd::Execute(const NDArray& input,
                                const OpReqType& req,
                                void* workspace) {
   auto stream = MKLDNNStream::Get();
-  auto in_mem = input.GetMKLDNNData();
+  auto in_mem = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
   // register primitives and arguments
   std::vector<mkldnn_args_map_t> args_map;
   size_t prims_size = prims_.size();
   if (prims_size == 1) {
-    args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+    args_map.push_back(
+        {{MKLDNN_ARG_FROM, *in_mem},
+         {MKLDNN_ARG_TO, *static_cast<const mkldnn::memory*>(output.GetMKLDNNData())}});
   } else if (prims_size == 2) {
     if (workspace) {
       temp_->set_data_handle(workspace);
     }
     args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, {MKLDNN_ARG_TO, *temp_}});
-    args_map.push_back({{MKLDNN_ARG_FROM, *temp_}, {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+    args_map.push_back(
+        {{MKLDNN_ARG_FROM, *temp_},
+         {MKLDNN_ARG_TO, *static_cast<const mkldnn::memory*>(output.GetMKLDNNData())}});
   } else {
     CHECK(prims_size == 0 && req != kWriteTo) << "kWriteTo should never reach here.";
   }
@@ -120,8 +124,8 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
                           const NDArray& input,
                           const OpReqType& req,
                           const NDArray& output) {
-  // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's
-  // simply fallback to naive implement.
+  // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to
+  // naive implement.
   const int input_ndims = input.shape().ndim();
   if ((input_ndims < 1 || input_ndims > 4) || !SupportMKLDNNQuantize(input.dtype())) {
     if (req != kWriteInplace) {
diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc
index 345560d..8a85e60 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice.cc
+++ b/src/operator/nn/mkldnn/mkldnn_slice.cc
@@ -49,8 +49,8 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam& param, const NDArray& in, const
     offsets[i] = s;
   }
 
-  auto in_md  = in.GetMKLDNNData()->get_desc();
-  auto out_md = out.GetMKLDNNData()->get_desc();
+  auto in_md  = static_cast<const mkldnn::memory*>(in.GetMKLDNNData())->get_desc();
+  auto out_md = static_cast<const mkldnn::memory*>(out.GetMKLDNNData())->get_desc();
   auto sub_md = in_md.submemory_desc(dims, offsets);
 
   auto engine = CpuEngine::Get()->get_engine();
@@ -98,8 +98,8 @@ void MKLDNNSlice(const nnvm::NodeAttrs& attrs,
                  const NDArray& out) {
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
   MKLDNNSliceFwd& fwd     = GetSliceForward(param, ctx.is_train, in, out);
-  auto in_mem             = in.GetMKLDNNData();
-  auto out_md             = out.GetMKLDNNData()->get_desc();
+  auto in_mem             = static_cast<const mkldnn::memory*>(in.GetMKLDNNData());
+  auto out_md             = static_cast<const mkldnn::memory*>(out.GetMKLDNNData())->get_desc();
   auto out_mem            = CreateMKLDNNMem(out, out_md, req);
   fwd.SetNewMem(*in_mem, *out_mem.second);
   fwd.Register();
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc
index 7e948b3..67aae1f 100644
--- a/src/operator/nn/mkldnn/mkldnn_softmax.cc
+++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc
@@ -62,8 +62,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray& data, const
   const int axis      = CheckAxis(param.axis, ndim);
   // MKLDNN does not support temperature argument in their softmax function
   // now. Need update this once they start to support it.
-  // Currently, MKLDNN shows bad performance when softmax is not performed on
-  // the last dimension
+  // Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
   if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype ||
       axis != (ndim - 1)) {
     return false;
@@ -111,7 +110,8 @@ static MKLDNNSoftmaxFwd& GetSoftmaxFwd(const SoftmaxParam& param,
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    MKLDNNSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
+    MKLDNNSoftmaxFwd fwd(
+        is_train, real_axis, *(static_cast<const mkldnn::memory*>(data.GetMKLDNNData())));
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
@@ -124,16 +124,16 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
                           const NDArray& out_data) {
   if (req == kNullOp)
     return;
-  // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace
-  // for now.
+  // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
   CHECK_NE(req, kAddTo);
 
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   int axis                  = CheckAxis(param.axis, in_data.shape().ndim());
   auto fwd                  = GetSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
 
-  auto in_mem          = in_data.GetMKLDNNData();
-  auto out_mem         = out_data.GetMKLDNNData(fwd.pd.dst_desc());
+  auto in_mem          = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
+  auto fwd_desc        = fwd.pd.dst_desc();
+  auto out_mem         = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData(&fwd_desc));
   MKLDNNStream* stream = MKLDNNStream::Get();
   stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
   stream->Submit();
@@ -176,8 +176,8 @@ static MKLDNNSoftmaxBwd& GetSoftmaxBwd(const SoftmaxParam& param,
 
   auto it = bwds.find(key);
   if (it == bwds.end()) {
-    auto diff_mem = data[0].GetMKLDNNData();
-    auto data_mem = data[1].GetMKLDNNData();
+    auto diff_mem = static_cast<const mkldnn::memory*>(data[0].GetMKLDNNData());
+    auto data_mem = static_cast<const mkldnn::memory*>(data[1].GetMKLDNNData());
     auto fwd_pd   = GetSoftmaxFwdPd(true, real_axis, *data_mem);
     MKLDNNSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
     it = AddToCache(&bwds, key, bwd);
@@ -195,8 +195,8 @@ void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_data.size(), 2U);
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   int axis                  = CheckAxis(param.axis, in_data[1].shape().ndim());
-  auto diff_mem             = in_data[0].GetMKLDNNData();
-  auto data_mem             = in_data[1].GetMKLDNNData();
+  auto diff_mem             = static_cast<const mkldnn::memory*>(in_data[0].GetMKLDNNData());
+  auto data_mem             = static_cast<const mkldnn::memory*>(in_data[1].GetMKLDNNData());
   auto bwd                  = GetSoftmaxBwd(param, axis, in_data, out_data);
 
   auto out_mem           = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
index e632245..93ea12c 100644
--- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
+++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
@@ -84,7 +84,7 @@ static MKLDNNSoftmaxOutputFwd& GetSoftmaxOutputForward(const SoftmaxOutputParam&
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    auto in_mem = *(in_data.GetMKLDNNData());
+    auto in_mem = *(static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData()));
     MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem);
     it = AddToCache(&fwds, key, fwd);
   }
@@ -109,7 +109,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
     idata = in_data[softmaxout_enum::kData].Reorder2Default();
   }
 
-  auto input_mem = idata.GetMKLDNNData();
+  auto input_mem = static_cast<const mkldnn::memory*>(idata.GetMKLDNNData());
   auto out_mem   = CreateMKLDNNMem(
       out_data[softmaxout_enum::kOut], input_mem->get_desc(), req[softmaxout_enum::kOut]);
 
diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc
index fc49b1b..c771efa 100644
--- a/src/operator/nn/mkldnn/mkldnn_sum.cc
+++ b/src/operator/nn/mkldnn/mkldnn_sum.cc
@@ -112,7 +112,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs,
   data_mem.reserve(num_inputs);
 
   for (int i = 0; i < num_inputs; ++i) {
-    const mkldnn::memory* in_mem = inputs[i].GetMKLDNNData();
+    const mkldnn::memory* in_mem = static_cast<const mkldnn::memory*>(inputs[i].GetMKLDNNData());
     mkldnn::memory::desc tmp_md  = in_mem->get_desc();
     data_md.push_back(tmp_md);
     data_mem.push_back(in_mem);
diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc
index 2f97302..40906b7 100644
--- a/src/operator/nn/mkldnn/mkldnn_transpose.cc
+++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc
@@ -66,7 +66,7 @@ class MKLDNNTransposeForward {
     }
 
     auto engine = CpuEngine::Get()->get_engine();
-    auto in_mem = data.GetMKLDNNData();
+    auto in_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
     auto src_md = in_mem->get_desc();
     data_       = std::make_shared<mkldnn::memory>(src_md, engine, nullptr);
 
@@ -90,7 +90,8 @@ class MKLDNNTransposeForward {
 
   void SetNewMem(const NDArray& data, const NDArray& output) {
     if (data.IsMKLDNNData()) {
-      this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle());
+      this->data_->set_data_handle(
+          static_cast<const mkldnn::memory*>(data.GetMKLDNNData())->get_data_handle());
     } else {
       MSHADOW_TYPE_SWITCH(
           data.dtype(), DTYPE, { this->data_->set_data_handle(data.data().dptr<DTYPE>()); });
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 808e46c..b93e3bc 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -425,8 +425,7 @@ inline std::vector<nnvm::NodeEntry> MakeGradNode(
   return CreateNodeEntries(p);
 }
 
-// quick helper to make gradient nodes that simply pass back zero. could be used
-// in output ops.
+// quick helper to make gradient nodes that simply pass back zero. could be used in output ops.
 inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(const nnvm::ObjectPtr& n,
                                                       const std::vector<nnvm::NodeEntry>& ograds) {
   std::vector<nnvm::NodeEntry> ret;
@@ -614,7 +613,8 @@ class OpSignature {
   void AddSign(const NDArray& arr) {
 #if MXNET_USE_MKLDNN == 1
     if (arr.IsMKLDNNData()) {
-      AddSign(*(arr.GetMKLDNNData()));
+      auto arr_data = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
+      AddSign(*(arr_data));
     } else {
 #endif
       hash = hash * 2 + arr.dtype();
diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
index 8fbe393..d9300f7 100644
--- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
@@ -62,7 +62,7 @@ void SgMKLDNNDequantizeOperator::Forward(const OpContext& ctx,
   NDArray in_buffer = inputs[0];
   if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
     in_buffer = inputs[0].Reorder2Default();
-  auto i_mem     = in_buffer.GetMKLDNNData();
+  auto i_mem     = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   float data_min = *inputs[1].data().dptr<float>();
   float data_max = *inputs[2].data().dptr<float>();
 
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
index 5443a46..0bad466 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
@@ -70,7 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
   if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
     in_buffer = inputs[0].Reorder2Default();
 
-  auto i_mem    = in_buffer.GetMKLDNNData();
+  auto i_mem    = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   auto i_desc   = i_mem->get_desc();
   size_t i_ndim = in_buffer.shape().ndim();
   mkldnn::memory::desc o_desc;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
index 61e19bd..9bdc072 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -81,13 +81,14 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
       }
     }
     if (req[0] != kWriteInplace) {
-      const_cast<NDArray&>(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData());
+      const_cast<NDArray&>(outputs[0])
+          .CopyFrom(static_cast<const mkldnn::memory*>(inputs[0].GetMKLDNNData()));
       MKLDNNStream::Get()->Submit();
     }
   } else {
     if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
       in_buffer = inputs[0].Reorder2Default();
-    auto i_mem = in_buffer.GetMKLDNNData();
+    auto i_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
 
     if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) {
       data_min = param_.min_calib_range.value();
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
index 2e75dbc..9037fc7 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
@@ -41,7 +41,7 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
   const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
   const NDArray& data         = in_data[quantized_batchnorm::kData];
-  auto data_mem               = data.GetMKLDNNData();
+  auto data_mem          = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
 
   // reorder if data type = uint8
   if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) {
@@ -98,8 +98,10 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
   float* moving_var_ptr      = moving_var.data().dptr<float>();
 
   // rescale gamma and beta, to make mean=0 and var=1
-  auto rescaled_mean_mem   = TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_desc());
-  auto rescaled_var_mem    = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_desc());
+  auto rescaled_mean_mem = TmpMemMgr::Get()->Alloc(
+      static_cast<const mkldnn::memory*>(moving_mean.GetMKLDNNData())->get_desc());
+  auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(
+      static_cast<const mkldnn::memory*>(moving_var.GetMKLDNNData())->get_desc());
   float* rescaled_mean_ptr = reinterpret_cast<float*>(rescaled_mean_mem->get_data_handle());
   float* rescaled_var_ptr  = reinterpret_cast<float*>(rescaled_var_mem->get_data_handle());
 
@@ -115,7 +117,9 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
   }
 
   const NDArray& out = outputs[batchnorm::kOut];
-  auto out_mem       = const_cast<NDArray&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
+  auto fwd_dst_desc = fwd.GetPd().dst_desc();
+  auto out_mem =
+      static_cast<mkldnn::memory*>(const_cast<NDArray&>(out).CreateMKLDNNData(&fwd_dst_desc));
   mkldnn_args_map_t net_args;
   net_args[MKLDNN_ARG_SRC]         = *data_mem;
   net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
index 57149bb..9d62a6b 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
@@ -72,11 +72,11 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs,
     auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]);
     if (i_scale == out_scale) {
       CHECK(in_data[i].dtype() == out_dtype);
-      auto mem = in_data[i].GetMKLDNNData();
+      auto mem = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
       data_mem.push_back(mem);
       data_md.push_back(mem->get_desc());
     } else {
-      auto mem      = in_data[i].GetMKLDNNData();
+      auto mem      = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
       auto mem_desc = mem->get_desc();
       if (in_data[i].dtype() != out_dtype) {
         mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(get_mkldnn_type(out_dtype));
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
index 5c3e8da..3c44123 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
@@ -46,13 +46,15 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
   MKLDNNConvFullParam full_param;
   full_param.conv_param = param;
   full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
-  auto& fwd     = GetConvFwd(full_param,
+  auto& fwd         = GetConvFwd(full_param,
                          ctx.is_train,
                          in_data[conv::kData],
                          in_data[conv::kWeight],
                          param.no_bias ? nullptr : &in_data[conv::kBias],
                          out_data[conv::kOut]);
-  auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.GetPd().src_desc());
+  auto fwd_src_desc = fwd.GetPd().src_desc();
+  auto data_mem =
+      static_cast<const mkldnn::memory*>(in_data[conv::kData].GetMKLDNNDataReorder(&fwd_src_desc));
   const mkldnn::memory* weight_mem;
   // For inference, we want to reorder the weight array so we don't need to
   // reorder data every time.
@@ -60,16 +62,18 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
     // We also need to modify the layout on the original weight array.
     // Don't switch below sequence because naive engine will executes
     // pushAsync synchronously.
-    weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc());
+    auto fwd_weight_desc = fwd.GetPd().weights_desc();
+    weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
     weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
   } else {
-    weight_mem = weight.GetMKLDNNData();
+    weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
   }
   auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]);
   mkldnn_args_map_t net_args;
   if (!param.no_bias) {
-    const mkldnn::memory* bias_mem =
-        in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc());
+    auto fwd_bias_desc             = fwd.GetPd().bias_desc();
+    const mkldnn::memory* bias_mem = static_cast<const mkldnn::memory*>(
+        in_data[conv::kBias].GetMKLDNNDataReorder(&fwd_bias_desc));
     net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
   }
   net_args.insert({MKLDNN_ARG_SRC, *data_mem});
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
index adee6f9..10f21bc 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
@@ -109,8 +109,10 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs,
   const float dataA_absmax = MaxAbs(dataA_min, dataA_max);
   const float dataB_absmax = MaxAbs(dataB_min, dataB_max);
 
-  auto dataA_mem = in_data[quantized_elemwise_add_enum::kDataA].GetMKLDNNData();
-  auto dataB_mem = in_data[quantized_elemwise_add_enum::kDataB].GetMKLDNNData();
+  auto dataA_mem = static_cast<const mkldnn::memory*>(
+      in_data[quantized_elemwise_add_enum::kDataA].GetMKLDNNData());
+  auto dataB_mem = static_cast<const mkldnn::memory*>(
+      in_data[quantized_elemwise_add_enum::kDataB].GetMKLDNNData());
   const bool is_dataA_int8 =
       (in_data[quantized_elemwise_add_enum::kDataA].dtype() == mshadow::kInt8);
   const float dataA_range = is_dataA_int8 ? kInt8Range : kUint8Range;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
index f994e7f..4e703ad 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
@@ -91,17 +91,20 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
   auto& fwd =
       GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);
 
-  auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
+  auto fwd_src_desc = fwd.fwd_pd.src_desc();
+  auto data_mem =
+      static_cast<const mkldnn::memory*>(in_data[fullc::kData].GetMKLDNNDataReorder(&fwd_src_desc));
   const mkldnn::memory* weight_mem = nullptr;
 
   if (weight.IsDefaultData()) {
     // We also need to modify the layout on the original weight array.
     // Don't switch below sequence because naive engine will executes
     // pushAsync synchronously.
-    weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
+    auto fwd_weight_desc = fwd.fwd_pd.weights_desc();
+    weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
     weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
   } else {
-    weight_mem = weight.GetMKLDNNData();
+    weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
     CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
   }
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(), req[fullc::kOut]);
@@ -114,7 +117,9 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
 
   const mkldnn::memory* bias_mem = nullptr;
   if (!param.no_bias) {
-    bias_mem              = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
+    auto fwd_bias_desc = fwd.fwd_pd.bias_desc();
+    bias_mem =
+        static_cast<const mkldnn::memory*>(quantized_bias.GetMKLDNNDataReorder(&fwd_bias_desc));
     args[MKLDNN_ARG_BIAS] = *bias_mem;
   }
 
diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
index f93d87b..9292615 100644
--- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
@@ -80,7 +80,7 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
   if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
     in_buffer = inputs[0].Reorder2Default();
 
-  auto i_mem            = in_buffer.GetMKLDNNData();
+  auto i_mem            = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
   auto i_desc           = i_mem->get_desc();
   auto o_desc           = i_desc;
   o_desc.data.data_type = get_mkldnn_type_t<DstType>();
diff --git a/src/operator/subgraph/mkldnn/mkldnn_common.h b/src/operator/subgraph/mkldnn/mkldnn_common.h
index c4f2857..e4fe4e0 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_common.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_common.h
@@ -100,8 +100,8 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
                                      const std::vector<float>& weight_scales,
                                      const bool submit = true) {
   MKLDNNStream* stream           = MKLDNNStream::Get();
-  const auto new_weight          = NDArray(weight_md);
-  const auto conv_weights_memory = new_weight.GetMKLDNNData();
+  const auto new_weight          = NDArray(&weight_md);
+  const auto conv_weights_memory = static_cast<const mkldnn::memory*>(new_weight.GetMKLDNNData());
   mkldnn::primitive_attr weight_attr;
   if (weight_scales.size()) {
     const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1;
@@ -109,7 +109,7 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
   }
   auto default_weights_memory = GetWeights(*weight, num_group);
   if (default_weights_memory == nullptr)
-    default_weights_memory = weight->GetMKLDNNData();
+    default_weights_memory = static_cast<const mkldnn::memory*>(weight->GetMKLDNNData());
   const auto weight_reorder_pd =
       mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr);
   MKLDNNStream::Get()->RegisterPrimArgs(
@@ -121,12 +121,12 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
     for (size_t c = 0; c < weight_scales.size(); ++c) {
       bias_scales[c] = weight_scales[c] * data_scale;
     }
-    new_bias                    = NDArray(*bias_md);
-    const auto conv_bias_memory = new_bias.GetMKLDNNData();
+    new_bias                    = NDArray(bias_md);
+    const auto conv_bias_memory = static_cast<const mkldnn::memory*>(new_bias.GetMKLDNNData());
     const int bias_mask         = (bias_scales.size()) == 1 ? 0 : 1;
     mkldnn::primitive_attr bias_attr;
     bias_attr.set_output_scales(bias_mask, bias_scales);
-    auto bias_weights_memory = bias->GetMKLDNNData();
+    auto bias_weights_memory = static_cast<const mkldnn::memory*>(bias->GetMKLDNNData());
     const auto bias_reorder_pd =
         mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr);
     MKLDNNStream::Get()->RegisterPrimArgs(
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index 95a997f..1e492b4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -159,17 +159,17 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
   // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
   if (mkldnn_param.with_sum) {
     if (!initialized_) {
-      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace
-      // option, which make check (req[kOut] == kWriteInplace) useless.
-      auto in_mkl_mem  = inputs[in_sum].GetMKLDNNData();
-      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
+      // which make check (req[kOut] == kWriteInplace) useless.
+      auto in_mkl_mem  = static_cast<const mkldnn::memory*>(inputs[in_sum].GetMKLDNNData());
+      auto out_mkl_mem = static_cast<const mkldnn::memory*>(outputs[kOut].GetMKLDNNData());
       if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
         inplace_ = true;
       }
     }
     if (!inplace_) {
-      auto in_mkl_mem  = inputs[in_sum].GetMKLDNNData();
-      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+      auto in_mkl_mem  = static_cast<const mkldnn::memory*>(inputs[in_sum].GetMKLDNNData());
+      auto out_mkl_mem = static_cast<const mkldnn::memory*>(outputs[kOut].GetMKLDNNData());
       if (outputs[kOut].dtype() == mshadow::kInt32) {
         const auto& mem_desc  = in_mkl_mem->get_desc();
         const auto this_dtype = get_mkldnn_type(mshadow::kInt32);
@@ -337,20 +337,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
                              full_conv_param.conv_param.num_group,
                              data_scale_,
                              weight_scales_);
-    args_[MKLDNN_ARG_SRC]     = *data.GetMKLDNNData();
-    args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
+    args_[MKLDNN_ARG_SRC]     = *static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+    args_[MKLDNN_ARG_WEIGHTS] = *static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
     if (has_bias)
-      args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
-    args_[MKLDNN_ARG_DST] = *output.GetMKLDNNData();
+      args_[MKLDNN_ARG_BIAS] = *static_cast<const mkldnn::memory*>(cached_bias_.GetMKLDNNData());
+    args_[MKLDNN_ARG_DST] = *static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
     initialized_          = true;
   }
 
   if (mkldnn_param.with_sum) {
-    const auto& output_mem   = output.GetMKLDNNData();
+    const auto& output_mem   = static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
     const auto& out_mem_desc = output_mem->get_desc();
     const auto& dst_mem_desc = fwd_->GetPd().dst_desc();
     if (out_mem_desc != dst_mem_desc) {
-      auto tmp_out_mem       = output.GetMKLDNNDataReorder(fwd_->GetPd().dst_desc());
+      auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+      auto tmp_out_mem =
+          static_cast<const mkldnn::memory*>(output.GetMKLDNNDataReorder(&fwd_dst_desc));
       auto data_md           = dst_mem_desc;
       data_md.data.data_type = static_cast<mkldnn_data_type_t>(out_mem_desc.data.data_type);
       mkldnn_mem_ptr new_out_mem(new mkldnn::memory(
@@ -362,8 +364,10 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
   }
 
   if (mkldnn_param.quantized) {
-    auto data_mem         = data.GetMKLDNNDataReorder(fwd_->GetPd().src_desc());
-    mkldnn::memory* mem   = output.CreateMKLDNNData(fwd_->GetPd().dst_desc());
+    auto fwd_src_desc = fwd_->GetPd().src_desc();
+    auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
+    auto fwd_pd_dst_desc  = fwd_->GetPd().dst_desc();
+    mkldnn::memory* mem   = static_cast<mkldnn::memory*>(output.CreateMKLDNNData(&fwd_pd_dst_desc));
     args_[MKLDNN_ARG_SRC] = *data_mem;
     args_[MKLDNN_ARG_DST] = *mem;
     MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
@@ -384,8 +388,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
     *outputs[kMax].data().dptr<float>() = cached_output_max_;
   }
   if (mkldnn_param.with_sum) {
-    auto out = const_cast<NDArray&>(outputs[kOut]);
-    out.UpdateMKLDNNMemDesc(fwd_->GetPd().dst_desc());
+    auto out          = const_cast<NDArray&>(outputs[kOut]);
+    auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+    out.UpdateMKLDNNMemDesc(&fwd_dst_desc);
   }
 }
 
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 5578106..123d491 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -113,8 +113,7 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
   const int out_index     = index++;
   const int out_min_index = out_quantized ? index++ : 0;
   const int out_max_index = out_quantized ? index++ : 0;
-  CHECK_EQ(out_data.size(),
-           index);  // index is equal to total number of outpits
+  CHECK_EQ(out_data.size(), index);  // index is equal to total number of outpits
 
   float min_data   = 0.0f;
   float max_data   = 0.0f;
@@ -132,10 +131,10 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
 
   if (mkldnn_param.with_sum) {
     if (!initialized_) {
-      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace
-      // option, which make check (req[out_index] == kWriteInplace) useless.
-      auto in_mkl_mem  = in_data[idx.sum].GetMKLDNNData();
-      auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
+      // which make check (req[out_index] == kWriteInplace) useless.
+      auto in_mkl_mem  = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
+      auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
       if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
         inplace_ = true;
       }
@@ -144,8 +143,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       output = in_data[idx.sum];
     } else {
       // Not in place: copy in_data[idx.sum] into outputs[out_index].
-      auto in_mkl_mem  = in_data[idx.sum].GetMKLDNNData();
-      auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+      auto in_mkl_mem  = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
+      auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
       if (out_data[out_index].dtype() == mshadow::kInt32) {
         auto mem_desc           = in_mkl_mem->get_desc();
         auto this_dtype         = get_mkldnn_type(mshadow::kInt32);
@@ -263,8 +262,7 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
 
       bool fuse_requantize = false;
-      // Channelwise scaling is only supported when fusion is enabled
-      // (requantize or dequantize).
+      // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
       if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) {
         cached_min_output_        = mkldnn_param.min_calib_range.value();
         cached_max_output_        = mkldnn_param.max_calib_range.value();
@@ -279,16 +277,13 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       // True          False                      Error
       // False         True/False                 False
       if (channel_wise && !support_channelwise_scale) {
-        LOG(FATAL) << "Currently, channel-wise quantization requires fuse requantize "
-                      "or dequantize."
-                   << " Please make sure the `min_calib_range` and `max_calib_range` "
-                      "are set when only"
-                   << " fuse requantize (outputs of FullyConnected are collected "
-                      "during calibration "
-                      "phase),"
-                   << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
-                   << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true "
-                      "(default is false)";
+        LOG(FATAL)
+            << "Currently, channel-wise quantization requires fuse requantize or dequantize."
+            << " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
+            << " fuse requantize (outputs of FullyConnected are collected during calibration "
+               "phase),"
+            << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
+            << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
       }
       support_channelwise_scale = support_channelwise_scale && channel_wise;
 
@@ -423,10 +418,11 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
                                weight_scales_,
                                false);
     } else {
-      const auto def_weight_mem = weight.GetMKLDNNData();
+      const auto def_weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
       if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
-        cached_weight_         = NDArray(fwd_->fwd_pd.weights_desc());
-        auto cached_weight_mem = cached_weight_.GetMKLDNNData();
+        auto weight_desc       = fwd_->fwd_pd.weights_desc();
+        cached_weight_         = NDArray(&weight_desc);
+        auto cached_weight_mem = static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
         std::unordered_map<int, mkldnn::memory> args(
             {{MKLDNN_ARG_FROM, *def_weight_mem}, {MKLDNN_ARG_TO, *cached_weight_mem}});
         MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*def_weight_mem, *cached_weight_mem),
@@ -434,23 +430,24 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       }
     }
 
-    const auto data_mem = data.GetMKLDNNData();
+    const auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
     cached_data_mem_    = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);
 
     args_[MKLDNN_ARG_SRC]     = *cached_data_mem_;
-    args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
+    args_[MKLDNN_ARG_WEIGHTS] = *static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
     if (has_bias)
-      args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
+      args_[MKLDNN_ARG_BIAS] = *static_cast<const mkldnn::memory*>(cached_bias_.GetMKLDNNData());
     args_[MKLDNN_ARG_DST] = *cached_out_mem_;
     initialized_          = true;
   }
 
   if (mkldnn_param.with_sum) {
-    const auto& output_mem   = output.GetMKLDNNData();
+    const auto& output_mem   = static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
     const auto& out_mem_desc = output_mem->get_desc();
     auto dst_mem_desc        = fwd_->fwd_pd.dst_desc();
     if (out_mem_desc != dst_mem_desc) {
-      auto tmp_out_mem            = output.GetMKLDNNDataReorder(dst_mem_desc);
+      auto tmp_out_mem =
+          static_cast<const mkldnn::memory*>(output.GetMKLDNNDataReorder(&dst_mem_desc));
       dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
       mkldnn_mem_ptr new_out_mem(new mkldnn::memory(
           dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle()));
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index d1e5758..a1a3d9f 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -46,7 +46,7 @@ static void AMPCastExCPU(const nnvm::NodeAttrs& attrs,
     mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
     if (data.IsView() && data.IsMKLDNNData())
       data = data.Reorder2Default();
-    const auto i_mem            = data.GetMKLDNNData();
+    const auto i_mem            = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
     const size_t i_ndim         = data.shape().ndim();
     mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
     for (size_t i = 0; i < i_ndim; i++) {
@@ -94,7 +94,7 @@ static void AMPMultiCastExCPU(const nnvm::NodeAttrs& attrs,
     auto data = inputs[i];
     if (data.IsView() && data.IsMKLDNNData())
       data = data.Reorder2Default();
-    const auto i_mem            = data.GetMKLDNNData();
+    const auto i_mem            = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
     const size_t i_ndim         = data.shape().ndim();
     mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
     for (size_t j = 0; j < i_ndim; j++) {
@@ -170,8 +170,7 @@ NNVM_REGISTER_OP(_backward_amp_cast)
     .set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>);
 
 NNVM_REGISTER_OP(amp_multicast)
-    .describe(
-        R"code(Cast function used by AMP, that casts its inputs to the common widest type.
+    .describe(R"code(Cast function used by AMP, that casts its inputs to the common widest type.
 
 It casts only between low precision float/FP32 and does not do anything for other types.
 
diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h
index e24f203..b4ee576 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -410,8 +410,8 @@ void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, const ND
       // data first.
       if (input.IsMKLDNNData() && input.IsView())
         tmp_input = input.Reorder2Default();
-      const mkldnn::memory* in_mem = tmp_input.GetMKLDNNData();
-      const_cast<NDArray&>(output).CopyFrom(*in_mem);
+      const mkldnn::memory* in_mem = static_cast<const mkldnn::memory*>(tmp_input.GetMKLDNNData());
+      const_cast<NDArray&>(output).CopyFrom(in_mem);
       MKLDNNStream::Get()->Submit();
     } else {
       mxnet_op::copy(ctx.get_stream<xpu>(), output.data(), input.data());
diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index 9c32f15..0cda299 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -86,7 +86,7 @@ inline static void InitMKLDNNArray(NDArray* arr,
                                    bool is_rand = false,
                                    int max      = 50) {
   InitDefaultArray(arr, is_rand, max);
-  arr->MKLDNNDataReorderAsync(desc);
+  arr->MKLDNNDataReorderAsync(&desc);
   arr->WaitToRead();
 }
 
@@ -352,8 +352,8 @@ inline void PrintVerifyMsg(const NDArrayAttrs& arr1, const NDArrayAttrs& arr2) {
  * think we should pass them to all operators. In the inference mode, the MKLDNN
  * memory in the weight array will be reordered to 5 dimensions.
  *
- *  num_inputs / dim arguments used to scale shape (used for concat backwards to
- * enlarge input shapes)
+ *  num_inputs / dim arguments used to scale shape (used for concat backwards to enlarge input
+ * shapes)
  */
 inline std::vector<NDArrayAttrs> GetTestInputArrays(int types                = ArrayTypes::All,
                                                     bool rand                = false,
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index 41a2ffe..43e010a 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -1209,7 +1209,8 @@ void TestPoolingOp(const OpAttrs& forward_attrs, const OpAttrs& backwards_attrs)
       continue;
     // cannot pool if ndarray and mkldnn memory have different ndim
     if (in_arr.arr.IsView() ||
-        in_arr.arr.GetMKLDNNData()->get_desc().data.ndims != in_arr.arr.shape().ndim())
+        static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData())->get_desc().data.ndims !=
+            in_arr.arr.shape().ndim())
       continue;
     std::vector<float> scale_vector(in_arr.arr.shape().ndim());
     for (int i = 0; i < in_arr.arr.shape().ndim(); i++) {
diff --git a/tests/cpp/operator/mkldnn_test.cc b/tests/cpp/operator/mkldnn_test.cc
index 1f8fe5f..ef516b7 100644
--- a/tests/cpp/operator/mkldnn_test.cc
+++ b/tests/cpp/operator/mkldnn_test.cc
@@ -26,25 +26,27 @@
 #if MXNET_USE_MKLDNN == 1
 
 #include <mkldnn_types.h>
-#include <cmath>
+
 #include <climits>
+#include <cmath>
 #include <set>
-#include "gtest/gtest.h"
-#include "mxnet/imperative.h"
-#include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h"
+
 #include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h"
 #include "../include/test_mkldnn.h"
+#include "gtest/gtest.h"
+#include "mxnet/imperative.h"
 
 using namespace mxnet;
 
 #if __GNUC__ >= 5
-bool test_mem_align(void *mem, size_t size, size_t alignment, size_t space) {
+bool test_mem_align(void* mem, size_t size, size_t alignment, size_t space) {
   void *ret1, *ret2;
   size_t space1, space2;
   space1 = space;
   space2 = space;
-  ret1 = mxnet::AlignMem(mem, size, alignment, &space1);
-  ret2 = std::align(alignment, size, mem, space2);
+  ret1   = mxnet::AlignMem(mem, size, alignment, &space1);
+  ret2   = std::align(alignment, size, mem, space2);
   EXPECT_EQ(ret1, ret2);
   EXPECT_EQ(space1, space2);
   return ret1 == ret2;
@@ -54,29 +56,29 @@ bool test_mem_align(void *mem, size_t size, size_t alignment, size_t space) {
 TEST(MKLDNN_UTIL_FUNC, AlignMem) {
 #if __GNUC__ >= 5
   size_t alignment = 4096;
-  void *mem;
+  void* mem;
   size_t size, space;
   // When mem has been aligned.
-  mem = reinterpret_cast<void *>(0x10000);
-  size = 1000;
+  mem   = reinterpret_cast<void*>(0x10000);
+  size  = 1000;
   space = 10000;
   test_mem_align(mem, size, alignment, space);
 
   // When mem isn't aligned and we have enough space for alignment.
-  mem = reinterpret_cast<void *>(0x10010);
-  size = 1000;
+  mem   = reinterpret_cast<void*>(0x10010);
+  size  = 1000;
   space = 10000;
   test_mem_align(mem, size, alignment, space);
 
   // When mem isn't aligned and we don't have enough memory for alignment
-  mem = reinterpret_cast<void *>(0x10010);
-  size = 1000;
+  mem   = reinterpret_cast<void*>(0x10010);
+  size  = 1000;
   space = 1001;
   test_mem_align(mem, size, alignment, space);
 
   for (size_t i = 0; i < 10000; i++) {
-    mem = reinterpret_cast<void *>(random());
-    size = random() % 2000;
+    mem   = reinterpret_cast<void*>(random());
+    size  = random() % 2000;
     space = random() % 2000;
     test_mem_align(mem, size, alignment, space);
   }
@@ -87,12 +89,11 @@ TEST(MKLDNN_UTIL_FUNC, AlignMem) {
 #endif
 }
 
-static void VerifyDefMem(const mkldnn::memory &mem) {
-  mkldnn::memory::desc desc = mem.get_desc();
-  mshadow::default_real_t *data
-      = static_cast<mshadow::default_real_t *>(mem.get_data_handle());
-  size_t size = desc.get_size() / sizeof(mshadow::default_real_t);
-  size_t num_same = 0;
+static void VerifyDefMem(const mkldnn::memory& mem) {
+  mkldnn::memory::desc desc     = mem.get_desc();
+  mshadow::default_real_t* data = static_cast<mshadow::default_real_t*>(mem.get_data_handle());
+  size_t size                   = desc.get_size() / sizeof(mshadow::default_real_t);
+  size_t num_same               = 0;
   for (int i = 0; i < size; i++)
     num_same += data[i] == static_cast<mshadow::default_real_t>(i % 100 - 50);
   EXPECT_EQ(num_same, size);
@@ -105,14 +106,14 @@ TEST(MKLDNN_UTIL_FUNC, MemFormat) {
   CHECK_EQ(mkldnn_oihw, 5);
 }
 
-static void VerifyMem(const mkldnn::memory &mem) {
+static void VerifyMem(const mkldnn::memory& mem) {
   mkldnn::memory::desc desc = mem.get_desc();
   mkldnn::memory::dims dims(desc.data.ndims);
   for (size_t i = 0; i < dims.size(); i++)
     dims[i] = desc.data.dims[i];
   mkldnn::memory::desc new_desc{dims,
-      static_cast<mkldnn::memory::data_type>(desc.data.data_type),
-      static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(desc))};
+                                static_cast<mkldnn::memory::data_type>(desc.data.data_type),
+                                static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(desc))};
 
   if (desc == new_desc) {
     VerifyDefMem(mem);
@@ -121,26 +122,25 @@ static void VerifyMem(const mkldnn::memory &mem) {
     mkldnn::memory new_mem(new_desc, CpuEngine::Get()->get_engine());
 
     mkldnn::stream s(CpuEngine::Get()->get_engine());
-    mkldnn::reorder(*src_mem, new_mem)
-        .execute(s, *src_mem, new_mem);
+    mkldnn::reorder(*src_mem, new_mem).execute(s, *src_mem, new_mem);
 
     VerifyDefMem(new_mem);
   }
 }
 
 TEST(MKLDNN_NDArray, GetDataReorder) {
-  TestArrayShapes tas = GetTestArrayShapes();
-  mxnet::ShapeVector shapes = tas.shapes;
+  TestArrayShapes tas                   = GetTestArrayShapes();
+  mxnet::ShapeVector shapes             = tas.shapes;
   std::vector<mkldnn::memory::desc> mds = tas.mds;
 
-
   // Reorder from the default to any other layout.
   for (auto s : shapes) {
     NDArray arr(s, Context());
     InitDefaultArray(&arr);
     for (auto md : mds) {
       if (s.Size() == md.get_size() / sizeof(mshadow::default_real_t)) {
-        const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(md);
+        const mkldnn::memory* mem =
+            static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&md));
         printf("reorder from (");
         for (size_t i = 0; i < s.ndim(); i++)
           printf("%ld, ", s[i]);
@@ -172,7 +172,8 @@ TEST(MKLDNN_NDArray, GetDataReorder) {
         InitMKLDNNArray(&arr, md);
         for (auto to_md : mds) {
           if (to_md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) {
-            const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(to_md);
+            const mkldnn::memory* mem =
+                static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&to_md));
             printf("reorder from (");
             for (size_t i = 0; i < s.ndim(); i++)
               printf("%ld, ", s[i]);
@@ -191,13 +192,13 @@ TEST(MKLDNN_NDArray, GetDataReorder) {
 }
 
 TEST(MKLDNN_BASE, MKLDNNSum) {
-  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
-  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<NDArrayAttrs> in_arrs     = GetTestInputArrays();
+  std::vector<NDArrayAttrs> in_arrs2    = GetTestInputArrays(ArrayTypes::All, true);
+  TestArrayShapes tas                   = GetTestArrayShapes();
   std::vector<mkldnn::memory::desc> mds = tas.mds;
 
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
@@ -205,12 +206,12 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
-    for (auto &out_arr : out_arrs) {
-      auto in_mem1 = in_arr.arr.GetMKLDNNData();
-      auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+    for (auto& out_arr : out_arrs) {
+      auto in_mem1 = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+      auto in_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
       if (out_arr.arr.IsView())
         continue;
-      auto out_mem = out_arr.arr.GetMKLDNNData();
+      auto out_mem = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
       PrintVerifyMsg(in_arr, in_arr);
       op::MKLDNNSum(*in_mem1, *in_mem2, *out_mem);
       MKLDNNStream::Get()->Submit();
@@ -220,20 +221,20 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
 
   // in place
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
     if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
       continue;
     }
-    auto input_mem = in_arr.arr.GetMKLDNNData();
-    auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+    auto input_mem  = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+    auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
     NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
     orig_arr.arr.WaitToRead();
     PrintVerifyMsg(orig_arr, in_arr);
     InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
-    orig_arr.arr.CopyFrom(*input_mem);
+    orig_arr.arr.CopyFrom(input_mem);
     op::MKLDNNSum(*input_mem, *input_mem2, *input_mem);
     MKLDNNStream::Get()->Submit();
     VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
@@ -241,15 +242,15 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
 }
 
 TEST(MKLDNN_BASE, CreateMKLDNNMem) {
-  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
-  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<NDArrayAttrs> in_arrs     = GetTestInputArrays();
+  std::vector<NDArrayAttrs> in_arrs2    = GetTestInputArrays(ArrayTypes::All, true);
+  TestArrayShapes tas                   = GetTestArrayShapes();
   std::vector<mkldnn::memory::desc> mds = tas.mds;
-  MKLDNNStream *stream = MKLDNNStream::Get();
+  MKLDNNStream* stream                  = MKLDNNStream::Get();
 
   // kWriteTo
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
@@ -257,13 +258,13 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
-    for (auto &out_arr : out_arrs) {
-      auto in_mem = in_arr.arr.GetMKLDNNData();
-      auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+    for (auto& out_arr : out_arrs) {
+      auto in_mem         = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+      auto in_mem2        = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
       NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
       orig_output.WaitToRead();
       PrintVerifyMsg(in_arr, out_arr);
-      auto out_mem = out_arr.arr.GetMKLDNNData();
+      auto out_mem      = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
       auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kWriteTo);
       op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second);
       CommitOutput(out_arr.arr, output_mem_t);
@@ -274,22 +275,22 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
 
   // kWriteInPlace
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
     if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
       continue;
     }
-    auto input_mem = in_arr.arr.GetMKLDNNData();
-    auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+    auto input_mem  = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+    auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
     NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
     orig_arr.arr.WaitToRead();
     PrintVerifyMsg(orig_arr, in_arr);
     InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
-    orig_arr.arr.CopyFrom(*input_mem);
-    auto output_mem_t = CreateMKLDNNMem(in_arr.arr,
-        input_mem->get_desc(), kWriteInplace, &in_arr.arr);
+    orig_arr.arr.CopyFrom(input_mem);
+    auto output_mem_t =
+        CreateMKLDNNMem(in_arr.arr, input_mem->get_desc(), kWriteInplace, &in_arr.arr);
     op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second);
     CommitOutput(in_arr.arr, output_mem_t);
     stream->Submit();
@@ -298,7 +299,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
 
   // kAddTo
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
@@ -306,13 +307,13 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
-    for (auto &out_arr : out_arrs) {
-      auto in_mem = in_arr.arr.GetMKLDNNData();
-      auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+    for (auto& out_arr : out_arrs) {
+      auto in_mem         = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+      auto in_mem2        = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
       NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
       orig_output.WaitToRead();
       PrintVerifyMsg(in_arr, out_arr);
-      auto out_mem = out_arr.arr.GetMKLDNNData();
+      auto out_mem      = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
       auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kAddTo);
       op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second);
       CommitOutput(out_arr.arr, output_mem_t);
@@ -324,20 +325,20 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
 
   // kNullOp
   for (int i = 0; i < in_arrs.size(); i++) {
-    auto in_arr = in_arrs[i];
+    auto in_arr  = in_arrs[i];
     auto in_arr2 = in_arrs2[i];
     if (!SupportMKLDNN(in_arr.arr))
       continue;
     if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
       continue;
     }
-    auto input_mem = in_arr.arr.GetMKLDNNData();
-    auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+    auto input_mem  = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+    auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
     NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
     orig_arr.arr.WaitToRead();
     PrintVerifyMsg(orig_arr, in_arr);
     InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
-    orig_arr.arr.CopyFrom(*input_mem);
+    orig_arr.arr.CopyFrom(input_mem);
     auto output_mem_t = CreateMKLDNNMem(in_arr.arr, input_mem->get_desc(), kNullOp);
     op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second);
     CommitOutput(in_arr.arr, output_mem_t);
@@ -355,10 +356,10 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
       for (size_t i = 0; i < dim + 1; ++i)
         scale_vector[i] = 1;
       scale_vector[dim] = num_inputs;
-      std::vector<NDArrayAttrs> expanded_arrs = GetTestInputArrays(
-          ArrayTypes::All, false, scale_vector);
+      std::vector<NDArrayAttrs> expanded_arrs =
+          GetTestInputArrays(ArrayTypes::All, false, scale_vector);
       int i = 0;
-      for (auto &arr : in_arrs) {
+      for (auto& arr : in_arrs) {
         if (dim >= arr.arr.shape().ndim())
           continue;
         auto ex_arr = expanded_arrs[i];
@@ -372,22 +373,22 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
 }
 
 TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) {
-  auto shapes_pds = GetTestArrayShapes();
-  std::vector<mxnet::TShape> shapes = shapes_pds.shapes;
+  auto shapes_pds                       = GetTestArrayShapes();
+  std::vector<mxnet::TShape> shapes     = shapes_pds.shapes;
   std::vector<mkldnn::memory::desc> mds = shapes_pds.mds;
-  for (auto &shape : shapes) {
+  for (auto& shape : shapes) {
     for (int dim = 0; dim < 5; dim++) {
       for (int num_inputs = 2; num_inputs < 5; num_inputs++) {
         if (shape.ndim() <= dim)
           continue;
-        std::cout << "Extending " << shape << " dim " <<
-                  dim << " and " << num_inputs << "num_inputs\n";
+        std::cout << "Extending " << shape << " dim " << dim << " and " << num_inputs
+                  << "num_inputs\n";
         std::vector<float> scale_vector(shape.ndim());
         for (int i = 0; i < shape.ndim(); i++)
           scale_vector[i] = 1;
         scale_vector[dim] = num_inputs;
-        auto output_arrs = GetTestOutputArrays(shape, mds, scale_vector);
-        for (auto &out_arr : output_arrs) {
+        auto output_arrs  = GetTestOutputArrays(shape, mds, scale_vector);
+        for (auto& out_arr : output_arrs) {
           auto out_shape = out_arr.arr.shape();
           EXPECT_EQ(shape.Size() * num_inputs, out_arr.arr.shape().Size());
           EXPECT_EQ(shape[dim] * num_inputs, out_arr.arr.shape()[dim]);
@@ -398,19 +399,19 @@ TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) {
 }
 
 TEST(MKLDNN_NDArray, CopyFrom) {
-  TestArrayShapes tas = GetTestArrayShapes();
+  TestArrayShapes tas                   = GetTestArrayShapes();
   std::vector<mkldnn::memory::desc> mds = tas.mds;
 
   std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  for (auto &in_arr : in_arrs) {
+  for (auto& in_arr : in_arrs) {
     if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
       continue;
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
-    for (auto &out_arr : out_arrs) {
-      const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData();
-      out_arr.arr.CopyFrom(*mem);
+    for (auto& out_arr : out_arrs) {
+      const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+      out_arr.arr.CopyFrom(mem);
       MKLDNNStream::Get()->Submit();
-      std::vector<NDArray *> inputs(1);
+      std::vector<NDArray*> inputs(1);
       inputs[0] = &in_arr.arr;
       VerifyCopyResult(inputs, {&out_arr.arr});
     }