You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/02 06:41:03 UTC

[incubator-mxnet] branch master updated: [master] Remove dnnl_ops-inl.h file (#20997)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 102388a055 [master] Remove dnnl_ops-inl.h file (#20997)
102388a055 is described below

commit 102388a0557c530741ed8e9b31296416a1c23925
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Thu Jun 2 08:40:48 2022 +0200

    [master] Remove dnnl_ops-inl.h file (#20997)
    
    * Removed dnnl_ops-inl.h
    
    * Removed dnnl_ops-inl.h file from tracking
    
    * Formatted new files
    
    * Added missing import
    
    * Applied changes from review
    
    * Linting
    
    * Refactored DNNLSumFwd class
    
    * Linting
    
    * Fixed issue with DNNLSumFwd
    
    * Removed unused variables
    
    * Fixed oneDNN usage in refactored files
    
    * Fix sanity
    
    * Changed GetConcatForward to GetCache dand placed it in the DNNLConcatFwd classD
    
    * Removed missed comment
    
    * Moved DNNLSum function from dnnl_sum files to dnnl_base and renamed it to DNNLMemorySum
---
 src/operator/leaky_relu.cc                         |   4 +-
 src/operator/nn/activation.cc                      |   4 +-
 src/operator/nn/concat-inl.h                       |   1 -
 src/operator/nn/concat.cc                          |   6 +-
 src/operator/nn/convolution-inl.h                  |   1 -
 src/operator/nn/convolution.cc                     |   4 +-
 src/operator/nn/deconvolution-inl.h                |   1 -
 src/operator/nn/deconvolution.cc                   |   4 +-
 src/operator/nn/dnnl/dnnl_act-inl.h                |  25 +++
 src/operator/nn/dnnl/dnnl_base-inl.h               |   1 +
 src/operator/nn/dnnl/dnnl_base.cc                  |  33 ++-
 src/operator/nn/dnnl/dnnl_batch_dot-inl.h          |   1 -
 src/operator/nn/dnnl/dnnl_batch_norm-inl.h         |   1 -
 src/operator/nn/dnnl/dnnl_binary-inl.h             |   1 -
 src/operator/nn/dnnl/dnnl_concat-inl.h             |  37 ++--
 src/operator/nn/dnnl/dnnl_concat.cc                |  25 ++-
 src/operator/nn/dnnl/dnnl_convolution-inl.h        |  13 +-
 src/operator/nn/dnnl/dnnl_convolution.cc           |   1 -
 .../nn/dnnl/{dnnl_copy.cc => dnnl_copy-inl.h}      |  36 +---
 src/operator/nn/dnnl/dnnl_copy.cc                  |   3 +-
 src/operator/nn/dnnl/dnnl_deconvolution-inl.h      |  54 ++---
 src/operator/nn/dnnl/dnnl_deconvolution.cc         |  41 ++++
 src/operator/nn/dnnl/dnnl_fully_connected-inl.h    |   6 +
 src/operator/nn/dnnl/dnnl_layer_norm-inl.h         |  13 +-
 src/operator/nn/dnnl/dnnl_log_softmax.cc           |   1 -
 src/operator/nn/dnnl/dnnl_masked_softmax-inl.h     |   1 -
 src/operator/nn/dnnl/dnnl_ops-inl.h                | 230 ---------------------
 src/operator/nn/dnnl/dnnl_power_scalar-inl.h       |   7 +-
 src/operator/nn/dnnl/dnnl_reduce-inl.h             |   1 -
 src/operator/nn/dnnl/dnnl_reshape-inl.h            |   7 +
 src/operator/nn/dnnl/dnnl_reshape.cc               |   1 -
 src/operator/nn/dnnl/dnnl_slice.cc                 |   1 -
 src/operator/nn/dnnl/dnnl_softmax-inl.h            |  31 ++-
 ...nnl_reshape-inl.h => dnnl_softmax_output-inl.h} |  41 ++--
 src/operator/nn/dnnl/dnnl_softmax_output.cc        |   1 -
 src/operator/nn/dnnl/dnnl_split-inl.h              |   7 +-
 .../dnnl/{dnnl_reshape-inl.h => dnnl_stack-inl.h}  |  40 ++--
 src/operator/nn/dnnl/dnnl_stack.cc                 |   5 +-
 src/operator/nn/dnnl/dnnl_sum-inl.h                |  72 +++++++
 src/operator/nn/dnnl/dnnl_sum.cc                   | 106 ++++------
 src/operator/nn/dnnl/dnnl_transpose-inl.h          |   1 -
 src/operator/nn/dnnl/dnnl_where-inl.h              |   7 +-
 src/operator/nn/fully_connected.cc                 |   5 +-
 src/operator/nn/layer_norm.cc                      |   4 +-
 src/operator/nn/log_softmax.cc                     |   4 +-
 src/operator/nn/masked_softmax.cc                  |   2 +-
 src/operator/nn/softmax.cc                         |   4 +-
 src/operator/numpy/np_matrix_op.cc                 |   5 +-
 .../quantization/dnnl/dnnl_quantized_act.cc        |   3 +-
 .../quantization/dnnl/dnnl_quantized_concat.cc     |   2 +-
 .../dnnl/dnnl_quantized_elemwise_add.cc            |   1 -
 .../quantization/dnnl/dnnl_quantized_flatten.cc    |   2 +-
 .../quantization/dnnl/dnnl_quantized_reshape.cc    |   2 +-
 src/operator/quantization/quantized_conv.cc        |   3 -
 src/operator/softmax_output.cc                     |   4 +-
 src/operator/subgraph/dnnl/dnnl_batch_dot.cc       |   1 -
 src/operator/subgraph/dnnl/dnnl_conv.cc            |   1 -
 src/operator/subgraph/dnnl/dnnl_conv_property.h    |   1 -
 src/operator/subgraph/dnnl/dnnl_fc.cc              |   1 -
 src/operator/tensor/elemwise_binary_op_basic.cc    |   5 +-
 src/operator/tensor/elemwise_sum.cc                |   6 +-
 src/operator/tensor/elemwise_unary_op_basic.cc     |   4 +-
 src/operator/tensor/elemwise_unary_op_logexp.cc    |   1 -
 src/operator/tensor/elemwise_unary_op_pow.cc       |   1 -
 src/operator/tensor/matrix_op.cc                   |  12 +-
 tests/cpp/operator/dnnl_operator_test.cc           |   1 -
 tests/cpp/operator/dnnl_test.cc                    |  15 +-
 67 files changed, 444 insertions(+), 522 deletions(-)

diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 39aa11dfb9..1376efba52 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -26,8 +26,8 @@
 #include "./leaky_relu-inl.h"
 #include "../common/alm.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./nn/dnnl/dnnl_base-inl.h"
-#include "./nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
 #endif  // MXNET_USE_ONEDNN == 1
 
 #include <nnvm/op_attr_types.h>
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index a228bf8a76..46ecdb2b69 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -26,8 +26,8 @@
 #include "../mshadow_op.h"
 #include "../tensor/elemwise_unary_op.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
 #endif  // MXNET_USE_ONEDNN == 1
 #include "../operator_common.h"
 #include "../../common/utils.h"
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 01cde26de2..5886d60386 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -388,7 +388,6 @@ void ConcatCSRImpl(const nnvm::NodeAttrs& attrs,
     });
   });
 }
-
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 2329892048..70b8aeb9f8 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -25,8 +25,10 @@
 
 #include "../../common/utils.h"
 #include "./concat-inl.h"
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_concat-inl.h"
+#endif  // MXNET_USE_ONEDNN == 1
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index 9994c7bed7..a034e44082 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -597,7 +597,6 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs,
     op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
   });
 }
