You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/02 06:41:03 UTC
[incubator-mxnet] branch master updated: [master] Remove dnnl_ops-inl.h file (#20997)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 102388a055 [master] Remove dnnl_ops-inl.h file (#20997)
102388a055 is described below
commit 102388a0557c530741ed8e9b31296416a1c23925
Author: PiotrWolinski - Intel <pi...@intel.com>
AuthorDate: Thu Jun 2 08:40:48 2022 +0200
[master] Remove dnnl_ops-inl.h file (#20997)
* Removed dnnl_ops-inl.h
* Removed dnnl_ops-inl.h file from tracking
* Formatted new files
* Added missing import
* Applied changes from review
* Linting
* Refactored DNNLSumFwd class
* Linting
* Fixed issue with DNNLSumFwd
* Removed unused variables
* Fixed oneDNN usage in refactored files
* Fix sanity
* Changed GetConcatForward to GetCache dand placed it in the DNNLConcatFwd classD
* Removed missed comment
* Moved DNNLSum function from dnnl_sum files to dnnl_base and renamed it to DNNLMemorySum
---
src/operator/leaky_relu.cc | 4 +-
src/operator/nn/activation.cc | 4 +-
src/operator/nn/concat-inl.h | 1 -
src/operator/nn/concat.cc | 6 +-
src/operator/nn/convolution-inl.h | 1 -
src/operator/nn/convolution.cc | 4 +-
src/operator/nn/deconvolution-inl.h | 1 -
src/operator/nn/deconvolution.cc | 4 +-
src/operator/nn/dnnl/dnnl_act-inl.h | 25 +++
src/operator/nn/dnnl/dnnl_base-inl.h | 1 +
src/operator/nn/dnnl/dnnl_base.cc | 33 ++-
src/operator/nn/dnnl/dnnl_batch_dot-inl.h | 1 -
src/operator/nn/dnnl/dnnl_batch_norm-inl.h | 1 -
src/operator/nn/dnnl/dnnl_binary-inl.h | 1 -
src/operator/nn/dnnl/dnnl_concat-inl.h | 37 ++--
src/operator/nn/dnnl/dnnl_concat.cc | 25 ++-
src/operator/nn/dnnl/dnnl_convolution-inl.h | 13 +-
src/operator/nn/dnnl/dnnl_convolution.cc | 1 -
.../nn/dnnl/{dnnl_copy.cc => dnnl_copy-inl.h} | 36 +---
src/operator/nn/dnnl/dnnl_copy.cc | 3 +-
src/operator/nn/dnnl/dnnl_deconvolution-inl.h | 54 ++---
src/operator/nn/dnnl/dnnl_deconvolution.cc | 41 ++++
src/operator/nn/dnnl/dnnl_fully_connected-inl.h | 6 +
src/operator/nn/dnnl/dnnl_layer_norm-inl.h | 13 +-
src/operator/nn/dnnl/dnnl_log_softmax.cc | 1 -
src/operator/nn/dnnl/dnnl_masked_softmax-inl.h | 1 -
src/operator/nn/dnnl/dnnl_ops-inl.h | 230 ---------------------
src/operator/nn/dnnl/dnnl_power_scalar-inl.h | 7 +-
src/operator/nn/dnnl/dnnl_reduce-inl.h | 1 -
src/operator/nn/dnnl/dnnl_reshape-inl.h | 7 +
src/operator/nn/dnnl/dnnl_reshape.cc | 1 -
src/operator/nn/dnnl/dnnl_slice.cc | 1 -
src/operator/nn/dnnl/dnnl_softmax-inl.h | 31 ++-
...nnl_reshape-inl.h => dnnl_softmax_output-inl.h} | 41 ++--
src/operator/nn/dnnl/dnnl_softmax_output.cc | 1 -
src/operator/nn/dnnl/dnnl_split-inl.h | 7 +-
.../dnnl/{dnnl_reshape-inl.h => dnnl_stack-inl.h} | 40 ++--
src/operator/nn/dnnl/dnnl_stack.cc | 5 +-
src/operator/nn/dnnl/dnnl_sum-inl.h | 72 +++++++
src/operator/nn/dnnl/dnnl_sum.cc | 106 ++++------
src/operator/nn/dnnl/dnnl_transpose-inl.h | 1 -
src/operator/nn/dnnl/dnnl_where-inl.h | 7 +-
src/operator/nn/fully_connected.cc | 5 +-
src/operator/nn/layer_norm.cc | 4 +-
src/operator/nn/log_softmax.cc | 4 +-
src/operator/nn/masked_softmax.cc | 2 +-
src/operator/nn/softmax.cc | 4 +-
src/operator/numpy/np_matrix_op.cc | 5 +-
.../quantization/dnnl/dnnl_quantized_act.cc | 3 +-
.../quantization/dnnl/dnnl_quantized_concat.cc | 2 +-
.../dnnl/dnnl_quantized_elemwise_add.cc | 1 -
.../quantization/dnnl/dnnl_quantized_flatten.cc | 2 +-
.../quantization/dnnl/dnnl_quantized_reshape.cc | 2 +-
src/operator/quantization/quantized_conv.cc | 3 -
src/operator/softmax_output.cc | 4 +-
src/operator/subgraph/dnnl/dnnl_batch_dot.cc | 1 -
src/operator/subgraph/dnnl/dnnl_conv.cc | 1 -
src/operator/subgraph/dnnl/dnnl_conv_property.h | 1 -
src/operator/subgraph/dnnl/dnnl_fc.cc | 1 -
src/operator/tensor/elemwise_binary_op_basic.cc | 5 +-
src/operator/tensor/elemwise_sum.cc | 6 +-
src/operator/tensor/elemwise_unary_op_basic.cc | 4 +-
src/operator/tensor/elemwise_unary_op_logexp.cc | 1 -
src/operator/tensor/elemwise_unary_op_pow.cc | 1 -
src/operator/tensor/matrix_op.cc | 12 +-
tests/cpp/operator/dnnl_operator_test.cc | 1 -
tests/cpp/operator/dnnl_test.cc | 15 +-
67 files changed, 444 insertions(+), 522 deletions(-)
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 39aa11dfb9..1376efba52 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -26,8 +26,8 @@
#include "./leaky_relu-inl.h"
#include "../common/alm.h"
#if MXNET_USE_ONEDNN == 1
-#include "./nn/dnnl/dnnl_base-inl.h"
-#include "./nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
#endif // MXNET_USE_ONEDNN == 1
#include <nnvm/op_attr_types.h>
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index a228bf8a76..46ecdb2b69 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -26,8 +26,8 @@
#include "../mshadow_op.h"
#include "../tensor/elemwise_unary_op.h"
#if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
#endif // MXNET_USE_ONEDNN == 1
#include "../operator_common.h"
#include "../../common/utils.h"
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 01cde26de2..5886d60386 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -388,7 +388,6 @@ void ConcatCSRImpl(const nnvm::NodeAttrs& attrs,
});
});
}
-
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 2329892048..70b8aeb9f8 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -25,8 +25,10 @@
#include "../../common/utils.h"
#include "./concat-inl.h"
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_concat-inl.h"
+#endif // MXNET_USE_ONEDNN == 1
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index 9994c7bed7..a034e44082 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -597,7 +597,6 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs,
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
});
}
-
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NN_CONVOLUTION_INL_H_
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index a39fa3fa45..674ccf6e81 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -30,8 +30,8 @@
#include "../operator_common.h"
#include "../../common/alm.h"
#if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_convolution-inl.h"
#endif // MXNET_USE_ONEDNN
namespace mxnet {
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index 9c9f2c8138..d5fb132d88 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -428,7 +428,6 @@ void DeconvolutionGradCompute(const nnvm::NodeAttrs& attrs,
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
_DeconvolutionGradCompute<xpu>(param, ctx, inputs, req, outputs);
}
-
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NN_DECONVOLUTION_INL_H_
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 2bef3fc898..e65df297de 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -28,8 +28,8 @@
#include "../../common/alm.h"
#include "../../common/utils.h"
#if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_deconvolution-inl.h"
#endif // MXNET_USE_ONEDNN
namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_act-inl.h b/src/operator/nn/dnnl/dnnl_act-inl.h
index 66f229962f..fcc896414b 100644
--- a/src/operator/nn/dnnl/dnnl_act-inl.h
+++ b/src/operator/nn/dnnl/dnnl_act-inl.h
@@ -95,6 +95,31 @@ class DNNLActBackward {
private:
std::shared_ptr<dnnl::eltwise_backward> bwd_prim_;
};
+
+void DNNLActivationForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const OpReqType& req,
+ const NDArray& out_data);
+
+void DNNLActivationBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+void DNNLLeakyReluForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const OpReqType& req,
+ const NDArray& out_data);
+
+void DNNLLeakyReluBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h
index c38895bb1e..8161d62c24 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -211,6 +211,7 @@ bool SupportDNNLStack(const std::vector<NDArray>& inputs);
bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
bool SupportDNNLEltwise(const NDArray& input, const NDArray& output);
bool SupportDNNLPower(const NDArray& input);
+void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 216a420137..c7e3a92d21 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -24,7 +24,6 @@
#include "../../../common/exec_utils.h"
#include "operator/operator_common.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
@@ -37,6 +36,36 @@ DNNLStream* DNNLStream::Get() {
return &stream;
}
+namespace op {
+void DNNLMemorySum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out) {
+ std::vector<dnnl::memory::desc> input_pds(2);
+ std::vector<float> scales(2, 1);
+ input_pds[0] = arr1.get_desc();
+ input_pds[1] = arr2.get_desc();
+ CHECK(input_pds[0] == input_pds[0]);
+ const dnnl::memory* in_mem1 = &arr1;
+ const dnnl::memory* in_mem2 = &arr2;
+ auto output_pd = out.get_desc();
+ if (input_pds[0] != output_pd) {
+ auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd);
+ auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd);
+ DNNLMemoryCopy(arr1, tmp_memory1);
+ DNNLMemoryCopy(arr2, tmp_memory2);
+ input_pds[0] = tmp_memory1->get_desc();
+ input_pds[1] = tmp_memory2->get_desc();
+ in_mem1 = tmp_memory1;
+ in_mem2 = tmp_memory2;
+ }
+ dnnl::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine());
+ dnnl_args_map_t args = {
+ {DNNL_ARG_MULTIPLE_SRC, *in_mem1},
+ {DNNL_ARG_MULTIPLE_SRC + 1, *in_mem2},
+ {DNNL_ARG_DST, out},
+ };
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::sum(sum_pd), args);
+}
+} // namespace op
+
void* AlignMem(void* mem, size_t size, size_t alignment, size_t* space) {
if (size > *space)
return nullptr;
@@ -222,7 +251,7 @@ void CommitOutput(const NDArray& arr, const dnnl_output_t& res) {
res_memory = tmp_memory;
mem = arr.GetDNNLData();
}
- op::DNNLSum(*mem, *res_memory, *mem);
+ op::DNNLMemorySum(*mem, *res_memory, *mem);
}
}
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index b48afd1d36..19233828dc 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -33,7 +33,6 @@
#include "operator/tensor/dot-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 8152b6c7a5..2780c9685f 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -33,7 +33,6 @@
#include "operator/nn/batch_norm-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h b/src/operator/nn/dnnl/dnnl_binary-inl.h
index 2cf63aa9a4..b047f8d365 100644
--- a/src/operator/nn/dnnl/dnnl_binary-inl.h
+++ b/src/operator/nn/dnnl/dnnl_binary-inl.h
@@ -27,7 +27,6 @@
#if MXNET_USE_ONEDNN == 1
#include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
#include <vector>
#include "../../tensor/elemwise_binary_broadcast_op.h"
diff --git a/src/operator/nn/dnnl/dnnl_concat-inl.h b/src/operator/nn/dnnl/dnnl_concat-inl.h
index 2970908185..8530538544 100644
--- a/src/operator/nn/dnnl/dnnl_concat-inl.h
+++ b/src/operator/nn/dnnl/dnnl_concat-inl.h
@@ -31,7 +31,6 @@
#include "operator/nn/concat-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -42,6 +41,11 @@ class DNNLConcatFwd {
DNNLConcatFwd(int concat_dim, const std::vector<dnnl::memory::desc>& data_md);
+ static DNNLConcatFwd& GetCached(int concat_dim,
+ const std::vector<NDArray>& in_data,
+ const std::vector<dnnl::memory::desc>& data_md,
+ int stack_axis = -1 /*used only by stack op*/);
+
const dnnl::concat& GetFwd() const {
return *fwd_;
}
@@ -50,28 +54,17 @@ class DNNLConcatFwd {
std::shared_ptr<dnnl::concat> fwd_;
};
-static DNNLConcatFwd& GetConcatForward(int concat_dim,
- const std::vector<NDArray>& in_data,
- const std::vector<dnnl::memory::desc>& data_md,
- int stack_axis = -1 /*used only by stack op*/) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
-#else
- static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
-#endif
-
- OpSignature key;
- key.AddSign(concat_dim);
- key.AddSign(stack_axis);
- key.AddSign(in_data);
+void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
- auto it = fwds.find(key);
- if (it == fwds.end()) {
- DNNLConcatFwd fwd(concat_dim, data_md);
- it = AddToCache(&fwds, key, fwd);
- }
- return it->second;
-}
+void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_concat.cc b/src/operator/nn/dnnl/dnnl_concat.cc
index 83ba9df245..1c7e73bd68 100644
--- a/src/operator/nn/dnnl/dnnl_concat.cc
+++ b/src/operator/nn/dnnl/dnnl_concat.cc
@@ -56,6 +56,29 @@ DNNLConcatFwd::DNNLConcatFwd(int concat_dim, const std::vector<dnnl::memory::des
fwd_ = std::make_shared<dnnl::concat>(fwd_pd);
}
+DNNLConcatFwd& DNNLConcatFwd::GetCached(int concat_dim,
+ const std::vector<NDArray>& in_data,
+ const std::vector<dnnl::memory::desc>& data_md,
+ int stack_axis /*used only by stack op*/) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
+#endif
+
+ OpSignature key;
+ key.AddSign(concat_dim);
+ key.AddSign(stack_axis);
+ key.AddSign(in_data);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ DNNLConcatFwd fwd(concat_dim, data_md);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
@@ -76,7 +99,7 @@ void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
data_md.push_back(tmp_md);
data_mem.push_back(tmp_mem);
}
- DNNLConcatFwd& fwd = GetConcatForward(concat_dim, in_data, data_md);
+ DNNLConcatFwd& fwd = DNNLConcatFwd::GetCached(concat_dim, in_data, data_md);
mxnet::dnnl_output_t out_mem =
CreateDNNLMem(out_data[concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
std::unordered_map<int, dnnl::memory> net_args;
diff --git a/src/operator/nn/dnnl/dnnl_convolution-inl.h b/src/operator/nn/dnnl/dnnl_convolution-inl.h
index f41af488b1..738be8214f 100644
--- a/src/operator/nn/dnnl/dnnl_convolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_convolution-inl.h
@@ -32,7 +32,6 @@
#include "operator/nn/convolution-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -172,6 +171,18 @@ class DNNLConvBackward {
std::shared_ptr<dnnl::convolution_backward_weights> bwd_weight_;
};
+void DNNLConvolutionForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
+void DNNLConvolutionBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc b/src/operator/nn/dnnl/dnnl_convolution.cc
index 60c65ffa47..ca6effb791 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -30,7 +30,6 @@
#include "operator/nn/convolution-inl.h"
#include "dnnl_base-inl.h"
#include "dnnl_convolution-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy-inl.h
similarity index 55%
copy from src/operator/nn/dnnl/dnnl_copy.cc
copy to src/operator/nn/dnnl/dnnl_copy-inl.h
index 16cbabd19a..41362dfcf1 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy-inl.h
@@ -18,15 +18,18 @@
*/
/*!
- * \file dnnl_copy.cc
+ * \file dnnl_copy-inl.h
* \brief
- * \author
+ * \author Wolinski Piotr piotr.wolinski@intel.com
*/
-#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
#if MXNET_USE_ONEDNN == 1
+
+#include <dnnl.hpp>
+
namespace mxnet {
namespace op {
@@ -34,27 +37,10 @@ void DNNLCopy(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const NDArray& in_data,
const OpReqType& req,
- const NDArray& out_data) {
- if (req == kNullOp || req == kWriteInplace)
- return;
- TmpMemMgr::Get()->Init(ctx.requested[0]);
- auto in_mem = in_data.GetDNNLData();
- if (req == kAddTo) {
- TmpMemMgr::Get()->Init(ctx.requested[0]);
- // We should try and force the input memory has the same format
- // as the input output. If not, we'll have to reorder memory.
- auto out_mem = out_data.GetDNNLData();
- auto out_mem_desc = out_mem->get_desc();
- in_mem = in_data.GetDNNLData(&out_mem_desc);
- if (in_mem == nullptr)
- in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
- DNNLSum(*out_mem, *in_mem, *out_mem);
- } else {
- const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
- }
- DNNLStream::Get()->Submit();
-}
+ const NDArray& out_data);
} // namespace op
} // namespace mxnet
-#endif
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_COPY_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy.cc
index 16cbabd19a..bdf5009bf5 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy.cc
@@ -24,7 +24,6 @@
*/
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#if MXNET_USE_ONEDNN == 1
namespace mxnet {
@@ -48,7 +47,7 @@ void DNNLCopy(const nnvm::NodeAttrs& attrs,
in_mem = in_data.GetDNNLData(&out_mem_desc);
if (in_mem == nullptr)
in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
- DNNLSum(*out_mem, *in_mem, *out_mem);
+ DNNLMemorySum(*out_mem, *in_mem, *out_mem);
} else {
const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
}
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
index a1ac5518e1..4e2058a5a6 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
@@ -43,7 +43,6 @@
#include "operator/nn/deconvolution-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -130,25 +129,6 @@ class DNNLDeconvFwd {
std::shared_ptr<deconv_fwd_pd_t> fwd_pd;
};
-DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
- const std::vector<NDArray>& inputs,
- const std::vector<NDArray>& outputs)
- : data(inputs[deconv::kData]),
- weights(inputs[deconv::kWeight]),
- bias(no_bias ? nullptr : &inputs[deconv::kBias]),
- out(outputs[deconv::kOut]) {}
-
-DNNLDeconvFwd::Tensors::Tensors(const NDArray& data,
- const NDArray& weights,
- const NDArray* const bias,
- const NDArray& out)
- : data(data), weights(weights), bias(bias), out(out) {}
-
-DNNLDeconvFwd::DNNLDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors)
- : fwd_pd(CreatePrimitiveDesc(param, tensors)) {
- fwd = std::make_shared<deconv_fwd_t>(*fwd_pd);
-}
-
inline const dnnl::memory* DNNLDeconvFwd::DataMem(const NDArray& data) const {
auto fwd_src_desc = fwd_pd->src_desc();
return data.GetDNNLDataReorder(&fwd_src_desc);
@@ -242,28 +222,6 @@ class DNNLDeconvBwd {
std::shared_ptr<deconv_bwd_weights_t> bwd_weights;
};
-DNNLDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector<NDArray>& inputs)
- : data(inputs[deconv::kData + 1]),
- weights(inputs[deconv::kWeight + 1]),
- bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]),
- out_grad(inputs[deconv::kOut]) {}
-
-DNNLDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector<NDArray>& outputs)
- : data_grad(outputs[deconv::kData]),
- weights_grad(outputs[deconv::kWeight]),
- bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {}
-
-DNNLDeconvBwd::DNNLDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors) {
- const auto& fwd_pd = DNNLDeconvFwd::CreatePrimitiveDesc(
- param,
- DNNLDeconvFwd::Tensors(
- read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad));
- bwd_data_pd = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd);
- bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd);
- bwd_data = std::make_shared<deconv_bwd_data_t>(*bwd_data_pd);
- bwd_weights = std::make_shared<deconv_bwd_weights_t>(*bwd_weights_pd);
-}
-
inline void DNNLDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group,
const std::vector<OpReqType>& req,
const NDArray& weights,
@@ -409,6 +367,18 @@ inline deconv_bwd_weights_t::desc DeconvDescCreator::CreateBwdWeightsDesc() cons
padding);
}
+void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
+void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 9487574d47..79e3229963 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -35,6 +35,31 @@ bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input) {
(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16);
}
+DNNLDeconvFwd::Tensors::Tensors(const bool no_bias,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs)
+ : data(inputs[deconv::kData]),
+ weights(inputs[deconv::kWeight]),
+ bias(no_bias ? nullptr : &inputs[deconv::kBias]),
+ out(outputs[deconv::kOut]) {}
+
+DNNLDeconvFwd::Tensors::Tensors(const NDArray& data,
+ const NDArray& weights,
+ const NDArray* const bias,
+ const NDArray& out)
+ : data(data), weights(weights), bias(bias), out(out) {}
+
+DNNLDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector<NDArray>& inputs)
+ : data(inputs[deconv::kData + 1]),
+ weights(inputs[deconv::kWeight + 1]),
+ bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]),
+ out_grad(inputs[deconv::kOut]) {}
+
+DNNLDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector<NDArray>& outputs)
+ : data_grad(outputs[deconv::kData]),
+ weights_grad(outputs[deconv::kWeight]),
+ bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {}
+
void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
@@ -49,6 +74,11 @@ void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
fwd.Execute(param.num_group, req[deconv::kOut], tensors);
}
+DNNLDeconvFwd::DNNLDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors)
+ : fwd_pd(CreatePrimitiveDesc(param, tensors)) {
+ fwd = std::make_shared<deconv_fwd_t>(*fwd_pd);
+}
+
DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const DeconvolutionParam& param, const Tensors& tensors) {
using deconv_fwd_map = std::unordered_map<DeconvSignature, DNNLDeconvFwd, OpHash>;
#if DMLC_CXX11_THREAD_LOCAL
@@ -180,6 +210,17 @@ void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
bwd.Execute(param.num_group, req, read_tensors, write_tensors);
}
+DNNLDeconvBwd::DNNLDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors) {
+ const auto& fwd_pd = DNNLDeconvFwd::CreatePrimitiveDesc(
+ param,
+ DNNLDeconvFwd::Tensors(
+ read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad));
+ bwd_data_pd = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd);
+ bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd);
+ bwd_data = std::make_shared<deconv_bwd_data_t>(*bwd_data_pd);
+ bwd_weights = std::make_shared<deconv_bwd_weights_t>(*bwd_weights_pd);
+}
+
DNNLDeconvBwd& DNNLDeconvBwd::GetCached(const DeconvolutionParam& param,
const ReadTensors& read_tensors) {
using deconv_bwd_map = std::unordered_map<DeconvSignature, DNNLDeconvBwd, OpHash>;
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index 200df850c3..976dc83fe5 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -220,6 +220,12 @@ void DNNLFCForwardFullFeature(const DNNLFCFullParam& param,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data);
+void DNNLFCBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm-inl.h b/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
index c751930967..9b184a4383 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_layer_norm-inl.h
@@ -31,7 +31,6 @@
#include "operator/nn/layer_norm-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -96,6 +95,18 @@ class DNNLLayerNormBwd {
std::shared_ptr<layernorm_bwd_pd_t> bwd_pd;
};
+void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index be1abdb0f1..1559ee347b 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -24,7 +24,6 @@
#include "operator/nn/softmax-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#if MXNET_USE_ONEDNN == 1
namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h b/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
index 683f743a10..96bcfc1e33 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax-inl.h
@@ -27,7 +27,6 @@
#include <vector>
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "operator/nn/softmax-inl.h"
#include "dnnl_softmax-inl.h"
diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h b/src/operator/nn/dnnl/dnnl_ops-inl.h
deleted file mode 100644
index 8ccdaaceda..0000000000
--- a/src/operator/nn/dnnl/dnnl_ops-inl.h
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_ops-inl.h
- * \brief
- * \author Da Zheng
- */
-
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/optional.h>
-#include <mxnet/base.h>
-#include <mxnet/io.h>
-#include <mxnet/ndarray.h>
-#include <mxnet/operator.h>
-#include <mxnet/operator_util.h>
-
-#include <vector>
-
-#if MXNET_USE_ONEDNN == 1
-#include <dnnl.hpp>
-
-namespace mxnet {
-namespace op {
-
-/* For fully connected. */
-void DNNLFCForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-void DNNLFCBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For convolution. */
-void DNNLConvolutionForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-void DNNLConvolutionBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For deconvolution */
-void DNNLDeconvolutionForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-void DNNLDeconvolutionBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For activation */
-void DNNLActivationForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& in_data,
- const OpReqType& req,
- const NDArray& out_data);
-void DNNLActivationBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-void DNNLLeakyReluForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& in_data,
- const OpReqType& req,
- const NDArray& out_data);
-void DNNLLeakyReluBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For softmax */
-void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& in_data,
- const OpReqType& req,
- const NDArray& out_data);
-void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-
-/* For log_softmax */
-void DNNLLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& in_data,
- const OpReqType& req,
- const NDArray& out_data);
-void DNNLLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-
-void DNNLMaskedSoftmaxForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For softmax_output */
-void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-
-void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For sum */
-void DNNLSumForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For copy */
-void DNNLCopy(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& in_data,
- const OpReqType& req,
- const NDArray& out_data);
-
-/* For concat */
-void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For batch dot */
-template <bool subgraph>
-void DNNLBatchDotForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-/* For layer normalization */
-void DNNLLayerNormForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);
-
-void DNNLStackForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& in_data,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data);
-
-template <class ParamType>
-void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& data,
- const OpReqType& req,
- const NDArray& output);
-
-void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& input,
- const OpReqType& req,
- const NDArray& output);
-
-void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
-void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& input,
- const OpReqType& req,
- const NDArray& output);
-
-} // namespace op
-} // namespace mxnet
-
-#endif // MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_NN_DNNL_DNNL_OPS_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h b/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
index eddbffffff..5ece7ef832 100644
--- a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
+++ b/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
@@ -28,7 +28,6 @@
#if MXNET_USE_ONEDNN == 1
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "operator/tensor/elemwise_binary_scalar_op.h"
namespace mxnet {
@@ -54,6 +53,12 @@ class DNNLPowerFwd {
typedef OpSignature DNNLPowerSignature;
+void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& input,
+ const OpReqType& req,
+ const NDArray& output);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_reduce-inl.h b/src/operator/nn/dnnl/dnnl_reduce-inl.h
index 9e3f0bd2a5..f33206afb9 100644
--- a/src/operator/nn/dnnl/dnnl_reduce-inl.h
+++ b/src/operator/nn/dnnl/dnnl_reduce-inl.h
@@ -28,7 +28,6 @@
#include <vector>
#include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_reshape-inl.h
index 04e1fecb01..01f26994f6 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_reshape-inl.h
@@ -53,6 +53,13 @@ typedef OpSignature DNNLReshapeSignature;
DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
const NDArray& input,
const NDArray& output);
+
+void DNNLReshapeForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& input,
+ const OpReqType& req,
+ const NDArray& output);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_reshape.cc b/src/operator/nn/dnnl/dnnl_reshape.cc
index ec96a5eb74..d1270a3a8c 100644
--- a/src/operator/nn/dnnl/dnnl_reshape.cc
+++ b/src/operator/nn/dnnl/dnnl_reshape.cc
@@ -26,7 +26,6 @@
#if MXNET_USE_ONEDNN == 1
#include "operator/tensor/elemwise_unary_op.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "dnnl_reshape-inl.h"
namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_slice.cc b/src/operator/nn/dnnl/dnnl_slice.cc
index 3008133425..102bf684fb 100644
--- a/src/operator/nn/dnnl/dnnl_slice.cc
+++ b/src/operator/nn/dnnl/dnnl_slice.cc
@@ -26,7 +26,6 @@
#if MXNET_USE_ONEDNN == 1
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "dnnl_slice-inl.h"
namespace mxnet {
diff --git a/src/operator/nn/dnnl/dnnl_softmax-inl.h b/src/operator/nn/dnnl/dnnl_softmax-inl.h
index 42558c6bbf..5b201ad13a 100644
--- a/src/operator/nn/dnnl/dnnl_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax-inl.h
@@ -37,7 +37,6 @@
#include <vector>
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "operator/nn/softmax-inl.h"
@@ -106,6 +105,36 @@ class DNNLSoftmaxBwd {
std::shared_ptr<linear_t> temperature_fwd;
};
+void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const OpReqType& req,
+ const NDArray& out_data);
+
+void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
+void DNNLLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const OpReqType& req,
+ const NDArray& out_data);
+
+void DNNLLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
+void DNNLMaskedSoftmaxForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
#endif
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_softmax_output-inl.h
similarity index 50%
copy from src/operator/nn/dnnl/dnnl_reshape-inl.h
copy to src/operator/nn/dnnl/dnnl_softmax_output-inl.h
index 04e1fecb01..d0fb1d820d 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax_output-inl.h
@@ -18,43 +18,30 @@
*/
/*!
- * \file dnnl_reshape-inl.h
- * \brief Function definition of dnnl reshape operator
+ * \file dnnl_softmax_output-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
*/
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
#if MXNET_USE_ONEDNN == 1
-#include <vector>
-#include "operator/tensor/matrix_op-inl.h"
-#include "dnnl_base-inl.h"
+#include <dnnl.hpp>
+#include <vector>
namespace mxnet {
namespace op {
-class DNNLReshapeFwd {
- protected:
- std::shared_ptr<dnnl::memory> out_;
- std::shared_ptr<dnnl::memory> temp_;
- std::vector<dnnl::primitive> prims_;
-
- public:
- DNNLReshapeFwd(const OpReqType& req, const NDArray& input, const NDArray& output);
- int GetWorkspaceSize();
- void Execute(const NDArray& input,
- const NDArray& output,
- const OpReqType& req,
- void* workspace = nullptr);
-};
-
-typedef OpSignature DNNLReshapeSignature;
-DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
- const NDArray& input,
- const NDArray& output);
+void DNNLSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_OUTPUT_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_softmax_output.cc b/src/operator/nn/dnnl/dnnl_softmax_output.cc
index 94b0029d9b..ba79effc5b 100644
--- a/src/operator/nn/dnnl/dnnl_softmax_output.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax_output.cc
@@ -27,7 +27,6 @@
#include "operator/softmax_output-inl.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/dnnl/dnnl_split-inl.h b/src/operator/nn/dnnl/dnnl_split-inl.h
index a8cdc4cd93..051bf4793f 100644
--- a/src/operator/nn/dnnl/dnnl_split-inl.h
+++ b/src/operator/nn/dnnl/dnnl_split-inl.h
@@ -28,7 +28,6 @@
#include <vector>
#include "./dnnl_base-inl.h"
-#include "./dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -63,6 +62,12 @@ class DNNLSplitFwd {
dnnl::memory::dims strides;
};
+void DNNLSplitForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
#endif
diff --git a/src/operator/nn/dnnl/dnnl_reshape-inl.h b/src/operator/nn/dnnl/dnnl_stack-inl.h
similarity index 50%
copy from src/operator/nn/dnnl/dnnl_reshape-inl.h
copy to src/operator/nn/dnnl/dnnl_stack-inl.h
index 04e1fecb01..ce7522f9c7 100644
--- a/src/operator/nn/dnnl/dnnl_reshape-inl.h
+++ b/src/operator/nn/dnnl/dnnl_stack-inl.h
@@ -18,43 +18,31 @@
*/
/*!
- * \file dnnl_reshape-inl.h
- * \brief Function definition of dnnl reshape operator
+ * \file dnnl_stack-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
*/
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
#if MXNET_USE_ONEDNN == 1
+
#include <vector>
-#include "operator/tensor/matrix_op-inl.h"
-#include "dnnl_base-inl.h"
+#include <dnnl.hpp>
namespace mxnet {
namespace op {
-class DNNLReshapeFwd {
- protected:
- std::shared_ptr<dnnl::memory> out_;
- std::shared_ptr<dnnl::memory> temp_;
- std::vector<dnnl::primitive> prims_;
-
- public:
- DNNLReshapeFwd(const OpReqType& req, const NDArray& input, const NDArray& output);
- int GetWorkspaceSize();
- void Execute(const NDArray& input,
- const NDArray& output,
- const OpReqType& req,
- void* workspace = nullptr);
-};
-
-typedef OpSignature DNNLReshapeSignature;
-DNNLReshapeFwd& GetReshapeForward(const OpReqType& req,
- const NDArray& input,
- const NDArray& output);
+void DNNLStackForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data);
+
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_NN_DNNL_DNNL_RESHAPE_INL_H_
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_STACK_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_stack.cc b/src/operator/nn/dnnl/dnnl_stack.cc
index 4cd45e104d..981a4f87c6 100644
--- a/src/operator/nn/dnnl/dnnl_stack.cc
+++ b/src/operator/nn/dnnl/dnnl_stack.cc
@@ -20,10 +20,11 @@
/*!
* \file dnnl_stack.cc
*/
+#include <vector>
#include "dnnl_base-inl.h"
#include "dnnl_concat-inl.h"
-#include "dnnl_ops-inl.h"
+#include "dnnl_stack-inl.h"
#include "operator/tensor/matrix_op-inl.h"
@@ -103,7 +104,7 @@ void DNNLStackForward(const nnvm::NodeAttrs& attrs,
}
});
- auto& fwd = GetConcatForward(stacking_dim, in_data, data_md, axis);
+ auto& fwd = DNNLConcatFwd::GetCached(stacking_dim, in_data, data_md, axis);
mxnet::dnnl_output_t out_mem =
CreateDNNLMem(out_data[concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
diff --git a/src/operator/nn/dnnl/dnnl_sum-inl.h b/src/operator/nn/dnnl/dnnl_sum-inl.h
new file mode 100644
index 0000000000..14b2349b06
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_sum-inl.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_sum-inl.h
+ * \brief
+ * \author Wolinski Piotr piotr.wolinski@intel.com
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <vector>
+
+#include <dnnl.hpp>
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+using sum_t = dnnl::sum;
+using sum_pd_t = dnnl::sum::primitive_desc;
+
+class DNNLSumFwd {
+ public:
+ typedef OpSignature DNNLSumSignature;
+
+ static DNNLSumFwd& GetCached(const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs);
+
+ explicit DNNLSumFwd(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs);
+
+ void Execute(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+ private:
+ std::shared_ptr<sum_t> fwd;
+ std::shared_ptr<sum_pd_t> fwd_pd;
+};
+
+void DNNLSumForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_SUM_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_sum.cc b/src/operator/nn/dnnl/dnnl_sum.cc
index c2626a26b4..5995778ccd 100644
--- a/src/operator/nn/dnnl/dnnl_sum.cc
+++ b/src/operator/nn/dnnl/dnnl_sum.cc
@@ -23,112 +23,88 @@
* \author Da Zheng
*/
#include <iostream>
+#include <vector>
#include "operator/operator_common.h"
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
+#include "dnnl_sum-inl.h"
namespace mxnet {
namespace op {
#if MXNET_USE_ONEDNN == 1
-void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out) {
- std::vector<dnnl::memory::desc> input_pds(2);
- std::vector<float> scales(2, 1);
- input_pds[0] = arr1.get_desc();
- input_pds[1] = arr2.get_desc();
- CHECK(input_pds[0] == input_pds[0]);
- const dnnl::memory* in_mem1 = &arr1;
- const dnnl::memory* in_mem2 = &arr2;
- auto output_pd = out.get_desc();
- if (input_pds[0] != output_pd) {
- auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd);
- auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd);
- DNNLMemoryCopy(arr1, tmp_memory1);
- DNNLMemoryCopy(arr2, tmp_memory2);
- input_pds[0] = tmp_memory1->get_desc();
- input_pds[1] = tmp_memory2->get_desc();
- in_mem1 = tmp_memory1;
- in_mem2 = tmp_memory2;
- }
- dnnl::sum::primitive_desc sum_pd(output_pd, scales, input_pds, CpuEngine::Get()->get_engine());
- dnnl_args_map_t args = {
- {DNNL_ARG_MULTIPLE_SRC, *in_mem1},
- {DNNL_ARG_MULTIPLE_SRC + 1, *in_mem2},
- {DNNL_ARG_DST, out},
- };
- DNNLStream::Get()->RegisterPrimArgs(dnnl::sum(sum_pd), args);
-}
-
-class DNNLSumFwd {
- public:
- dnnl::sum::primitive_desc fwd_pd;
-
- DNNLSumFwd(const std::vector<float>& scales, const std::vector<dnnl::memory::desc>& data_md)
- : fwd_pd(scales, data_md, CpuEngine::Get()->get_engine()) {
- fwd_ = std::make_shared<dnnl::sum>(fwd_pd);
- }
- const dnnl::sum& GetFwd() const {
- return *fwd_;
- }
-
- private:
- std::shared_ptr<dnnl::sum> fwd_;
-};
-
-static DNNLSumFwd& GetSumForward(const std::vector<float>& scales,
- const std::vector<NDArray>& in_data,
- const std::vector<dnnl::memory::desc>& data_md) {
+DNNLSumFwd& DNNLSumFwd::GetCached(const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs) {
#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<OpSignature, DNNLSumFwd, OpHash> fwds;
+ static thread_local std::unordered_map<DNNLSumSignature, DNNLSumFwd, OpHash> fwds;
#else
- static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLSumFwd, OpHash> fwds;
+ static MX_THREAD_LOCAL std::unordered_map<DNNLSumSignature, DNNLSumFwd, OpHash> fwds;
#endif
- OpSignature key;
- key.AddSign(in_data);
+ DNNLSumSignature key;
+ key.AddSign(inputs);
auto it = fwds.find(key);
if (it == fwds.end()) {
- DNNLSumFwd fwd(scales, data_md);
+ const DNNLSumFwd fwd(inputs, outputs);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
-void DNNLSumForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- TmpMemMgr::Get()->Init(ctx.requested[0]);
+DNNLSumFwd::DNNLSumFwd(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs) {
const int num_inputs = inputs.size();
- const NDArray& out_data = outputs[0];
+
std::vector<dnnl::memory::desc> data_md;
- std::vector<const dnnl::memory*> data_mem;
+
std::vector<float> scales(num_inputs, 1);
data_md.reserve(num_inputs);
- data_mem.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const dnnl::memory* in_mem = inputs[i].GetDNNLData();
dnnl::memory::desc tmp_md = in_mem->get_desc();
data_md.push_back(tmp_md);
+ }
+
+ fwd_pd = std::make_shared<sum_pd_t>(scales, data_md, CpuEngine::Get()->get_engine());
+ fwd = std::make_shared<sum_t>(*fwd_pd);
+}
+
+void DNNLSumFwd::Execute(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const NDArray& out_data = outputs[0];
+ const int num_inputs = inputs.size();
+ std::vector<const dnnl::memory*> data_mem;
+
+ data_mem.reserve(num_inputs);
+
+ for (int i = 0; i < num_inputs; ++i) {
+ const dnnl::memory* in_mem = inputs[i].GetDNNLData();
data_mem.push_back(in_mem);
}
- DNNLSumFwd& fwd = GetSumForward(scales, inputs, data_md);
- mxnet::dnnl_output_t out_mem = CreateDNNLMem(out_data, fwd.fwd_pd.dst_desc(), req[0], &inputs[0]);
+ mxnet::dnnl_output_t out_mem = CreateDNNLMem(out_data, fwd_pd->dst_desc(), req[0], &inputs[0]);
dnnl_args_map_t net_args;
net_args.insert({DNNL_ARG_DST, *out_mem.second});
for (int i = 0; i < num_inputs; ++i) {
net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, *data_mem[i]});
}
- DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+ DNNLStream::Get()->RegisterPrimArgs(*fwd, net_args);
CommitOutput(out_data, out_mem);
DNNLStream::Get()->Submit();
}
+
+void DNNLSumForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+ fwd.Execute(ctx, inputs, req, outputs);
+}
#endif
} // namespace op
diff --git a/src/operator/nn/dnnl/dnnl_transpose-inl.h b/src/operator/nn/dnnl/dnnl_transpose-inl.h
index a829dc2315..aa6e071da7 100644
--- a/src/operator/nn/dnnl/dnnl_transpose-inl.h
+++ b/src/operator/nn/dnnl/dnnl_transpose-inl.h
@@ -27,7 +27,6 @@
#if MXNET_USE_ONEDNN == 1
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
#include "operator/numpy/np_matrix_op-inl.h"
diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h
index bfda684668..cd3c643fe7 100644
--- a/src/operator/nn/dnnl/dnnl_where-inl.h
+++ b/src/operator/nn/dnnl/dnnl_where-inl.h
@@ -29,7 +29,6 @@
#include <unordered_map>
#include <vector>
#include "dnnl_base-inl.h"
-#include "dnnl_ops-inl.h"
namespace mxnet {
namespace op {
@@ -67,6 +66,12 @@ class DNNLWhereFwd {
bool SupportDNNLWhere(const std::vector<NDArray>& inputs);
+void DNNLWhereForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
} // namespace op
} // namespace mxnet
#endif
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index f5d6c2c966..61e7ca5ea9 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -22,8 +22,11 @@
* \brief fully connect operator
*/
#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
#include "./fully_connected-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_fully_connected-inl.h"
+#endif // MXNET_USE_ONEDNN == 1
namespace mxnet {
namespace op {
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index f0b989fda1..ae8c429ec5 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -27,8 +27,8 @@
#include "../elemwise_op_common.h"
#include "layer_norm_cpu.h"
#if MXNET_USE_ONEDNN == 1
-#include "./dnnl/dnnl_base-inl.h"
-#include "./dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_layer_norm-inl.h"
#endif // MXNET_USE_ONEDNN
#if MSHADOW_USE_MKL == 1
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index f56e7acda7..e63f09c71f 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -26,8 +26,8 @@
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_ONEDNN == 1
-#include "dnnl/dnnl_base-inl.h"
-#include "dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
#endif
namespace mxnet {
diff --git a/src/operator/nn/masked_softmax.cc b/src/operator/nn/masked_softmax.cc
index 386b53e8d9..ee8c902833 100644
--- a/src/operator/nn/masked_softmax.cc
+++ b/src/operator/nn/masked_softmax.cc
@@ -27,7 +27,7 @@
#include "operator/operator_common.h"
#if MXNET_USE_ONEDNN == 1
#include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
#endif
namespace mxnet {
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 004e5a2bb4..0fb4e338cb 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -26,8 +26,8 @@
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_ONEDNN == 1
-#include "dnnl/dnnl_base-inl.h"
-#include "dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax-inl.h"
#endif
namespace mxnet {
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 8b556b5ef3..946527562a 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -27,9 +27,8 @@
#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"
#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_transpose-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_transpose-inl.h"
#endif
namespace mxnet {
namespace op {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_act.cc b/src/operator/quantization/dnnl/dnnl_quantized_act.cc
index 6cf31a0ee2..ec780b91d5 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_act.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_act.cc
@@ -23,7 +23,8 @@
*/
#if MXNET_USE_ONEDNN == 1
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/activation-inl.h"
+#include "operator/nn/dnnl/dnnl_act-inl.h"
#include "operator/quantization/quantization_utils.h"
namespace mxnet {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
index a6f9e85427..ee00e6b61d 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
@@ -97,7 +97,7 @@ static void DNNLQuantizedConcatForward(const nnvm::NodeAttrs& attrs,
}
int param_dim = param_.dim.has_value() ? param_.dim.value() : 0;
param_dim = CheckAxis(param_dim, in_data[concat_enum::kData0].shape().ndim());
- DNNLConcatFwd& fwd = GetConcatForward(param_dim, in_data, data_md);
+ DNNLConcatFwd& fwd = DNNLConcatFwd::GetCached(param_dim, in_data, data_md);
mxnet::dnnl_output_t out_mem = CreateDNNLMem(
out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
dnnl_args_map_t net_args;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
index 518dc0748e..82e66a1ed2 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
@@ -24,7 +24,6 @@
#if MXNET_USE_ONEDNN == 1
#include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
#include "operator/quantization/quantization_utils.h"
#include "operator/quantization/quantized_elemwise_add-inl.h"
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
index a605a16813..0b14a96777 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_flatten.cc
@@ -23,7 +23,7 @@
*/
#if MXNET_USE_ONEDNN == 1
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
#include "operator/quantization/quantization_utils.h"
namespace mxnet {
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
index 0d468fde7a..344bda3f5e 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_reshape.cc
@@ -24,7 +24,7 @@
#if MXNET_USE_ONEDNN == 1
#include "operator/quantization/quantized_reshape-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc
index 95fbd3bba2..f42a550573 100644
--- a/src/operator/quantization/quantized_conv.cc
+++ b/src/operator/quantization/quantized_conv.cc
@@ -23,9 +23,6 @@
* \author Ziheng Jiang, Jun Wu
*/
#include "../nn/convolution-inl.h"
-#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#endif
namespace mxnet {
namespace op {
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 401968f2d2..52a9888ea8 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -24,8 +24,8 @@
*/
#include "./softmax_output-inl.h"
#if MXNET_USE_ONEDNN == 1
-#include "./nn/dnnl/dnnl_base-inl.h"
-#include "./nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_softmax_output-inl.h"
#endif
namespace mxnet {
namespace op {
diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
index c9beffc90c..6905b118ba 100644
--- a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
@@ -30,7 +30,6 @@
#include "operator/nn/dnnl/dnnl_base-inl.h"
#include "operator/nn/dnnl/dnnl_batch_dot-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
#include "operator/quantization/quantization_utils.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/subgraph/common.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 262746068f..5b2c9ad3e0 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -25,7 +25,6 @@
#include "operator/nn/dnnl/dnnl_act-inl.h"
#include "operator/nn/dnnl/dnnl_base-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
#include "operator/quantization/quantization_utils.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/subgraph/common.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_conv_property.h b/src/operator/subgraph/dnnl/dnnl_conv_property.h
index c5e027d5b1..3d814c0598 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_conv_property.h
@@ -28,7 +28,6 @@
#include "operator/leaky_relu-inl.h"
#include "operator/nn/activation-inl.h"
#include "operator/nn/convolution-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/subgraph/common.h"
#include "dnnl_subgraph_base-inl.h"
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index a5c199f964..24b7ec6883 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -35,7 +35,6 @@
#include "operator/nn/dnnl/dnnl_act-inl.h"
#include "operator/nn/dnnl/dnnl_base-inl.h"
#include "operator/nn/dnnl/dnnl_fully_connected-inl.h"
-#include "operator/nn/dnnl/dnnl_ops-inl.h"
#include "operator/quantization/quantization_utils.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/subgraph/common.h"
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index e1b881c2b7..ad29575cb3 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -21,8 +21,9 @@
* \file elemwise_binary_op_basic.cc
* \brief CPU Implementation of basic elementwise binary broadcast operators
*/
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_copy-inl.h"
+#include "operator/nn/dnnl/dnnl_sum-inl.h"
#include "./elemwise_binary_op-inl.h"
#include "./elemwise_unary_op.h"
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 67842cd25f..8e6e2ad11c 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -25,8 +25,10 @@
#include "../../common/utils.h"
#include "../../ndarray/ndarray_function.h"
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_sum-inl.h"
+#endif // MXNET_USE_ONEDNN == 1
namespace mxnet {
namespace op {
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 3cc930b0d8..9c317fd3a7 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -23,8 +23,8 @@
*/
#include <mxnet/base.h>
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "./elemwise_binary_op-inl.h"
+#include "elemwise_binary_op-inl.h"
+#include "operator/nn/dnnl/dnnl_copy-inl.h"
#include "elemwise_unary_op.h"
namespace mxnet {
diff --git a/src/operator/tensor/elemwise_unary_op_logexp.cc b/src/operator/tensor/elemwise_unary_op_logexp.cc
index 65bc767312..e94862d7dc 100644
--- a/src/operator/tensor/elemwise_unary_op_logexp.cc
+++ b/src/operator/tensor/elemwise_unary_op_logexp.cc
@@ -24,7 +24,6 @@
#include <mxnet/base.h>
#include "../../nnvm/node_op_util.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
#include "./elemwise_binary_op-inl.h"
#include "elemwise_unary_op.h"
diff --git a/src/operator/tensor/elemwise_unary_op_pow.cc b/src/operator/tensor/elemwise_unary_op_pow.cc
index b4e35c4c26..51f65553ce 100644
--- a/src/operator/tensor/elemwise_unary_op_pow.cc
+++ b/src/operator/tensor/elemwise_unary_op_pow.cc
@@ -24,7 +24,6 @@
#include <mxnet/base.h>
#include "../../nnvm/node_op_util.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
#include "./elemwise_binary_op-inl.h"
#include "elemwise_unary_op.h"
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index bc97aa4a53..15f131cfcd 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -25,12 +25,12 @@
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"
#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_base-inl.h"
-#include "../nn/dnnl/dnnl_ops-inl.h"
-#include "../nn/dnnl/dnnl_reshape-inl.h"
-#include "../nn/dnnl/dnnl_slice-inl.h"
-#include "../nn/dnnl/dnnl_transpose-inl.h"
-#include "../nn/dnnl/dnnl_split-inl.h"
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_reshape-inl.h"
+#include "operator/nn/dnnl/dnnl_slice-inl.h"
+#include "operator/nn/dnnl/dnnl_transpose-inl.h"
+#include "operator/nn/dnnl/dnnl_split-inl.h"
+#include "operator/nn/dnnl/dnnl_stack-inl.h"
#endif
namespace mxnet {
diff --git a/tests/cpp/operator/dnnl_operator_test.cc b/tests/cpp/operator/dnnl_operator_test.cc
index e66fc56bab..dd02656897 100644
--- a/tests/cpp/operator/dnnl_operator_test.cc
+++ b/tests/cpp/operator/dnnl_operator_test.cc
@@ -34,7 +34,6 @@
#include "../../src/operator/nn/convolution-inl.h"
#include "../../src/operator/nn/deconvolution-inl.h"
#include "../../src/operator/nn/dnnl/dnnl_base-inl.h"
-#include "../../src/operator/nn/dnnl/dnnl_ops-inl.h"
#include "../../src/operator/nn/dnnl/dnnl_pooling-inl.h"
#include "../../src/operator/nn/pooling-inl.h"
#include "../include/test_dnnl.h"
diff --git a/tests/cpp/operator/dnnl_test.cc b/tests/cpp/operator/dnnl_test.cc
index a420c2a7b9..565b551479 100644
--- a/tests/cpp/operator/dnnl_test.cc
+++ b/tests/cpp/operator/dnnl_test.cc
@@ -32,7 +32,6 @@
#include <set>
#include "../../src/operator/nn/dnnl/dnnl_base-inl.h"
-#include "../../src/operator/nn/dnnl/dnnl_ops-inl.h"
#include "../include/test_dnnl.h"
#include "gtest/gtest.h"
#include "mxnet/imperative.h"
@@ -189,7 +188,7 @@ TEST(DNNL_NDArray, GetDataReorder) {
}
}
-TEST(DNNL_BASE, DNNLSum) {
+TEST(DNNL_BASE, DNNLMemorySum) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
TestArrayShapes tas = GetTestArrayShapes();
@@ -211,7 +210,7 @@ TEST(DNNL_BASE, DNNLSum) {
continue;
auto out_mem = out_arr.arr.GetDNNLData();
PrintVerifyMsg(in_arr, in_arr);
- op::DNNLSum(*in_mem1, *in_mem2, *out_mem);
+ op::DNNLMemorySum(*in_mem1, *in_mem2, *out_mem);
DNNLStream::Get()->Submit();
VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
}
@@ -233,7 +232,7 @@ TEST(DNNL_BASE, DNNLSum) {
PrintVerifyMsg(orig_arr, in_arr);
InitDNNLArray(&orig_arr.arr, input_mem->get_desc());
orig_arr.arr.CopyFrom(*input_mem);
- op::DNNLSum(*input_mem, *input_mem2, *input_mem);
+ op::DNNLMemorySum(*input_mem, *input_mem2, *input_mem);
DNNLStream::Get()->Submit();
VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
}
@@ -264,7 +263,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
PrintVerifyMsg(in_arr, out_arr);
auto out_mem = out_arr.arr.GetDNNLData();
auto output_mem_t = CreateDNNLMem(out_arr.arr, out_mem->get_desc(), kWriteTo);
- op::DNNLSum(*in_mem, *in_mem2, *output_mem_t.second);
+ op::DNNLMemorySum(*in_mem, *in_mem2, *output_mem_t.second);
CommitOutput(out_arr.arr, output_mem_t);
stream->Submit();
VerifySumResult({&in_arr.arr, &in_arr2.arr}, {&out_arr.arr});
@@ -289,7 +288,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
orig_arr.arr.CopyFrom(*input_mem);
auto output_mem_t =
CreateDNNLMem(in_arr.arr, input_mem->get_desc(), kWriteInplace, &in_arr.arr);
- op::DNNLSum(*input_mem, *input_mem2, *output_mem_t.second);
+ op::DNNLMemorySum(*input_mem, *input_mem2, *output_mem_t.second);
CommitOutput(in_arr.arr, output_mem_t);
stream->Submit();
VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
@@ -313,7 +312,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
PrintVerifyMsg(in_arr, out_arr);
auto out_mem = out_arr.arr.GetDNNLData();
auto output_mem_t = CreateDNNLMem(out_arr.arr, out_mem->get_desc(), kAddTo);
- op::DNNLSum(*in_mem, *in_mem2, *output_mem_t.second);
+ op::DNNLMemorySum(*in_mem, *in_mem2, *output_mem_t.second);
CommitOutput(out_arr.arr, output_mem_t);
stream->Submit();
VerifyAddRequest(
@@ -338,7 +337,7 @@ TEST(DNNL_BASE, CreateDNNLMem) {
InitDNNLArray(&orig_arr.arr, input_mem->get_desc());
orig_arr.arr.CopyFrom(*input_mem);
auto output_mem_t = CreateDNNLMem(in_arr.arr, input_mem->get_desc(), kNullOp);
- op::DNNLSum(*input_mem, *input_mem2, *output_mem_t.second);
+ op::DNNLMemorySum(*input_mem, *input_mem2, *output_mem_t.second);
CommitOutput(in_arr.arr, output_mem_t);
stream->Submit();
// original and input should be the same since noop