-
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_NN_CONVOLUTION_INL_H_
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index a39fa3fa45..674ccf6e81 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -30,8 +30,8 @@
 #include "../operator_common.h"
 #include "../../common/alm.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_convolution-inl.h"
 #endif  // MXNET_USE_ONEDNN
 
 namespace mxnet {
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index 9c9f2c8138..d5fb132d88 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -428,7 +428,6 @@ void DeconvolutionGradCompute(const nnvm::NodeAttrs& attrs,
   const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
   _DeconvolutionGradCompute<xpu>(param, ctx, inputs, req, outputs);
 }
-
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_NN_DECONVOLUTION_INL_H_
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 2bef3fc898..e65df297de 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -28,8 +28,8 @@
 #include "../../common/alm.h"
 #include "../../common/utils.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_deconvolution-inl.h"
 #endif  // MXNET_USE_ONEDNN
 
 namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_act-inl.h b/src/operator/nn/dnnl/dnnl_act-inl.h
index 66f229962f..fcc896414b 100644
--- a/src/operator/nn/dnnl/dnnl_act-inl.h
+++ b/src/operator/nn/dnnl/dnnl_act-inl.h
@@ -95,6 +95,31 @@ class DNNLActBackward {
  private:
   std::shared_ptr<dnnl::eltwise_backward> bwd_prim_;
 };
+
+void DNNLActivationForward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const NDArray& in_data,
+                           const OpReqType& req,
+                           const NDArray& out_data);
+
+void DNNLActivationBackward(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& outputs);
+
+void DNNLLeakyReluForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const NDArray& in_data,
+                          const OpReqType& req,
+                          const NDArray& out_data);
+
+void DNNLLeakyReluBackward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h
index c38895bb1e..8161d62c24 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -211,6 +211,7 @@ bool SupportDNNLStack(const std::vector<NDArray>& inputs);
 bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
 bool SupportDNNLEltwise(const NDArray& input, const NDArray& output);
 bool SupportDNNLPower(const NDArray& input);
+void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 216a420137..c7e3a92d21 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -24,7 +24,6 @@
 #include "../../../common/exec_utils.h"
 #include "operator/operator_common.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 
@@ -37,6 +36,36 @@ DNNLStream* DNNLStream::Get() {
   return &stream;
 }
 
+namespace op {
+void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out) {
+  std::vector<dnnl::memory::desc> input_pds(2);
+  std::vector<float> scales(2, 1);
+  input_pds[0] = arr1.get_desc();
+  input_pds[1] = arr2.get_desc();
+  CHECK(input_pds[0] == input_pds[0]);
+  const dnnl::memory* in_mem1 = &arr1;
+  const dnnl::memory* in_mem2 = &arr2;
+  auto output_pd              = out.get_desc();
+  if (input_pds[0] != output_pd) {
+    auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd);
+    auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd);
+    DNNLMemoryCopy(arr1, tmp_memory1);
+    DNNLMemoryCopy(arr2, tmp_memory2);
+    input_pds[0] = tmp_memory1->get_desc();
+    input_pds[1] = tmp_memory2->get_desc();
+    in_mem1      = tmp_memory1;
+    in_mem2      = tmp_memory2;
+  }
+  dnnl::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine());
+  dnnl_args_map_t args = {
+      {DNNL_ARG_MULTIPLE_SRC, *in_mem1},
+      {DNNL_ARG_MULTIPLE_SRC + 1, *in_mem2},
+      {DNNL_ARG_DST, out},
+  };
+  DNNLStream::Get()->RegisterPrimArgs(dnnl::sum(sum_pd), args);
+}
+}  // namespace op
+
 void* AlignMem(void* mem, size_t size, size_t alignment, size_t* space) {
   if (size > *space)
     return nullptr;
@@ -222,7 +251,7 @@ void CommitOutput(const NDArray& arr, const dnnl_output_t& res) {
       res_memory = tmp_memory;
       mem        = arr.GetDNNLData();
     }
-    op::DNNLSum(*mem, *res_memory, *mem);
+    op::DNNLMemorySum(*mem, *res_memory, *mem);
   }
 }
 
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index b48afd1d36..19233828dc 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -33,7 +33,6 @@
 
 #include "operator/tensor/dot-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 8152b6c7a5..2780c9685f 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -33,7 +33,6 @@
 
 #include "operator/nn/batch_norm-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h b/src/operator/nn/dnnl/dnnl_binary-inl.h
index 2cf63aa9a4..b047f8d365 100644
--- a/src/operator/nn/dnnl/dnnl_binary-inl.h
+++ b/src/operator/nn/dnnl/dnnl_binary-inl.h
@@ -27,7 +27,6 @@
 
 #if MXNET_USE_ONEDNN == 1
 #include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
 #include <vector>
 
 #include "../../tensor/elemwise_binary_broadcast_op.h"
diff --git a/src/operator/nn/dnnl/dnnl_concat-inl.h b/src/operator/nn/dnnl/dnnl_concat-inl.h
index 2970908185..8530538544 100644
--- a/src/operator/nn/dnnl/dnnl_concat-inl.h
+++ b/src/operator/nn/dnnl/dnnl_concat-inl.h
@@ -31,7 +31,6 @@
 
 #include "operator/nn/concat-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -42,6 +41,11 @@ class DNNLConcatFwd {
 
   DNNLConcatFwd(int concat_dim, const std::vector<dnnl::memory::desc>& data_md);
 
+  static DNNLConcatFwd& GetCached(int concat_dim,
+                                  const std::vector<NDArray>& in_data,
+                                  const std::vector<dnnl::memory::desc>& data_md,
+                                  int stack_axis = -1 /*used only by stack op*/);
+
   const dnnl::concat& GetFwd() const {
     return *fwd_;
   }
@@ -50,28 +54,17 @@ class DNNLConcatFwd {
   std::shared_ptr<dnnl::concat> fwd_;
 };
 
-static DNNLConcatFwd& GetConcatForward(int concat_dim,
-                                       const std::vector<NDArray>& in_data,
-                                       const std::vector<dnnl::memory::desc>& data_md,
-                                       int stack_axis = -1 /*used only by stack op*/) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
-#endif
-
-  OpSignature key;
-  key.AddSign(concat_dim);
-  key.AddSign(stack_axis);
-  key.AddSign(in_data);
+void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<NDArray>& in_data,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<NDArray>& out_data);
 
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    DNNLConcatFwd fwd(concat_dim, data_md);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
-}
+void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<NDArray>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<NDArray>& outputs);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_concat.cc b/src/operator/nn/dnnl/dnnl_concat.cc
index 83ba9df245..1c7e73bd68 100644
--- a/src/operator/nn/dnnl/dnnl_concat.cc
+++ b/src/operator/nn/dnnl/dnnl_concat.cc
@@ -56,6 +56,29 @@ DNNLConcatFwd::DNNLConcatFwd(int concat_dim, const std::vector<dnnl::memory::des
   fwd_ = std::make_shared<dnnl::concat>(fwd_pd);
 }
 
+DNNLConcatFwd& DNNLConcatFwd::GetCached(int concat_dim,
+                                        const std::vector<NDArray>& in_data,
+                                        const std::vector<dnnl::memory::desc>& data_md,
+                                        int stack_axis /*used only by stack op*/) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
+#endif
+
+  OpSignature key;
+  key.AddSign(concat_dim);
+  key.AddSign(stack_axis);
+  key.AddSign(in_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    DNNLConcatFwd fwd(concat_dim, data_md);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
 void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
                        const std::vector<NDArray>& in_data,
@@ -76,7 +99,7 @@ void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
     data_md.push_back(tmp_md);
     data_mem.push_back(tmp_mem);
   }
-  DNNLConcatFwd& fwd = GetConcatForward(concat_dim, in_data, data_md);
+  DNNLConcatFwd& fwd = DNNLConcatFwd::GetCached(concat_dim, in_data, data_md);
   mxnet::dnnl_output_t out_mem =
       CreateDNNLMem(out_data[concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
   std::unordered_map<int, dnnl::memory> net_args;
diff --git a/src/operator/nn/dnnl/dnnl_convolution-inl.h b/src/operator/nn/dnnl/dnnl_convolution-inl.h
index f41af488b1..738be8214f 100644
--- a/src/operator/nn/dnnl/dnnl_convolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_convolution-inl.h
@@ -32,7 +32,6 @@
 
 #include "operator/nn/convolution-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -172,6 +171,18 @@ class DNNLConvBackward {
   std::shared_ptr<dnnl::convolution_backward_weights> bwd_weight_;
 };
 
+void DNNLConvolutionForward(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& in_data,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& out_data);
+
+void DNNLConvolutionBackward(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<NDArray>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc b/src/operator/nn/dnnl/dnnl_convolution.cc
index 60c65ffa47..ca6effb791 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -30,7 +30,6 @@
 #include "operator/nn/convolution-inl.h"
 #include "dnnl_base-inl.h"
 #include "dnnl_convolution-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy-inl.h
similarity index 55%
copy from src/operator/nn/dnnl/dnnl_copy.cc
copy to src/operator/nn/dnnl/dnnl_copy-inl.h
index 16cbabd19a..41362dfcf1 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy-inl.h
@@ -18,15 +18,18 @@
  */
 
 /*!
- * \file dnnl_copy.cc
+ * \file dnnl_copy-inl.h
  * \brief
- * \author
+ * \author Wolinski Piotr piotr.wolinski@intel.com
  */
 
-#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
 
 #if MXNET_USE_ONEDNN == 1
+
+#include <dnnl.hpp>
+
 namespace mxnet {
 namespace op {
 
@@ -34,27 +37,10 @@ void DNNLCopy(const nnvm::NodeAttrs& attrs,
               const OpContext& ctx,
               const NDArray& in_data,
               const OpReqType& req,
-              const NDArray& out_data) {
-  if (req == kNullOp || req == kWriteInplace)
-    return;
-  TmpMemMgr::Get()->Init(ctx.requested[0]);
-  auto in_mem = in_data.GetDNNLData();
-  if (req == kAddTo) {
-    TmpMemMgr::Get()->Init(ctx.requested[0]);
-    // We should try and force the input memory has the same format
-    // as the input output. If not, we'll have to reorder memory.
-    auto out_mem      = out_data.GetDNNLData();
-    auto out_mem_desc = out_mem->get_desc();
-    in_mem            = in_data.GetDNNLData(&out_mem_desc);
-    if (in_mem == nullptr)
-      in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
-    DNNLSum(*out_mem, *in_mem, *out_mem);
-  } else {
-    const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
-  }
-  DNNLStream::Get()->Submit();
-}
+              const NDArray& out_data);
 
 }  // namespace op
 }  // namespace mxnet
-#endif
+
+#endif  // MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy.cc
index 16cbabd19a..bdf5009bf5 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy.cc
@@ -24,7 +24,6 @@
  */
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 #if MXNET_USE_ONEDNN == 1
 namespace mxnet {
@@ -48,7 +47,7 @@ void DNNLCopy(const nnvm::NodeAttrs& attrs,
     in_mem            = in_data.GetDNNLData(&out_mem_desc);
     if (in_mem == nullptr)
       in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
-    DNNLSum(*out_mem, *in_mem, *out_mem);
+    DNNLMemorySum(*out_mem, *in_mem, *out_mem);
   } else {
     const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
   }
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
index a1ac5518e1..4e2058a5a6 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
@@ -43,7 +43,6 @@
 
 #include "operator/nn/deconvolution-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -130,25 +129,6 @@ class DNNLDeconvFwd {
   std::shared_ptr<deconv_fwd_pd_t> fwd_pd;
 };
 
-DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
-                                const std::vector<NDArray>& inputs,
-                                const std::vector<NDArray>& outputs)
-    : data(inputs[deconv::kData]),
-      weights(inputs[deconv::kWeight]),
-      bias(no_bias ? nullptr : &inputs[deconv::kBias]),
-      out(outputs[deconv::kOut]) {}
-
-DNNLDeconvFwd::Tensors::Tensors(const NDArray& data,
-                                const NDArray& weights,
-                                const NDArray* const bias,
-                                const NDArray& out)
-    : data(data), weights(weights), bias(bias), out(out) {}
-
-DNNLDeconvFwd::DNNLDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors)
-    : fwd_pd(CreatePrimitiveDesc(param, tensors)) {
-  fwd = std::make_shared<deconv_fwd_t>(*fwd_pd);
-}
-
 inline const dnnl::memory* DNNLDeconvFwd::DataMem(const NDArray& data) const {
   auto fwd_src_desc = fwd_pd->src_desc();
   return data.GetDNNLDataReorder(&fwd_src_desc);
@@ -242,28 +222,6 @@ class DNNLDeconvBwd {
   std::shared_ptr<deconv_bwd_weights_t> bwd_weights;
 };
 
-DNNLDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector<NDArray>& inputs)
-    : data(inputs[deconv::kData + 1]),
-      weights(inputs[deconv::kWeight + 1]),
-      bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]),
-      out_grad(inputs[deconv::kOut]) {}
-
-DNNLDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector<NDArray>& outputs)
-    : data_grad(outputs[deconv::kData]),
-      weights_grad(outputs[deconv::kWeight]),
-      bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {}
-
-DNNLDeconvBwd::DNNLDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors) {
-  const auto& fwd_pd = DNNLDeconvFwd::CreatePrimitiveDesc(
-      param,
-      DNNLDeconvFwd::Tensors(
-          read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad));
-  bwd_data_pd    = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd);
-  bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd);
-  bwd_data       = std::make_shared<deconv_bwd_data_t>(*bwd_data_pd);
-  bwd_weights    = std::make_shared<deconv_bwd_weights_t>(*bwd_weights_pd);
-}
-
 inline void DNNLDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group,
                                                 const std::vector<OpReqType>& req,
                                                 const NDArray& weights,
@@ -409,6 +367,18 @@ inline deconv_bwd_weights_t::desc DeconvDescCreator::CreateBwdWeightsDesc() cons
                                     padding);
 }
 
+void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& in_data,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& out_data);
+
+void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 9487574d47..79e3229963 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -35,6 +35,31 @@ bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input) {
          (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16);
 }
 
+DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<NDArray>& outputs)
+    : data(inputs[deconv::kData]),
+      weights(inputs[deconv::kWeight]),
+      bias(no_bias ? nullptr : &inputs[deconv::kBias]),
+      out(outputs[deconv::kOut]) {}
+
+DNNLDeconvFwd::Tensors::Tensors(const NDArray& data,
+                                const NDArray& weights,
+                                const NDArray* const bias,
+                                const NDArray& out)
+    : data(data), weights(weights), bias(bias), out(out) {}
+
+DNNLDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector<NDArray>& inputs)
+    : data(inputs[deconv::kData + 1]),
+      weights(inputs[deconv::kWeight + 1]),
+      bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]),
+      out_grad(inputs[deconv::kOut]) {}
+
+DNNLDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector<NDArray>& outputs)
+    : data_grad(outputs[deconv::kData]),
+      weights_grad(outputs[deconv::kWeight]),
+      bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {}
+
 void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
                               const std::vector<NDArray>& inputs,
@@ -49,6 +74,11 @@ void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
   fwd.Execute(param.num_group, req[deconv::kOut], tensors);
 }
 
+DNNLDeconvFwd::DNNLDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors)
+    : fwd_pd(CreatePrimitiveDesc(param, tensors)) {
+  fwd = std::make_shared<deconv_fwd_t>(*fwd_pd);
+}
+
 DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const DeconvolutionParam& param, const Tensors& tensors) {
   using deconv_fwd_map = std::unordered_map<DeconvSignature, DNNLDeconvFwd, OpHash>;
 #if DMLC_CXX11_THREAD_LOCAL
@@ -180,6 +210,17 @@ void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
   bwd.Execute(param.num_group, req, read_tensors, write_tensors);
 }
 
+DNNLDeconvBwd::DNNLDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors) {
+  const auto& fwd_pd = DNNLDeconvFwd::CreatePrimitiveDesc(
+      param,
+      DNNLDeconvFwd::Tensors(
+          read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad));
+  bwd_data_pd    = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd);
+  bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd);
+  bwd_data       = std::make_shared<deconv_bwd_data_t>(*bwd_data_pd);
+  bwd_weights    = std::make_shared<deconv_bwd_weights_t>(*bwd_weights_pd);
+}
+
 DNNLDeconvBwd& DNNLDeconvBwd::GetCached(const DeconvolutionParam& param,
                                         const ReadTensors& read_tensors) {
   using deconv_bwd_map = std::unordered_map<DeconvSignature, DNNLDeconvBwd, OpHash>;
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index 200df850c3..976dc83fe5 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -220,6 +220,12 @@ void DNNLFCForwardFullFeature(const DNNLFCFullParam& param,
                               const std::vector<OpReqType>& req,
                               const std::vector<NDArray>& out_data);
 
+void DNNLFCBackward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<NDArray>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm-inl.h b/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
index c751930967..9b184a4383 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
@@ -31,7 +31,6 @@
 
 #include "operator/nn/layer_norm-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -96,6 +95,18 @@ class DNNLLayerNormBwd {
   std::shared_ptr<layernorm_bwd_pd_t> bwd_pd;
 };
 
+void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs);
+
+void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<NDArray>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index be1abdb0f1..1559ee347b 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -24,7 +24,6 @@
 
 #include "operator/nn/softmax-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 #if MXNET_USE_ONEDNN == 1
 namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h b/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
index 683f743a10..96bcfc1e33 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
@@ -27,7 +27,6 @@
 #include <vector>
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 #include "operator/nn/softmax-inl.h"
 #include "dnnl_softmax-inl.h"
 
diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h b/src/operator/nn/dnnl/dnnl_ops-inl.h
deleted file mode 100644
index 8ccdaaceda..0000000000
--- a/src/operator/nn/dnnl/dnnl_ops-inl.h
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_ops-inl.h
- * \brief
- * \author Da Zheng
- */
-
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/optional.h>
-#include <mxnet/base.h>
-#include <mxnet/io.h>
-#include <mxnet/ndarray.h>
-#include <mxnet/operator.h>
-#include <mxnet/operator_util.h>
-
-#include <vector>
-
-#if MXNET_USE_ONEDNN == 1
-#include <dnnl.hpp>
-
-namespace mxnet {
-namespace op {
-
-/* For fully connected. */
-void DNNLFCForward(const nnvm::NodeAttrs& attrs,
-                   const OpContext& ctx,
-                   const std::vector<NDArray>& in_data,
-                   const std::vector<OpReqType>& req,
-                   const std::vector<NDArray>& out_data);
-void DNNLFCBackward(const nnvm::NodeAttrs& attrs,
-                    const OpContext& ctx,
-                    const std::vector<NDArray>& inputs,
-                    const std::vector<OpReqType>& req,
-                    const std::vector<NDArray>& outputs);
-
-/* For convolution. */
-void DNNLConvolutionForward(const nnvm::NodeAttrs& attrs,
-                            const OpContext& ctx,
-                            const std::vector<NDArray>& in_data,
-                            const std::vector<OpReqType>& req,
-                            const std::vector<NDArray>& out_data);
-void DNNLConvolutionBackward(const nnvm::NodeAttrs& attrs,
-                             const OpContext& ctx,
-                             const std::vector<NDArray>& inputs,
-                             const std::vector<OpReqType>& req,
-                             const std::vector<NDArray>& outputs);
-
-/* For deconvolution */
-void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& in_data,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& out_data);
-void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
-                               const OpContext& ctx,
-                               const std::vector<NDArray>& inputs,
-                               const std::vector<OpReqType>& req,
-                               const std::vector<NDArray>& outputs);
-
-/* For activation */
-void DNNLActivationForward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const NDArray& in_data,
-                           const OpReqType& req,
-                           const NDArray& out_data);
-void DNNLActivationBackward(const nnvm::NodeAttrs& attrs,
-                            const OpContext& ctx,
-                            const std::vector<NDArray>& inputs,
-                            const std::vector<OpReqType>& req,
-                            const std::vector<NDArray>& outputs);
-
-void DNNLLeakyReluForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext& ctx,
-                          const NDArray& in_data,
-                          const OpReqType& req,
-                          const NDArray& out_data);
-void DNNLLeakyReluBackward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const std::vector<NDArray>& inputs,
-                           const std::vector<OpReqType>& req,
-                           const std::vector<NDArray>& outputs);
-
-/* For softmax */
-void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
-                        const OpContext& ctx,
-                        const NDArray& in_data,
-                        const OpReqType& req,
-                        const NDArray& out_data);
-void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs,
-                         const OpContext& ctx,
-                         const std::vector<NDArray>& in_data,
-                         const std::vector<OpReqType>& req,
-                         const std::vector<NDArray>& out_data);
-
-/* For log_softmax */
-void DNNLLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const NDArray& in_data,
-                           const OpReqType& req,
-                           const NDArray& out_data);
-void DNNLLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
-                            const OpContext& ctx,
-                            const std::vector<NDArray>& in_data,
-                            const std::vector<OpReqType>& req,
-                            const std::vector<NDArray>& out_data);
-
-void DNNLMaskedSoftmaxForward(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs);
-
-/* For softmax_output */
-void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& in_data,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& out_data);
-
-void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<NDArray>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<NDArray>& outputs);
-
-/* For sum */
-void DNNLSumForward(const nnvm::NodeAttrs& attrs,
-                    const OpContext& ctx,
-                    const std::vector<NDArray>& inputs,
-                    const std::vector<OpReqType>& req,
-                    const std::vector<NDArray>& outputs);
-
-/* For copy */
-void DNNLCopy(const nnvm::NodeAttrs& attrs,
-              const OpContext& ctx,
-              const NDArray& in_data,
-              const OpReqType& req,
-              const NDArray& out_data);
-
-/* For concat */
-void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
-                       const OpContext& ctx,
-                       const std::vector<NDArray>& in_data,
-                       const std::vector<OpReqType>& req,
-                       const std::vector<NDArray>& out_data);
-void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
-                        const OpContext& ctx,
-                        const std::vector<NDArray>& inputs,
-                        const std::vector<OpReqType>& req,
-                        const std::vector<NDArray>& outputs);
-
-/* For batch dot */
-template <bool subgraph>
-void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
-                         const OpContext& ctx,
-                         const std::vector<NDArray>& inputs,
-                         const std::vector<OpReqType>& req,
-                         const std::vector<NDArray>& outputs);
-
-/* For layer normalization */
-void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext& ctx,
-                          const std::vector<NDArray>& inputs,
-                          const std::vector<OpReqType>& req,
-                          const std::vector<NDArray>& outputs);
-void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,
-                           const OpContext& ctx,
-                           const std::vector<NDArray>& inputs,
-                           const std::vector<OpReqType>& req,
-                           const std::vector<NDArray>& outputs);
-
-void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);
-
-void DNNLStackForward(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<NDArray>& in_data,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<NDArray>& out_data);
-
-template <class ParamType>
-void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
-                          const OpContext& ctx,
-                          const NDArray& data,
-                          const OpReqType& req,
-                          const NDArray& output);
-
-void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
-                        const OpContext& ctx,
-                        const NDArray& input,
-                        const OpReqType& req,
-                        const NDArray& output);
-
-void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<NDArray>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<NDArray>& outputs);
-
-void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const NDArray& input,
-                      const OpReqType& req,
-                      const NDArray& output);
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h b/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
index eddbffffff..5ece7ef832 100644
--- a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
+++ b/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
@@ -28,7 +28,6 @@
 #if MXNET_USE_ONEDNN == 1
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 #include "operator/tensor/elemwise_binary_scalar_op.h"
 
 namespace mxnet {
@@ -54,6 +53,12 @@ class DNNLPowerFwd {
 
 typedef OpSignature DNNLPowerSignature;
 
+void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const NDArray& input,
+                      const OpReqType& req,
+                      const NDArray& output);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_reduce-inl.h b/src/operator/nn/dnnl/dnnl_reduce-inl.h
index 9e3f0bd2a5..f33206afb9 100644
--- a/src/operator/nn/dnnl/dnnl_reduce-inl.h
+++ b/src/operator/nn/dnnl/dnnl_reduce-inl.h
@@ -28,7 +28,6 @@
 #include <vector>
 
 #include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_reshape-inl.h
index 04e1fecb01..01f26994f6 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_reshape-inl.h
@@ -53,6 +53,13 @@ typedef OpSignature DNNLReshapeSignature;
 DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
                                   const NDArray& input,
                                   const NDArray& output);
+
+void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const NDArray& input,
+                        const OpReqType& req,
+                        const NDArray& output);
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/dnnl/dnnl_reshape.cc b/src/operator/nn/dnnl/dnnl_reshape.cc
index ec96a5eb74..d1270a3a8c 100644
--- a/src/operator/nn/dnnl/dnnl_reshape.cc
+++ b/src/operator/nn/dnnl/dnnl_reshape.cc
@@ -26,7 +26,6 @@
 #if MXNET_USE_ONEDNN == 1
 #include "operator/tensor/elemwise_unary_op.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 #include "dnnl_reshape-inl.h"
 
 namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_slice.cc b/src/operator/nn/dnnl/dnnl_slice.cc
index 3008133425..102bf684fb 100644
--- a/src/operator/nn/dnnl/dnnl_slice.cc
+++ b/src/operator/nn/dnnl/dnnl_slice.cc
@@ -26,7 +26,6 @@
 #if MXNET_USE_ONEDNN == 1
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 #include "dnnl_slice-inl.h"
 
 namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_softmax-inl.h b/src/operator/nn/dnnl/dnnl_softmax-inl.h
index 42558c6bbf..5b201ad13a 100644
--- a/src/operator/nn/dnnl/dnnl_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax-inl.h
@@ -37,7 +37,6 @@
 #include <vector>
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 #include "operator/nn/softmax-inl.h"
 
@@ -106,6 +105,36 @@ class DNNLSoftmaxBwd {
   std::shared_ptr<linear_t> temperature_fwd;
 };
 
+void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const NDArray& in_data,
+                        const OpReqType& req,
+                        const NDArray& out_data);
+
+void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<NDArray>& in_data,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& out_data);
+
+void DNNLLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const NDArray& in_data,
+                           const OpReqType& req,
+                           const NDArray& out_data);
+
+void DNNLLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& in_data,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& out_data);
+
+void DNNLMaskedSoftmaxForward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_softmax_output-inl.h
similarity index 50%
copy from src/operator/nn/dnnl/dnnl_reshape-inl.h
copy to src/operator/nn/dnnl/dnnl_softmax_output-inl.h
index 04e1fecb01..d0fb1d820d 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax_output-inl.h
@@ -18,43 +18,30 @@
  */
 
 /*!
- * \file dnnl_reshape-inl.h
- * \brief Function definition of dnnl reshape operator
+ * \file dnnl_softmax_output-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
  */
 
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
 
 #if MXNET_USE_ONEDNN == 1
-#include <vector>
 
-#include "operator/tensor/matrix_op-inl.h"
-#include "dnnl_base-inl.h"
+#include <dnnl.hpp>
+#include <vector>
 
 namespace mxnet {
 namespace op {
 
-class DNNLReshapeFwd {
- protected:
-  std::shared_ptr<dnnl::memory> out_;
-  std::shared_ptr<dnnl::memory> temp_;
-  std::vector<dnnl::primitive> prims_;
-
- public:
-  DNNLReshapeFwd(const OpReqType& req, const NDArray& input, const NDArray& output);
-  int GetWorkspaceSize();
-  void Execute(const NDArray& input,
-               const NDArray& output,
-               const OpReqType& req,
-               void* workspace = nullptr);
-};
-
-typedef OpSignature DNNLReshapeSignature;
-DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
-                                  const NDArray& input,
-                                  const NDArray& output);
+void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& in_data,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& out_data);
+
 }  // namespace op
 }  // namespace mxnet
 
 #endif  // MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_softmax_output.cc b/src/operator/nn/dnnl/dnnl_softmax_output.cc
index 94b0029d9b..ba79effc5b 100644
--- a/src/operator/nn/dnnl/dnnl_softmax_output.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax_output.cc
@@ -27,7 +27,6 @@
 
 #include "operator/softmax_output-inl.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_split-inl.h b/src/operator/nn/dnnl/dnnl_split-inl.h
index a8cdc4cd93..051bf4793f 100644
--- a/src/operator/nn/dnnl/dnnl_split-inl.h
+++ b/src/operator/nn/dnnl/dnnl_split-inl.h
@@ -28,7 +28,6 @@
 #include <vector>
 
 #include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -63,6 +62,12 @@ class DNNLSplitFwd {
   dnnl::memory::dims strides;
 };
 
+void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_stack-inl.h
similarity index 50%
copy from src/operator/nn/dnnl/dnnl_reshape-inl.h
copy to src/operator/nn/dnnl/dnnl_stack-inl.h
index 04e1fecb01..ce7522f9c7 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_stack-inl.h
@@ -18,43 +18,31 @@
  */
 
 /*!
- * \file dnnl_reshape-inl.h
- * \brief Function definition of dnnl reshape operator
+ * \file dnnl_stack-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
  */
 
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
 
 #if MXNET_USE_ONEDNN == 1
+
 #include <vector>
 
-#include "operator/tensor/matrix_op-inl.h"
-#include "dnnl_base-inl.h"
+#include <dnnl.hpp>
 
 namespace mxnet {
 namespace op {
 
-class DNNLReshapeFwd {
- protected:
-  std::shared_ptr<dnnl::memory> out_;
-  std::shared_ptr<dnnl::memory> temp_;
-  std::vector<dnnl::primitive> prims_;
-
- public:
-  DNNLReshapeFwd(const OpReqType& req, const NDArray& input, const NDArray& output);
-  int GetWorkspaceSize();
-  void Execute(const NDArray& input,
-               const NDArray& output,
-               const OpReqType& req,
-               void* workspace = nullptr);
-};
-
-typedef OpSignature DNNLReshapeSignature;
-DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
-                                  const NDArray& input,
-                                  const NDArray& output);
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& in_data,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& out_data);
+
 }  // namespace op
 }  // namespace mxnet
 
 #endif  // MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_stack.cc b/src/operator/nn/dnnl/dnnl_stack.cc
index 4cd45e104d..981a4f87c6 100644
--- a/src/operator/nn/dnnl/dnnl_stack.cc
+++ b/src/operator/nn/dnnl/dnnl_stack.cc
@@ -20,10 +20,11 @@
 /*!
  * \file dnnl_stack.cc
  */
+#include <vector>
 
 #include "dnnl_base-inl.h"
 #include "dnnl_concat-inl.h"
-#include "dnnl_ops-inl.h"
+#include "dnnl_stack-inl.h"
 
 #include "operator/tensor/matrix_op-inl.h"
 
@@ -103,7 +104,7 @@ void DNNLStackForward(const nnvm::NodeAttrs& attrs,
     }
   });
 
-  auto& fwd = GetConcatForward(stacking_dim, in_data, data_md, axis);
+  auto& fwd = DNNLConcatFwd::GetCached(stacking_dim, in_data, data_md, axis);
   mxnet::dnnl_output_t out_mem =
       CreateDNNLMem(out_data[concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
 
diff --git a/src/operator/nn/dnnl/dnnl_sum-inl.h b/src/operator/nn/dnnl/dnnl_sum-inl.h
new file mode 100644
index 0000000000..14b2349b06
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_sum-inl.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_sum-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <vector>
+
+#include <dnnl.hpp>
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+using sum_t    = dnnl::sum;
+using sum_pd_t = dnnl::sum::primitive_desc;
+
+class DNNLSumFwd {
+ public:
+  typedef OpSignature DNNLSumSignature;
+
+  static DNNLSumFwd& GetCached(const std::vector<NDArray>& inputs,
+                               const std::vector<NDArray>& outputs);
+
+  explicit DNNLSumFwd(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs);
+
+  void Execute(const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+
+ private:
+  std::shared_ptr<sum_t> fwd;
+  std::shared_ptr<sum_pd_t> fwd_pd;
+};
+
+void DNNLSumForward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<NDArray>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<NDArray>& outputs);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_sum.cc b/src/operator/nn/dnnl/dnnl_sum.cc
index c2626a26b4..5995778ccd 100644
--- a/src/operator/nn/dnnl/dnnl_sum.cc
+++ b/src/operator/nn/dnnl/dnnl_sum.cc
@@ -23,112 +23,88 @@
  * \author Da Zheng
  */
 #include <iostream>
+#include <vector>
 
 #include "operator/operator_common.h"
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
+#include "dnnl_sum-inl.h"
 
 namespace mxnet {
 namespace op {
 
 #if MXNET_USE_ONEDNN == 1
-void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out) {
-  std::vector<dnnl::memory::desc> input_pds(2);
-  std::vector<float> scales(2, 1);
-  input_pds[0] = arr1.get_desc();
-  input_pds[1] = arr2.get_desc();
-  CHECK(input_pds[0] == input_pds[0]);
-  const dnnl::memory* in_mem1 = &arr1;
-  const dnnl::memory* in_mem2 = &arr2;
-  auto output_pd              = out.get_desc();
-  if (input_pds[0] != output_pd) {
-    auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd);
-    auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd);
-    DNNLMemoryCopy(arr1, tmp_memory1);
-    DNNLMemoryCopy(arr2, tmp_memory2);
-    input_pds[0] = tmp_memory1->get_desc();
-    input_pds[1] = tmp_memory2->get_desc();
-    in_mem1      = tmp_memory1;
-    in_mem2      = tmp_memory2;
-  }
-  dnnl::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine());
-  dnnl_args_map_t args = {
-      {DNNL_ARG_MULTIPLE_SRC, *in_mem1},
-      {DNNL_ARG_MULTIPLE_SRC + 1, *in_mem2},
-      {DNNL_ARG_DST, out},
-  };
-  DNNLStream::Get()->RegisterPrimArgs(dnnl::sum(sum_pd), args);
-}
-
-class DNNLSumFwd {
- public:
-  dnnl::sum::primitive_desc fwd_pd;
-
-  DNNLSumFwd(const std::vector<float>& scales, const std::vector<dnnl::memory::desc>& data_md)
-      : fwd_pd(scales, data_md, CpuEngine::Get()->get_engine()) {
-    fwd_ = std::make_shared<dnnl::sum>(fwd_pd);
-  }
 
-  const dnnl::sum& GetFwd() const {
-    return *fwd_;
-  }
-
- private:
-  std::shared_ptr<dnnl::sum> fwd_;
-};
-
-static DNNLSumFwd& GetSumForward(const std::vector<float>& scales,
-                                 const std::vector<NDArray>& in_data,
-                                 const std::vector<dnnl::memory::desc>& data_md) {
+DNNLSumFwd& DNNLSumFwd::GetCached(const std::vector<NDArray>& inputs,
+                                  const std::vector<NDArray>& outputs) {
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<OpSignature, DNNLSumFwd, OpHash> fwds;
+  static thread_local std::unordered_map<DNNLSumSignature, DNNLSumFwd, OpHash> fwds;
 #else
-  static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLSumFwd, OpHash> fwds;
+  static MX_THREAD_LOCAL std::unordered_map<DNNLSumSignature, DNNLSumFwd, OpHash> fwds;
 #endif
-  OpSignature key;
-  key.AddSign(in_data);
+  DNNLSumSignature key;
+  key.AddSign(inputs);
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    DNNLSumFwd fwd(scales, data_md);
+    const DNNLSumFwd fwd(inputs, outputs);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
 
-void DNNLSumForward(const nnvm::NodeAttrs& attrs,
-                    const OpContext& ctx,
-                    const std::vector<NDArray>& inputs,
-                    const std::vector<OpReqType>& req,
-                    const std::vector<NDArray>& outputs) {
-  TmpMemMgr::Get()->Init(ctx.requested[0]);
+DNNLSumFwd::DNNLSumFwd(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs) {
   const int num_inputs    = inputs.size();
-  const NDArray& out_data = outputs[0];
+
   std::vector<dnnl::memory::desc> data_md;
-  std::vector<const dnnl::memory*> data_mem;
+
   std::vector<float> scales(num_inputs, 1);
 
   data_md.reserve(num_inputs);
-  data_mem.reserve(num_inputs);
 
   for (int i = 0; i < num_inputs; ++i) {
     const dnnl::memory* in_mem = inputs[i].GetDNNLData();
     dnnl::memory::desc tmp_md  = in_mem->get_desc();
     data_md.push_back(tmp_md);
+  }
+
+  fwd_pd = std::make_shared<sum_pd_t>(scales, data_md, CpuEngine::Get()->get_engine());
+  fwd    = std::make_shared<sum_t>(*fwd_pd);
+}
+
+void DNNLSumFwd::Execute(const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs) {
+  const NDArray& out_data = outputs[0];
+  const int num_inputs    = inputs.size();
+  std::vector<const dnnl::memory*> data_mem;
+
+  data_mem.reserve(num_inputs);
+
+  for (int i = 0; i < num_inputs; ++i) {
+    const dnnl::memory* in_mem = inputs[i].GetDNNLData();
     data_mem.push_back(in_mem);
   }
 
-  DNNLSumFwd& fwd              = GetSumForward(scales, inputs, data_md);
-  mxnet::dnnl_output_t out_mem = CreateDNNLMem(out_data, fwd.fwd_pd.dst_desc(), req[0], &inputs[0]);
+  mxnet::dnnl_output_t out_mem = CreateDNNLMem(out_data, fwd_pd->dst_desc(), req[0], &inputs[0]);
   dnnl_args_map_t net_args;
   net_args.insert({DNNL_ARG_DST, *out_mem.second});
   for (int i = 0; i < num_inputs; ++i) {
     net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, *data_mem[i]});
   }
-  DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+  DNNLStream::Get()->RegisterPrimArgs(*fwd, net_args);
   CommitOutput(out_data, out_mem);
   DNNLStream::Get()->Submit();
 }
+
+void DNNLSumForward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<NDArray>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<NDArray>& outputs) {
+  DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+  fwd.Execute(ctx, inputs, req, outputs);
+}
 #endif
 
 }  // namespace op
diff --git a/src/operator/nn/dnnl/dnnl_transpose-inl.h b/src/operator/nn/dnnl/dnnl_transpose-inl.h
index a829dc2315..aa6e071da7 100644
--- a/src/operator/nn/dnnl/dnnl_transpose-inl.h
+++ b/src/operator/nn/dnnl/dnnl_transpose-inl.h
@@ -27,7 +27,6 @@
 #if MXNET_USE_ONEDNN == 1
 
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 #include "operator/numpy/np_matrix_op-inl.h"
 
diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h
index bfda684668..cd3c643fe7 100644
--- a/src/operator/nn/dnnl/dnnl_where-inl.h
+++ b/src/operator/nn/dnnl/dnnl_where-inl.h
@@ -29,7 +29,6 @@
 #include <unordered_map>
 #include <vector>
 #include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -67,6 +66,12 @@ class DNNLWhereFwd {
 
 bool SupportDNNLWhere(const std::vector<NDArray>& inputs);
 
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs);
+
 }  // namespace op
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index f5d6c2c966..61e7ca5ea9 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -22,8 +22,11 @@
  * \brief fully connect operator
  */
 #include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
 #include "./fully_connected-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_fully_connected-inl.h"
+#endif  // MXNET_USE_ONEDNN == 1
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index f0b989fda1..ae8c429ec5 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -27,8 +27,8 @@
 #include "../elemwise_op_common.h"
 #include "layer_norm_cpu.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_layer_norm-inl.h"
 #endif  // MXNET_USE_ONEDNN
 
 #if MSHADOW_USE_MKL == 1
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index f56e7acda7..e63f09c71f 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -26,8 +26,8 @@
 #include "../tensor/elemwise_binary_op.h"
 #include "../operator_common.h"
 #if MXNET_USE_ONEDNN == 1
-#include "dnnl/dnnl_base-inl.h"
-#include "dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
 #endif
 
 namespace mxnet {
diff --git a/src/operator/nn/masked_softmax.cc b/src/operator/nn/masked_softmax.cc
index 386b53e8d9..ee8c902833 100644
--- a/src/operator/nn/masked_softmax.cc
+++ b/src/operator/nn/masked_softmax.cc
@@ -27,7 +27,7 @@
 #include "operator/operator_common.h"
 #if MXNET_USE_ONEDNN == 1
 #include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
 #endif
 
 namespace mxnet {
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 004e5a2bb4..0fb4e338cb 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -26,8 +26,8 @@
 #include "../tensor/elemwise_binary_op.h"
 #include "../operator_common.h"
 #if MXNET_USE_ONEDNN == 1
-#include "dnnl/dnnl_base-inl.h"
-#include "dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
 #endif
 
 namespace mxnet {
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 8b556b5ef3..946527562a 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -27,9 +27,8 @@
 #include "./np_matrix_op-inl.h"
 #include "../nn/concat-inl.h"
 #if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_transpose-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_transpose-inl.h"
 #endif
 namespace mxnet {
 namespace op {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_act.cc b/src/operator/quantization/dnnl/dnnl_quantized_act.cc
index 6cf31a0ee2..ec780b91d5 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_act.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_act.cc
@@ -23,7 +23,8 @@
  */
 #if MXNET_USE_ONEDNN == 1
 
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/activation-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
 #include "operator/quantization/quantization_utils.h"
 
 namespace mxnet {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
index a6f9e85427..ee00e6b61d 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
@@ -97,7 +97,7 @@ static void DNNLQuantizedConcatForward(const nnvm::NodeAttrs& attrs,
   }
   int param_dim                = param_.dim.has_value() ? param_.dim.value() : 0;
   param_dim                    = CheckAxis(param_dim, in_data[concat_enum::kData0].shape().ndim());
-  DNNLConcatFwd& fwd           = GetConcatForward(param_dim, in_data, data_md);
+  DNNLConcatFwd& fwd           = DNNLConcatFwd::GetCached(param_dim, in_data, data_md);
   mxnet::dnnl_output_t out_mem = CreateDNNLMem(
       out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
   dnnl_args_map_t net_args;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
index 518dc0748e..82e66a1ed2 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
@@ -24,7 +24,6 @@
 
 #if MXNET_USE_ONEDNN == 1
 #include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
 #include "operator/quantization/quantization_utils.h"
 #include "operator/quantization/quantized_elemwise_add-inl.h"
 
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
index a605a16813..0b14a96777 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
@@ -23,7 +23,7 @@
  */
 
 #if MXNET_USE_ONEDNN == 1
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
 #include "operator/quantization/quantization_utils.h"
 
 namespace mxnet {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
index 0d468fde7a..344bda3f5e 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
@@ -24,7 +24,7 @@
 
 #if MXNET_USE_ONEDNN == 1
 #include "operator/quantization/quantized_reshape-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc
index 95fbd3bba2..f42a550573 100644
--- a/src/operator/quantization/quantized_conv.cc
+++ b/src/operator/quantization/quantized_conv.cc
@@ -23,9 +23,6 @@
  * \author Ziheng Jiang, Jun Wu
  */
 #include "../nn/convolution-inl.h"
-#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#endif
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 401968f2d2..52a9888ea8 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -24,8 +24,8 @@
  */
 #include "./softmax_output-inl.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./nn/dnnl/dnnl_base-inl.h"
-#include "./nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax_output-inl.h"
 #endif
 namespace mxnet {
 namespace op {
diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
index c9beffc90c..6905b118ba 100644
--- a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
@@ -30,7 +30,6 @@
 
 #include "operator/nn/dnnl/dnnl_base-inl.h"
 #include "operator/nn/dnnl/dnnl_batch_dot-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
 #include "operator/quantization/quantization_utils.h"
 #include "operator/tensor/matrix_op-inl.h"
 #include "operator/subgraph/common.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 262746068f..5b2c9ad3e0 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -25,7 +25,6 @@
 
 #include "operator/nn/dnnl/dnnl_act-inl.h"
 #include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
 #include "operator/quantization/quantization_utils.h"
 #include "operator/tensor/matrix_op-inl.h"
 #include "operator/subgraph/common.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_conv_property.h b/src/operator/subgraph/dnnl/dnnl_conv_property.h
index c5e027d5b1..3d814c0598 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_conv_property.h
@@ -28,7 +28,6 @@
 #include "operator/leaky_relu-inl.h"
 #include "operator/nn/activation-inl.h"
 #include "operator/nn/convolution-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
 #include "operator/tensor/matrix_op-inl.h"
 #include "operator/subgraph/common.h"
 #include "dnnl_subgraph_base-inl.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index a5c199f964..24b7ec6883 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -35,7 +35,6 @@
 #include "operator/nn/dnnl/dnnl_act-inl.h"
 #include "operator/nn/dnnl/dnnl_base-inl.h"
 #include "operator/nn/dnnl/dnnl_fully_connected-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
 #include "operator/quantization/quantization_utils.h"
 #include "operator/tensor/matrix_op-inl.h"
 #include "operator/subgraph/common.h"
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index e1b881c2b7..ad29575cb3 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -21,8 +21,9 @@
  * \file elemwise_binary_op_basic.cc
  * \brief CPU Implementation of basic elementwise binary broadcast operators
  */
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_copy-inl.h"
+#include "operator/nn/dnnl/dnnl_sum-inl.h"
 #include "./elemwise_binary_op-inl.h"
 #include "./elemwise_unary_op.h"
 
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 67842cd25f..8e6e2ad11c 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -25,8 +25,10 @@
 
 #include "../../common/utils.h"
 #include "../../ndarray/ndarray_function.h"
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_sum-inl.h"
+#endif  // MXNET_USE_ONEDNN == 1
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 3cc930b0d8..9c317fd3a7 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -23,8 +23,8 @@
  */
 #include <mxnet/base.h>
 
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "./elemwise_binary_op-inl.h"
+#include "elemwise_binary_op-inl.h"
+#include "operator/nn/dnnl/dnnl_copy-inl.h"
 #include "elemwise_unary_op.h"
 
 namespace mxnet {
diff --git a/src/operator/tensor/elemwise_unary_op_logexp.cc b/src/operator/tensor/elemwise_unary_op_logexp.cc
index 65bc767312..e94862d7dc 100644
--- a/src/operator/tensor/elemwise_unary_op_logexp.cc
+++ b/src/operator/tensor/elemwise_unary_op_logexp.cc
@@ -24,7 +24,6 @@
 #include <mxnet/base.h>
 
 #include "../../nnvm/node_op_util.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
 #include "./elemwise_binary_op-inl.h"
 #include "elemwise_unary_op.h"
 
diff --git a/src/operator/tensor/elemwise_unary_op_pow.cc b/src/operator/tensor/elemwise_unary_op_pow.cc
index b4e35c4c26..51f65553ce 100644
--- a/src/operator/tensor/elemwise_unary_op_pow.cc
+++ b/src/operator/tensor/elemwise_unary_op_pow.cc
@@ -24,7 +24,6 @@
 #include <mxnet/base.h>
 
 #include "../../nnvm/node_op_util.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
 #include "./elemwise_binary_op-inl.h"
 #include "elemwise_unary_op.h"
 
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index bc97aa4a53..15f131cfcd 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -25,12 +25,12 @@
 #include "./matrix_op-inl.h"
 #include "./elemwise_unary_op.h"
 #if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "../nn/dnnl/dnnl_reshape-inl.h"
-#include "../nn/dnnl/dnnl_slice-inl.h"
-#include "../nn/dnnl/dnnl_transpose-inl.h"
-#include "../nn/dnnl/dnnl_split-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
+#include "operator/nn/dnnl/dnnl_slice-inl.h"
+#include "operator/nn/dnnl/dnnl_transpose-inl.h"
+#include "operator/nn/dnnl/dnnl_split-inl.h"
+#include "operator/nn/dnnl/dnnl_stack-inl.h"
 #endif
 
 namespace mxnet {
diff --git a/tests/cpp/operator/dnnl_operator_test.cc b/tests/cpp/operator/dnnl_operator_test.cc
index e66fc56bab..dd02656897 100644
--- a/tests/cpp/operator/dnnl_operator_test.cc
+++ b/tests/cpp/operator/dnnl_operator_test.cc
@@ -34,7 +34,6 @@
 #include "../../src/operator/nn/convolution-inl.h"
 #include "../../src/operator/nn/deconvolution-inl.h"
 #include "../../src/operator/nn/dnnl/dnnl_base-inl.h"
-#include "../../src/operator/nn/dnnl/dnnl_ops-inl.h"
 #include "../../src/operator/nn/dnnl/dnnl_pooling-inl.h"
 #include "../../src/operator/nn/pooling-inl.h"
 #include "../include/test_dnnl.h"
diff --git a/tests/cpp/operator/dnnl_test.cc b/tests/cpp/operator/dnnl_test.cc
index a420c2a7b9..565b551479 100644
--- a/tests/cpp/operator/dnnl_test.cc
+++ b/tests/cpp/operator/dnnl_test.cc
@@ -32,7 +32,6 @@
 #include <set>
 
 #include "../../src/operator/nn/dnnl/dnnl_base-inl.h"
-#include "../../src/operator/nn/dnnl/dnnl_ops-inl.h"
 #include "../include/test_dnnl.h"
 #include "gtest/gtest.h"
 #include "mxnet/imperative.h"
@@ -189,7 +188,7 @@ TEST(DNNL_NDArray, GetDataReorder) {
   }
 }
 
-TEST(DNNL_BASE, DNNLSum) {
+TEST(DNNL_BASE, DNNLMemorySum) {
   std::vector<NDArrayAttrs> in_arrs   = GetTestInputArrays();
   std::vector<NDArrayAttrs> in_arrs2  = GetTestInputArrays(ArrayTypes::All, true);
   TestArrayShapes tas                 = GetTestArrayShapes();
@@ -211,7 +210,7 @@ TEST(DNNL_BASE, DNNLSum) {
         continue;
       auto out_mem = out_arr.arr.GetDNNLData();
       PrintVerifyMsg(in_arr, in_arr);
-      op::DNNLSum(*in_mem1, *in_mem2, *out_mem);
+      op::DNNLMemorySum(*in_mem1, *in_mem2, *out_mem);
       DNNLStream::Get()->Submit();
       VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
     }
@@ -233,7 +232,7 @@ TEST(DNNL_BASE, DNNLSum) {
     PrintVerifyMsg(orig_arr, in_arr);
     InitDNNLArray(&orig_arr.arr, input_mem->get_desc());
     orig_arr.arr.CopyFrom(*input_mem);
-    op::DNNLSum(*input_mem, *input_mem2, *input_mem);
+    op::DNNLMemorySum(*input_mem, *input_mem2, *input_mem);
     DNNLStream::Get()->Submit();
     VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
   }
@@ -264,7 +263,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
       PrintVerifyMsg(in_arr, out_arr);
       auto out_mem      = out_arr.arr.GetDNNLData();
       auto output_mem_t = CreateDNNLMem(out_arr.arr, out_mem->get_desc(), kWriteTo);
-      op::DNNLSum(*in_mem, *in_mem2, *output_mem_t.second);
+      op::DNNLMemorySum(*in_mem, *in_mem2, *output_mem_t.second);
       CommitOutput(out_arr.arr, output_mem_t);
       stream->Submit();
       VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
@@ -289,7 +288,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
     orig_arr.arr.CopyFrom(*input_mem);
     auto output_mem_t =
         CreateDNNLMem(in_arr.arr, input_mem->get_desc(), kWriteInplace, &in_arr.arr);
-    op::DNNLSum(*input_mem, *input_mem2, *output_mem_t.second);
+    op::DNNLMemorySum(*input_mem, *input_mem2, *output_mem_t.second);
     CommitOutput(in_arr.arr, output_mem_t);
     stream->Submit();
     VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
@@ -313,7 +312,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
       PrintVerifyMsg(in_arr, out_arr);
       auto out_mem      = out_arr.arr.GetDNNLData();
       auto output_mem_t = CreateDNNLMem(out_arr.arr, out_mem->get_desc(), kAddTo);
-      op::DNNLSum(*in_mem, *in_mem2, *output_mem_t.second);
+      op::DNNLMemorySum(*in_mem, *in_mem2, *output_mem_t.second);
       CommitOutput(out_arr.arr, output_mem_t);
       stream->Submit();
       VerifyAddRequest(
@@ -338,7 +337,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
     InitDNNLArray(&orig_arr.arr, input_mem->get_desc());
     orig_arr.arr.CopyFrom(*input_mem);
     auto output_mem_t = CreateDNNLMem(in_arr.arr, input_mem->get_desc(), kNullOp);
-    op::DNNLSum(*input_mem, *input_mem2, *output_mem_t.second);
+    op::DNNLMemorySum(*input_mem, *input_mem2, *output_mem_t.second);
     CommitOutput(in_arr.arr, output_mem_t);
     stream->Submit();
     // original and input should be the same since noop