You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/09/24 09:45:56 UTC
[incubator-mxnet] branch v1.x updated: Decouple OneDNN data
structures in MXNet C++ API (#20349)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 868d06b Decouple OneDNN data structures in MXNet C++ API (#20349)
868d06b is described below
commit 868d06b236ee1005d40b43b7529171ce432605f8
Author: mozga <ma...@intel.com>
AuthorDate: Fri Sep 24 11:43:53 2021 +0200
Decouple OneDNN data structures in MXNet C++ API (#20349)
* NDArry file has been modified, there are a few chnages:
1. OneDNN header was moved into *cc file
2. OneDNN object are created on-the-fly: static_cast is needed
* Santiy check
* Adaptive pooling: fix
---
include/mxnet/ndarray.h | 24 +-
src/ndarray/ndarray.cc | 112 ++++-----
src/operator/nn/concat.cc | 259 +++++++++++----------
src/operator/nn/mkldnn/mkldnn_act.cc | 25 +-
src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc | 9 +-
src/operator/nn/mkldnn/mkldnn_base-inl.h | 9 +-
src/operator/nn/mkldnn/mkldnn_base.cc | 95 ++++----
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 50 ++--
src/operator/nn/mkldnn/mkldnn_concat.cc | 6 +-
src/operator/nn/mkldnn/mkldnn_convolution.cc | 41 ++--
src/operator/nn/mkldnn/mkldnn_copy.cc | 15 +-
src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h | 59 ++---
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 12 +-
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 34 ++-
src/operator/nn/mkldnn/mkldnn_log_softmax.cc | 22 +-
src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 25 +-
src/operator/nn/mkldnn/mkldnn_pooling.cc | 22 +-
src/operator/nn/mkldnn/mkldnn_reshape.cc | 16 +-
src/operator/nn/mkldnn/mkldnn_slice.cc | 8 +-
src/operator/nn/mkldnn/mkldnn_softmax.cc | 22 +-
src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 4 +-
src/operator/nn/mkldnn/mkldnn_sum.cc | 2 +-
src/operator/nn/mkldnn/mkldnn_transpose.cc | 5 +-
src/operator/operator_common.h | 6 +-
.../quantization/mkldnn/mkldnn_dequantize-inl.h | 2 +-
.../quantization/mkldnn/mkldnn_quantize-inl.h | 2 +-
.../quantization/mkldnn/mkldnn_quantize_v2-inl.h | 5 +-
.../mkldnn/mkldnn_quantized_batch_norm.cc | 12 +-
.../quantization/mkldnn/mkldnn_quantized_concat.cc | 4 +-
.../quantization/mkldnn/mkldnn_quantized_conv.cc | 16 +-
.../mkldnn/mkldnn_quantized_elemwise_add.cc | 6 +-
.../mkldnn/mkldnn_quantized_fully_connected.cc | 13 +-
.../quantization/mkldnn/mkldnn_requantize-inl.h | 2 +-
src/operator/subgraph/mkldnn/mkldnn_common.h | 12 +-
src/operator/subgraph/mkldnn/mkldnn_conv.cc | 37 +--
src/operator/subgraph/mkldnn/mkldnn_fc.cc | 53 ++---
src/operator/tensor/amp_cast.cc | 7 +-
src/operator/tensor/cast_storage-inl.h | 4 +-
tests/cpp/include/test_mkldnn.h | 6 +-
tests/cpp/operator/mkldnn_operator_test.cc | 3 +-
tests/cpp/operator/mkldnn_test.cc | 169 +++++++-------
41 files changed, 665 insertions(+), 570 deletions(-)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 0febc65..92aefd6 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -37,9 +37,7 @@
#include <memory>
#include <string>
#include <vector>
-#if MXNET_USE_MKLDNN == 1
-#include <mkldnn.hpp>
-#endif
+
#include "./base.h"
#include "./engine.h"
#include "./storage.h"
@@ -711,12 +709,12 @@ class NDArray {
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
*/
- explicit NDArray(const std::shared_ptr<mkldnn::memory>& mkldnn_mem);
+ explicit NDArray(const std::shared_ptr<void>& mkldnn_mem);
/*
* Create NDArray from mkldnn memory descriptor.
* mem_pd The mkldnn memory descriptor to be created.
*/
- explicit NDArray(const mkldnn::memory::desc& md);
+ explicit NDArray(const void* md);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
@@ -739,29 +737,29 @@ class NDArray {
/*
* This function returns mkldnn::memory with the default primitive_desc.
*/
- const mkldnn::memory* GetMKLDNNData() const;
+ const void* GetMKLDNNData() const;
/*
* This function returns mkldnn::memory with the given primitive_desc
* as long as the array size meets the required size in the given
* primitive_desc.
*/
- const mkldnn::memory* GetMKLDNNData(const mkldnn::memory::desc& md) const;
+ const void* GetMKLDNNData(const void* md) const;
/*
* This function returns mkldnn::memory with the given primitive_desc.
* The returned mkldnn::memory will have the same physical layout as
* the given primitive_desc.
*/
- const mkldnn::memory* GetMKLDNNDataReorder(const mkldnn::memory::desc& md) const;
+ const void* GetMKLDNNDataReorder(const void* md) const;
/*
* This function copies data from mkldnn memory.
*/
- void CopyFrom(const mkldnn::memory& mem);
+ void CopyFrom(const void* mem);
/*
* This function allocates memory for array and creates mkldnn memory
* with the specified format.
*/
- mkldnn::memory* CreateMKLDNNData(const mkldnn::memory::desc& md);
+ void* CreateMKLDNNData(const void* md);
/*
* These are the async version of the methods above.
@@ -769,7 +767,7 @@ class NDArray {
* the array are complete.
*/
void Reorder2DefaultAsync() const;
- void MKLDNNDataReorderAsync(const mkldnn::memory::desc& md) const;
+ void MKLDNNDataReorderAsync(const void* md) const;
/*
* This creates a new NDArray with the reordered data.
@@ -800,7 +798,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
- void UpdateMKLDNNMemDesc(const mkldnn::memory::desc& desc);
+ void UpdateMKLDNNMemDesc(const void* desc);
#endif
/*!
@@ -1087,7 +1085,7 @@ class NDArray {
// save the result in shandle.
void Reorder2Default();
// Reroder data to a specified layout.
- void MKLDNNDataReorder(const mkldnn::memory::desc& md);
+ void MKLDNNDataReorder(const void* md);
bool IsMKLDNN() const;
bool IsDefault() const;
#endif
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index c7188e3..46262ed 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -37,7 +37,9 @@
#include "../operator/tensor/init_op.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "./ndarray_function.h"
-
+#if MXNET_USE_MKLDNN == 1
+#include <mkldnn.hpp>
+#endif
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif // MXNET_USE_OPENCV
@@ -147,8 +149,7 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape& shape, int dtype) {
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(shape.Size(), (int64_t{1} << 31) - 1)
<< "[CheckAndAllocData] Size of tensor you are trying to allocate is "
- "larger than "
- "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+ "larger than 2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
if (shandle.size < dbytes) {
// free storage
@@ -187,16 +188,19 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
#if MXNET_USE_MKLDNN == 1
-NDArray::NDArray(const mkldnn::memory::desc& md) : storage_type_(kDefaultStorage), entry_(nullptr) {
- shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
- dtype_ = get_mxnet_type(md.data.data_type);
- ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+NDArray::NDArray(const void* md_desc) : storage_type_(kDefaultStorage), entry_(nullptr) {
+ mkldnn::memory::desc md = *static_cast<const mkldnn::memory::desc*>(md_desc);
+ shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
+ dtype_ = get_mxnet_type(md.data.data_type);
+ ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->CheckAndAlloc(md.get_size());
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(md, ptr_->shandle.dptr);
}
-NDArray::NDArray(const std::shared_ptr<mkldnn::memory>& mkldnn_mem)
+NDArray::NDArray(const std::shared_ptr<void>& mkldnn_mem_ptr)
: storage_type_(kDefaultStorage), entry_(nullptr) {
+ std::shared_ptr<mkldnn::memory> mkldnn_mem =
+ std::static_pointer_cast<mkldnn::memory>(mkldnn_mem_ptr);
auto mem_desc = mkldnn_mem->get_desc();
shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
@@ -399,8 +403,7 @@ bool NDArray::fresh_out_grad() const {
void NDArray::set_fresh_out_grad(bool state) const {
CHECK(!Imperative::AGInfo::IsNone(*this))
- << "NDArray has not been marked as a variable and does not have gradient "
- "state";
+ << "NDArray has not been marked as a variable and does not have gradient state";
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
info.fresh_out_grad = state;
}
@@ -444,7 +447,8 @@ void NDArray::Chunk::Reorder2Default() {
mkl_mem_ = nullptr;
}
-void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::desc& md) {
+void NDArray::Chunk::MKLDNNDataReorder(const void* mem_desc) {
+ const mkldnn::memory::desc md = *static_cast<const mkldnn::memory::desc*>(mem_desc);
// If the memory already uses the specified layout, don't do anything.
if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(md))
return;
@@ -530,12 +534,13 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape& shape, int dtype) {
mkl_mem_.reset(new MKLDNNMemory(data_md, shandle.dptr));
}
-const mkldnn::memory* NDArray::GetMKLDNNData(const mkldnn::memory::desc& desc) const {
+const void* NDArray::GetMKLDNNData(const void* mem_desc) const {
+ const mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
return nullptr;
}
- const mkldnn::memory* mem = GetMKLDNNData();
+ const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
mkldnn::memory::desc desc1 = mem->get_desc();
// The MKL memory has the same format and shape as required,
// or both use the default format, we can return the MKL memory.
@@ -546,10 +551,11 @@ const mkldnn::memory* NDArray::GetMKLDNNData(const mkldnn::memory::desc& desc) c
}
}
-const mkldnn::memory* NDArray::GetMKLDNNDataReorder(const mkldnn::memory::desc& new_desc) const {
+const void* NDArray::GetMKLDNNDataReorder(const void* mem_desc) const {
+ mkldnn::memory::desc new_desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
CHECK(storage_type() == kDefaultStorage);
- const mkldnn::memory* mem = GetMKLDNNData();
+ const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
// If the memory descriptor matches, it's easy.
MKLDNNStream* stream = MKLDNNStream::Get();
if (mem->get_desc() == new_desc) {
@@ -578,7 +584,7 @@ const mkldnn::memory* NDArray::GetMKLDNNDataReorder(const mkldnn::memory::desc&
for (int i = 0; i < new_desc.data.ndims; i++)
required_shape[i] = new_desc.data.dims[i];
NDArray reshaped = MKLDNNDataReshape(required_shape);
- const mkldnn::memory* ret = reshaped.GetMKLDNNData();
+ const mkldnn::memory* ret = static_cast<const mkldnn::memory*>(reshaped.GetMKLDNNData());
if (ret->get_desc() == new_desc) {
return GetMKLDNNExact(ret, new_desc);
} else {
@@ -641,24 +647,25 @@ NDArray NDArray::Reorder2DefaultFloatFormat() const {
return Reorder2Default();
}
NDArray ret(shape(), ctx(), false, mshadow::DataType<float>::kFlag);
- auto src_mem = GetMKLDNNData();
- auto dst_mem = ret.GetMKLDNNData();
+ auto src_mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
+ auto dst_mem = static_cast<const mkldnn::memory*>(ret.GetMKLDNNData());
ReorderTo(src_mem, dst_mem);
return ret;
}
-void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const {
+void NDArray::MKLDNNDataReorderAsync(const void* mem_desc) const {
+ mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
std::vector<Engine::VarHandle> const_vars;
std::vector<Engine::VarHandle> mutable_vars(1, this->var());
NDArray tmp = *this;
const auto version = this->version();
Engine::Get()->PushAsync(
[tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
- // MXNet will try to reuse NDArray from memory planning, so we need to
- // ensure the NDArray is still holding the original trunk data.
+ // MXNet will try to reuse NDArray from memory planning, so we need to ensure
+ // the NDArray is still holding the original trunk data.
if (tmp.version() == version) {
- tmp.ptr_->MKLDNNDataReorder(desc);
+ tmp.ptr_->MKLDNNDataReorder(&desc);
}
on_complete();
},
@@ -670,7 +677,7 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc& desc) const {
"Reorder");
}
-const mkldnn::memory* NDArray::GetMKLDNNData() const {
+const void* NDArray::GetMKLDNNData() const {
CHECK(storage_type() == kDefaultStorage);
bool is_view = IsView();
if (IsMKLDNNData()) {
@@ -715,7 +722,8 @@ void NDArray::InvalidateMKLDNNData() {
ptr_->mkl_mem_ = nullptr;
}
-void NDArray::CopyFrom(const mkldnn::memory& mem) {
+void NDArray::CopyFrom(const void* memory) {
+ auto mem = *static_cast<const mkldnn::memory*>(memory);
CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem)
return;
@@ -728,14 +736,14 @@ void NDArray::CopyFrom(const mkldnn::memory& mem) {
if (IsMKLDNNData() && IsView())
ptr_->Reorder2Default();
- const mkldnn::memory* this_mem = GetMKLDNNData();
+ const mkldnn::memory* this_mem = static_cast<const mkldnn::memory*>(GetMKLDNNData());
MKLDNNMemoryCopy(mem, this_mem);
}
-mkldnn::memory* NDArray::CreateMKLDNNData(const mkldnn::memory::desc& desc) {
+void* NDArray::CreateMKLDNNData(const void* mem_desc) {
+ mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
- LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN "
- "memory desc. "
+ LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc. "
<< "MKLDNN memory requests for " << desc.get_size() << " bytes, but got "
<< shape().Size() * GetTypeSize(dtype_) << " bytes from NDArray";
return nullptr;
@@ -779,10 +787,11 @@ mkldnn::memory* NDArray::CreateMKLDNNData(const mkldnn::memory::desc& desc) {
return ptr_->mkl_mem_->GetRaw();
}
-void NDArray::UpdateMKLDNNMemDesc(const mkldnn::memory::desc& desc) {
- auto new_desc = desc;
- auto this_dtype = get_mkldnn_type(dtype());
- new_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
+void NDArray::UpdateMKLDNNMemDesc(const void* mem_desc) {
+ mkldnn::memory::desc desc = *static_cast<const mkldnn::memory::desc*>(mem_desc);
+ auto new_desc = desc;
+ auto this_dtype = get_mkldnn_type(dtype());
+ new_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
ptr_->mkl_mem_.reset(new MKLDNNMemory(new_desc, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
@@ -1203,15 +1212,15 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
#if MXNET_USE_MKLDNN == 1
} else if (SupportMKLDNN(from.dtype(), from.shape()) && SupportMKLDNN(to.dtype(), to.shape()) &&
from.ctx().dev_mask() == cpu::kDevMask && to.ctx().dev_mask() == cpu::kDevMask) {
- // If we copy data directly, we need to make sure both NDArrays are
- // supported by MKLDNN.
- auto from_mem = from.GetMKLDNNData();
- auto to_mem = to.GetMKLDNNData();
+ // If we copy data directly, we need to make sure both NDArrays are supported
+ // by MKLDNN.
+ auto from_mem = static_cast<const mkldnn::memory*>(from.GetMKLDNNData());
+ auto to_mem = static_cast<const mkldnn::memory*>(to.GetMKLDNNData());
if (from_mem->get_desc() == to_mem->get_desc()) {
size_t size = std::min(from_mem->get_desc().get_size(), to_mem->get_desc().get_size());
memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size);
} else {
- const_cast<NDArray&>(to).CopyFrom(*from_mem);
+ const_cast<NDArray&>(to).CopyFrom(from_mem);
MKLDNNStream::Get()->Submit();
}
} else {
@@ -1222,8 +1231,8 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
if (tmp_from.IsMKLDNNData()) {
// TODO(zhengda) tmp_from should be cached.
tmp_from = NDArray(from.shape(), from.ctx(), false, from.dtype());
- auto tmp_mem = from.GetMKLDNNData();
- tmp_from.CopyFrom(*tmp_mem);
+ auto tmp_mem = static_cast<const mkldnn::memory*>(from.GetMKLDNNData());
+ tmp_from.CopyFrom(tmp_mem);
MKLDNNStream::Get()->Submit();
}
CHECK(tmp_from.IsDefaultData());
@@ -2053,8 +2062,7 @@ void NDArray::SyncCopyFromCPU(const void* data, size_t size) const {
mxnet::TShape dshape = this->shape();
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(size, (int64_t{1} << 31) - 1)
- << "[SyncCopyFromCPU] Size of tensor you are trying to allocate is "
- "larger than "
+ << "[SyncCopyFromCPU] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
CHECK_EQ(dshape.Size(), size) << "Memory size do not match";
@@ -2217,8 +2225,7 @@ void NDArray::SyncCopyToCPU(void* data, size_t size) const {
mxnet::TShape dshape = this->shape();
if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
CHECK_LT(size, (int64_t{1} << 31) - 1)
- << "[SyncCopyToCPU] Size of tensor you are trying to allocate is "
- "larger than "
+ << "[SyncCopyToCPU] Size of tensor you are trying to allocate is larger than "
"2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
}
CHECK_EQ(dshape.Size(), size) << "Memory size do not match";
@@ -2290,12 +2297,12 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
}
this->WaitToWrite();
CHECK_NE(err, kCSRShapeErr) << "Shape mismatch of this csr NDArray";
- CHECK_NE(err, kCSRIndPtrErr) << "IndPtr of csr NDArray should be non-negative, in non-decreasing "
- "order, "
- << "start with 0, and end with value equal with size of indices.";
- CHECK_NE(err, kCSRIdxErr) << "Indices of csr NDArray should be non-negative, "
- "in ascending order per row "
- << " and less than the number of columns.";
+ CHECK_NE(err, kCSRIndPtrErr)
+ << "IndPtr of csr NDArray should be non-negative, in non-decreasing order, "
+ << "start with 0, and end with value equal with size of indices.";
+ CHECK_NE(err, kCSRIdxErr)
+ << "Indices of csr NDArray should be non-negative, in ascending order per row "
+ << " and less than the number of columns.";
CHECK_NE(err, kRSPShapeErr) << "Shape mismatch of this row_sparse NDArray";
CHECK_NE(err, kRSPIdxErr) << "Indices of row_sparse NDArray should be non-negative, "
<< "less than the size of first dimension and in ascending order";
@@ -2313,8 +2320,7 @@ MXNET_REGISTER_NDARRAY_FUN(fill_element_0index)
.set_function(TernaryOp<ndarray::MatFillRowElem>)
.describe(
"Fill one element of each line(row for python, column for R/Julia)"
- " in lhs according to index indicated by rhs and values indicated by "
- "mhs."
+ " in lhs according to index indicated by rhs and values indicated by mhs."
" This function assume rhs uses 0-based index.");
// register API function
@@ -2464,9 +2470,7 @@ MXNET_REGISTER_NDARRAY_FUN(_imdecode)
.set_num_use_vars(1)
.set_num_scalars(7)
.set_num_mutate_vars(1)
- .describe(
- "Decode an image, clip to (x0, y0, x1, y1), subtract mean, and write "
- "to buffer")
+ .describe("Decode an image, clip to (x0, y0, x1, y1), subtract mean, and write to buffer")
.add_argument("mean", "NDArray-or-Symbol", "image mean")
.add_argument("index", "int", "buffer position for output")
.add_argument("x0", "int", "x0")
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 5a7ece1..de02bd6 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -22,30 +22,30 @@
* \file concat.cc
* \brief
* \author Bing Xu
-*/
+ */
+#include "../../common/utils.h"
#include "./concat-inl.h"
-#include "./mkldnn/mkldnn_ops-inl.h"
#include "./mkldnn/mkldnn_base-inl.h"
-#include "../../common/utils.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
namespace mxnet {
namespace op {
bool ConcatShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape) {
+ mxnet::ShapeVector* in_shape,
+ mxnet::ShapeVector* out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
mxnet::TShape dshape;
- dim_t size = 0;
+ dim_t size = 0;
bool has_unknown_dim_size = false;
- int axis = -1;
+ int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
mxnet::TShape tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
- axis = CheckAxis(param_.dim, tmp.ndim());
+ axis = CheckAxis(param_.dim, tmp.ndim());
has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
size += tmp[axis];
tmp[axis] = -1;
@@ -55,12 +55,13 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
- axis = CheckAxis(param_.dim, tmp.ndim());
+ axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
- if (dshape.ndim() == -1) return false;
+ if (dshape.ndim() == -1)
+ return false;
CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
for (int i = 0; i < param_.num_args; ++i) {
@@ -68,7 +69,8 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}
- if (!has_unknown_dim_size) dshape[axis] = size;
+ if (!has_unknown_dim_size)
+ dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
@@ -80,8 +82,8 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
// The first (and sometimes the second) input may be unknown on the target axis.
// If the two inputs are unknown, they always have the same shape.
static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape) {
+ mxnet::ShapeVector* in_shape,
+ mxnet::ShapeVector* out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
@@ -106,31 +108,32 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
- axis = CheckAxis(param_.dim, tmp.ndim());
+ axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
- if (!mxnet::ndim_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape))
+ return false;
for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}
- if (zero_indices.empty()) dshape[axis] = size;
+ if (zero_indices.empty())
+ dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) {
int residual = (*out_shape)[0][axis] - size;
- CHECK_GE(residual, 0)
- << "Input size already exceeds output size. Residual: " << residual;
+ CHECK_GE(residual, 0) << "Input size already exceeds output size. Residual: " << residual;
CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0)
<< "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
bool need_infer = !shape_is_known((*out_shape)[0]);
for (int i : zero_indices) {
(*in_shape)[i][axis] = residual / zero_indices.size();
- need_infer = need_infer || !shape_is_known((*in_shape)[i]);
+ need_infer = need_infer || !shape_is_known((*in_shape)[i]);
}
return !need_infer;
}
@@ -139,21 +142,20 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
}
bool ConcatType(const nnvm::NodeAttrs& attrs,
- std::vector<int> *in_type,
- std::vector<int> *out_type) {
+ std::vector<int>* in_type,
+ std::vector<int>* out_type) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
- int dtype = -1;
+ int dtype = -1;
// checks uniformity of input
- for (size_t i =0; i < in_type->size(); ++i) {
+ for (size_t i = 0; i < in_type->size(); ++i) {
if (dtype == -1) {
dtype = in_type->at(i);
} else {
CHECK(in_type->at(i) == dtype || in_type->at(i) == -1)
- << "Non-uniform data type in " << attrs.op->name
- << ", expected data type " << mxnet::op::type_string(dtype)
- << ", got data type " << mxnet::op::type_string(in_type->at(i))
- << " for input " << i;
+ << "Non-uniform data type in " << attrs.op->name << ", expected data type "
+ << mxnet::op::type_string(dtype) << ", got data type "
+ << mxnet::op::type_string(in_type->at(i)) << " for input " << i;
}
}
@@ -166,18 +168,17 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
for (size_t i = 0; i < nin; ++i) {
in_type->push_back(dtype);
}
- // if out types are known in types are unknown
+ // if out types are known in types are unknown
} else if ((*out_type)[0] != -1 && dtype == -1) {
in_type->clear();
for (size_t i = 0; i < nin; ++i) {
in_type->push_back((*out_type)[0]);
}
- // if both out_types and in_types are known, and different
+ // if both out_types and in_types are known, and different
} else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) {
std::ostringstream os;
- os << "Type inconsistent, Provided output type = "
- << mxnet::op::type_string((*out_type)[0]) << ','
- << " inferred type = " << mxnet::op::type_string(dtype);
+ os << "Type inconsistent, Provided output type = " << mxnet::op::type_string((*out_type)[0])
+ << ',' << " inferred type = " << mxnet::op::type_string(dtype);
throw mxnet::op::InferTypeError(os.str(), 0);
}
return true;
@@ -186,29 +187,27 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
CHECK(!in_attrs->empty());
CHECK_EQ(out_attrs->size(), 1U);
- auto& out_stype = out_attrs->at(0);
- bool dispatched = false;
+ auto& out_stype = out_attrs->at(0);
+ bool dispatched = false;
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
- if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage)
- && param.dim == 0) {
- dispatched = storage_type_assign(&out_stype, kCSRStorage,
- dispatch_mode, DispatchMode::kFComputeEx);
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param.dim == 0) {
+ dispatched =
+ storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
- if (!dispatched && dev_mask == mshadow::cpu::kDevMask
- && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
- && param.dim > 0) {
- dispatched = storage_type_assign(&out_stype, kDefaultStorage,
- dispatch_mode, DispatchMode::kFComputeEx);
+ if (!dispatched && dev_mask == mshadow::cpu::kDevMask &&
+ common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0) {
+ dispatched =
+ storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
#endif // MXNET_USE_MKLDNN == 1
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
- dispatched = storage_type_assign(&out_stype, kDefaultStorage,
- dispatch_mode, DispatchMode::kFCompute);
+ dispatched =
+ storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -223,15 +222,14 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
- if (dev_mask == mshadow::cpu::kDevMask
- && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
- && param.dim > 0)
+ if (dev_mask == mshadow::cpu::kDevMask &&
+ common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif // MXNET_USE_MKLDNN == 1
@@ -240,19 +238,23 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
if (!MKLDNNEnvSet())
wanted_mode = DispatchMode::kFComputeFallback;
#endif // MXNET_USE_MKLDNN == 1
- return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
- dispatch_mode, wanted_mode);
+ return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode);
}
#if MXNET_USE_MKLDNN == 1
-bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
- for (auto &arr : arrs) {
- if (arr.IsView()) return false;
- if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16)) return false;
+bool SupportMKLDNNConcat(const std::vector<NDArray>& arrs) {
+ for (auto& arr : arrs) {
+ if (arr.IsView())
+ return false;
+ if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16))
+ return false;
// DO not support zero-size tensors.
- if (arr.shape().Size() == 0) return false;
+ if (arr.shape().Size() == 0)
+ return false;
int ndim = arr.shape().ndim();
- const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
- if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
+ const int mkldnn_ndims =
+ static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc().data.ndims;
+ if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims)
+ return false;
}
return true;
}
@@ -265,7 +267,8 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK(!inputs.empty());
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
- if (req[0] == kNullOp) return;
+ if (req[0] == kNullOp)
+ return;
if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
@@ -299,7 +302,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
#endif // MXNET_USE_MKLDNN == 1
struct ConcatGrad {
- const char *op_name;
+ const char* op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
@@ -315,38 +318,37 @@ struct ConcatGrad {
DMLC_REGISTER_PARAMETER(ConcatParam);
-#define CONCAT_FORWARD_ATTRS \
-.set_num_inputs([](const NodeAttrs& attrs) { \
- const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
- return params.num_args; \
-}) \
-.set_num_outputs(1) \
-.set_attr_parser(ParamParser<ConcatParam>) \
-.set_attr<nnvm::FListInputNames>("FListInputNames", \
- [](const NodeAttrs& attrs) { \
- const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
- std::vector<std::string> ret; \
- for (int i = 0; i < params.num_args; ++i) { \
- ret.push_back(std::string("arg") + std::to_string(i)); \
- } \
- return ret; \
-}) \
-.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
- [](const NodeAttrs& attrs) { \
- return std::vector<std::string>{"output"}; \
-}) \
-.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
-.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
-.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
-.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
-.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
-.set_attr<std::string>("key_var_num_args", "num_args")
-
+#define CONCAT_FORWARD_ATTRS \
+ .set_num_inputs([](const NodeAttrs& attrs) { \
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
+ return params.num_args; \
+ }) \
+ .set_num_outputs(1) \
+ .set_attr_parser(ParamParser<ConcatParam>) \
+ .set_attr<nnvm::FListInputNames>("FListInputNames", \
+ [](const NodeAttrs& attrs) { \
+ const ConcatParam& params = \
+ nnvm::get<ConcatParam>(attrs.parsed); \
+ std::vector<std::string> ret; \
+ for (int i = 0; i < params.num_args; ++i) { \
+ ret.push_back(std::string("arg") + std::to_string(i)); \
+ } \
+ return ret; \
+ }) \
+ .set_attr<nnvm::FListOutputNames>( \
+ "FListOutputNames", \
+ [](const NodeAttrs& attrs) { return std::vector<std::string>{"output"}; }) \
+ .set_attr<nnvm::FInferType>("FInferType", ConcatType) \
+ .set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
+ .set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
+ .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
+ .set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
+ .set_attr<std::string>("key_var_num_args", "num_args")
NNVM_REGISTER_OP(Concat)
MXNET_ADD_SPARSE_OP_ALIAS(concat)
-.add_alias("concat")
-.describe(R"code(Joins input arrays along a given axis.
+ .add_alias("concat")
+ .describe(R"code(Joins input arrays along a given axis.
.. note:: `Concat` is deprecated. Use `concat` instead.
@@ -384,59 +386,60 @@ Example::
)code" ADD_FILELINE)
#if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
- return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
-.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
-.set_attr<bool>("TIsMKLDNN", true)
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+ .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
+ .set_attr<bool>("TIsMKLDNN", true)
#endif // MXNET_USE_MKLDNN == 1
-CONCAT_FORWARD_ATTRS
-.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
-.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-.add_arguments(ConcatParam::__FIELDS__());
+ CONCAT_FORWARD_ATTRS.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
+ .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+ .add_arguments(ConcatParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_Concat)
-.set_num_inputs([](const NodeAttrs& attrs) {
+ .set_num_inputs([](const NodeAttrs& attrs) {
#if MXNET_USE_MKLDNN == 1
- const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
- return 1 + params.num_args;
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return 1 + params.num_args;
#else
- return 1;
+ return 1;
#endif
-})
-.set_num_outputs([](const NodeAttrs& attrs) {
- const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
- return params.num_args;
-})
-.set_attr_parser(ParamParser<ConcatParam>)
+ })
+ .set_num_outputs([](const NodeAttrs& attrs) {
+ const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+ return params.num_args;
+ })
+ .set_attr_parser(ParamParser<ConcatParam>)
#if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
- return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
#endif // MXNET_USE_MKLDNN == 1
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
+ .set_attr<nnvm::TIsBackward>("TIsBackward", true)
+ .set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
#if MXNET_USE_MKLDNN == 1
-.set_attr<bool>("TIsMKLDNN", true)
-.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
+ .set_attr<bool>("TIsMKLDNN", true)
+ .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
#endif // MXNET_USE_MKLDNN == 1
-.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
+ .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
// _rnn_param_concat is a custom concat op with specialized infer_shape,
// which handles the case where the first one or two inputs may have
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
-.add_alias("_npi_rnn_param_concat")
+ .add_alias("_npi_rnn_param_concat")
#if MXNET_USE_MKLDNN == 1
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
- return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-})
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
#endif // MXNET_USE_MKLDNN == 1
-CONCAT_FORWARD_ATTRS
-.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
-.set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
-.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-.add_arguments(ConcatParam::__FIELDS__());
+ CONCAT_FORWARD_ATTRS.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
+ .set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
+ .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+ .add_arguments(ConcatParam::__FIELDS__());
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 6f4ac3d..3f73310 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -152,7 +152,7 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs,
param_.alg = GetMKLDNNActAlgo(param);
const NDArray& in_buffer = in_data;
MKLDNNStream* stream = MKLDNNStream::Get();
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
MKLDNNActForward& fwd = GetActForward(param_, ctx, in_buffer, *input_mem);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
stream->RegisterPrimArgs(fwd.GetFwd(),
@@ -177,7 +177,7 @@ void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs,
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
MKLDNNActForward& fwd = GetActForward(param_, ctx, in_buffer, *input_mem);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
stream->RegisterPrimArgs(fwd.GetFwd(),
@@ -222,14 +222,15 @@ static inline MKLDNNActBackward& GetActBackward(const MKLDNNActParam& param,
auto it = bwds.find(key);
if (it == bwds.end()) {
- MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
+ MKLDNNActBackward bwd(
+ param, in_data, in_mem, *static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData()));
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}
-// For backward relu activation, it's okay to pass "out_data" as "in_data" to
-// this function, since the computation only involes non-zeros.
+// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
+// function, since the computation only involes non-zeros.
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
@@ -247,12 +248,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs,
MKLDNNActParam param_;
param_.alg = GetMKLDNNActAlgo(param);
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
- auto diff_dst_memory = out_buffer.GetMKLDNNData();
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto diff_dst_memory = static_cast<const mkldnn::memory*>(out_buffer.GetMKLDNNData());
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
+ auto diff_dst_desc = diff_dst_memory->get_desc();
if (input_mem->get_desc() != diff_dst_memory->get_desc())
- input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
+ input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNDataReorder(&diff_dst_desc));
MKLDNNActBackward& bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream* stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
@@ -286,12 +288,13 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
param_.slope = param.slope;
TmpMemMgr::Get()->Init(ctx.requested[leakyrelu::kRandom]);
- auto diff_dst_memory = out_buffer.GetMKLDNNData();
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto diff_dst_memory = static_cast<const mkldnn::memory*>(out_buffer.GetMKLDNNData());
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
+ auto diff_dst_desc = diff_dst_memory->get_desc();
if (input_mem->get_desc() != diff_dst_memory->get_desc())
- input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
+ input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNDataReorder(&diff_dst_desc));
MKLDNNActBackward& bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream* stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory = CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc b/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
index 3b8f234..13c842e 100644
--- a/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_adaptive_pooling.cc
@@ -36,7 +36,8 @@ void MKLDNNAdaptivePoolingFwd::Init(const mxnet::NDArray& input,
const mkldnn::memory::dims& pad_r,
const bool is_train,
const mkldnn::algorithm alg_kind) {
- const auto src_md = input.GetMKLDNNData()->get_desc();
+ const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
+ const auto src_md = mem->get_desc();
const auto dst_md = GetMemDesc(output);
const mkldnn::engine engine = CpuEngine::Get()->get_engine();
@@ -69,7 +70,7 @@ void MKLDNNAdaptivePoolingFwd::Execute(const NDArray& input,
in_buffer = input.Reorder2Default();
}
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
auto output_mem_t = CreateMKLDNNMem(output, this->fwd_pd_->dst_desc(), req);
mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *(output_mem_t.second)}};
@@ -80,7 +81,9 @@ void MKLDNNAdaptivePoolingFwd::Execute(const NDArray& input,
LOG(FATAL) << "MKLDNN Average Pooling: incorrect worskapce input";
}
auto ws = std::make_shared<mkldnn::memory>(
- this->fwd_pd_->workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+ this->fwd_pd_->workspace_desc(),
+ engine,
+ static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData())->get_data_handle());
args[MKLDNN_ARG_WORKSPACE] = *ws;
}
if (this->fwd_) {
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index e206e9e..698cd21 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -156,8 +156,7 @@ static inline int GetMKLDNNCacheSize() {
return mkldnn_cache_size;
}
-// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size
-// during runtime
+// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
template <typename S, typename I, typename H>
static typename std::unordered_map<S, I, H>::iterator AddToCache(std::unordered_map<S, I, H>* cache,
const S& key,
@@ -208,7 +207,8 @@ static int GetTypeSize(int dtype) {
static inline size_t GetArraySize(const NDArray& arr) {
if (arr.IsMKLDNNData()) {
- return arr.GetMKLDNNData()->get_desc().get_size();
+ auto arr_data = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
+ return arr_data->get_desc().get_size();
}
return arr.shape().Size() * GetTypeSize(arr.dtype());
}
@@ -532,8 +532,7 @@ static inline void InvalidateOutputs(const std::vector<NDArray>& arrs,
}
}
-// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature
-// added
+// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added
static inline void CreateDefaultInputs(const std::vector<NDArray>& arrs,
std::vector<NDArray>* out_arrs) {
out_arrs->clear();
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 5a65c94..8e1dd74 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -159,9 +159,12 @@ void MKLDNNMemoryCopy(const mkldnn::memory& mem, const mkldnn::memory* this_mem)
}
bool CanWriteTo(const NDArray& out_arr, const NDArray& in_arr, const mkldnn::memory::desc& desc) {
- auto in_mem = in_arr.GetMKLDNNData();
- bool add_same = in_mem->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle();
- bool pdesc_same = out_arr.GetMKLDNNData()->get_desc() == desc && in_mem->get_desc() == desc;
+ auto in_mem = static_cast<const mkldnn::memory*>(in_arr.GetMKLDNNData());
+ bool add_same = in_mem->get_data_handle() ==
+ static_cast<const mkldnn::memory*>(out_arr.GetMKLDNNData())->get_data_handle();
+ bool pdesc_same =
+ static_cast<const mkldnn::memory*>(out_arr.GetMKLDNNData())->get_desc() == desc &&
+ in_mem->get_desc() == desc;
return add_same && pdesc_same;
}
@@ -173,7 +176,8 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray& out_arr,
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::AddBack, tmp);
} else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
- mkldnn::memory* mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+ mkldnn::memory* mem =
+ static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
// mem is nullptr if out_arr is view and desc is MKLDNN format.
// need to Reorder2Default before calling CreateMKLDNNMem
CHECK(mem != nullptr);
@@ -182,7 +186,8 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray& out_arr,
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
} else if (kWriteTo == req) {
- mkldnn::memory* mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+ mkldnn::memory* mem =
+ static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
if (nullptr == mem) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
@@ -205,7 +210,7 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray& out_arr,
} else {
mkldnn::memory* mem = nullptr;
if (IsDefaultFormat(desc)) {
- mem = const_cast<NDArray&>(out_arr).CreateMKLDNNData(desc);
+ mem = static_cast<mkldnn::memory*>(const_cast<NDArray&>(out_arr).CreateMKLDNNData(&desc));
}
if (mem == nullptr) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
@@ -218,16 +223,17 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray& out_arr,
void CommitOutput(const NDArray& arr, const mkldnn_output_t& res) {
if (res.first == CopyBack) {
- const_cast<NDArray&>(arr).CopyFrom(*res.second);
+ const_cast<NDArray&>(arr).CopyFrom(res.second);
} else if (res.first == AddBack) {
auto res_memory = res.second;
- auto target_pd = arr.GetMKLDNNData()->get_desc();
- auto mem = arr.GetMKLDNNData(res.second->get_desc());
+ auto target_pd = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc();
+ auto res_desc = res.second->get_desc();
+ auto mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&res_desc));
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
MKLDNNMemoryCopy(*res_memory, tmp_memory);
res_memory = tmp_memory;
- mem = arr.GetMKLDNNData();
+ mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
}
op::MKLDNNSum(*mem, *res_memory, *mem);
}
@@ -283,22 +289,24 @@ const mkldnn::memory* GetWeights(const NDArray& arr, int num_groups) {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
const auto md = mkldnn::memory::desc{tz, type, format_tag};
- return arr.GetMKLDNNData(md);
+ return static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&md));
}
const mkldnn::memory* GetWeights(const NDArray& arr,
const mkldnn::memory::desc& target_desc,
int num_groups) {
- const mkldnn::memory* mem = arr.GetMKLDNNData(target_desc);
- // If the weight array already uses the target layout, simply return it
- // directly.
- if (mem)
+ const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData(&target_desc));
+ // If the weight array already uses the target layout, simply return it directly.
+ if (mem) {
return mem;
+ }
mem = GetWeights(arr, num_groups);
- if (mem == nullptr)
- mem = arr.GetMKLDNNDataReorder(target_desc);
- if (mem->get_desc() == target_desc)
+ if (mem == nullptr) {
+ mem = static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&target_desc));
+ }
+ if (mem->get_desc() == target_desc) {
return mem;
+ }
auto ret = TmpMemMgr::Get()->Alloc(target_desc);
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *mem}, {MKLDNN_ARG_TO, *ret}});
@@ -307,8 +315,7 @@ const mkldnn::memory* GetWeights(const NDArray& arr,
}
// default: block and dims' stride increase monotonically
-// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase
-// monotonically
+// mkldnn: 1.winograd 2.rnn packed 3. block and dims'stride is not increase monotonically
bool IsMKLDNN(const mkldnn::memory::desc& desc) {
bool rslt = true;
if (desc.data.format_kind == mkldnn_blocked) {
@@ -454,8 +461,8 @@ void FallBackCompute(Compute fn,
fn(attrs_states, ctx, in_blobs, new_req, out_blobs);
for (size_t i = 0, bf16_pos = 0; i < out_blobs.size(); i++) {
if (outputs[i].dtype() == mshadow::kBfloat16) {
- auto src_mem = temp_bf16_src[bf16_pos].GetMKLDNNData();
- auto dst_mem = temp_bf16_dst[bf16_pos].GetMKLDNNData();
+ auto src_mem = static_cast<const mkldnn::memory*>(temp_bf16_src[bf16_pos].GetMKLDNNData());
+ auto dst_mem = static_cast<const mkldnn::memory*>(temp_bf16_dst[bf16_pos].GetMKLDNNData());
bf16_pos++;
ReorderTo(src_mem, dst_mem);
} else if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) {
@@ -489,13 +496,13 @@ static bool SimilarArray(const mxnet::NDArray& arr1,
NDArray buf1, buf2;
if (arr1.IsMKLDNNData()) {
buf1 = NDArray(arr1.shape(), arr1.ctx(), false, arr1.dtype());
- auto mem = arr1.GetMKLDNNData();
- buf1.CopyFrom(*mem);
+ auto mem = static_cast<const mkldnn::memory*>(arr1.GetMKLDNNData());
+ buf1.CopyFrom(mem);
}
if (arr2.IsMKLDNNData()) {
buf2 = NDArray(arr2.shape(), arr2.ctx(), false, arr2.dtype());
- auto mem = arr2.GetMKLDNNData();
- buf2.CopyFrom(*mem);
+ auto mem = static_cast<const mkldnn::memory*>(arr2.GetMKLDNNData());
+ buf2.CopyFrom(mem);
}
MKLDNNStream::Get()->Submit();
@@ -519,25 +526,25 @@ static bool SimilarArray(const mxnet::NDArray& arr1,
template void FallBackCompute(void (*)(nnvm::NodeAttrs const&,
OpContext const&,
- std::vector<TBlob, std::allocator<TBlob>> const&,
- std::vector<OpReqType, std::allocator<OpReqType>> const&,
- std::vector<TBlob, std::allocator<TBlob>> const&),
+ std::vector<TBlob, std::allocator<TBlob> > const&,
+ std::vector<OpReqType, std::allocator<OpReqType> > const&,
+ std::vector<TBlob, std::allocator<TBlob> > const&),
nnvm::NodeAttrs const&,
OpContext const&,
- std::vector<NDArray, std::allocator<NDArray>> const&,
- std::vector<OpReqType, std::allocator<OpReqType>> const&,
- std::vector<NDArray, std::allocator<NDArray>> const&);
+ std::vector<NDArray, std::allocator<NDArray> > const&,
+ std::vector<OpReqType, std::allocator<OpReqType> > const&,
+ std::vector<NDArray, std::allocator<NDArray> > const&);
template void FallBackCompute(void (*)(OpStatePtr const&,
OpContext const&,
- std::vector<TBlob, std::allocator<TBlob>> const&,
- std::vector<OpReqType, std::allocator<OpReqType>> const&,
- std::vector<TBlob, std::allocator<TBlob>> const&),
+ std::vector<TBlob, std::allocator<TBlob> > const&,
+ std::vector<OpReqType, std::allocator<OpReqType> > const&,
+ std::vector<TBlob, std::allocator<TBlob> > const&),
OpStatePtr const&,
OpContext const&,
- std::vector<NDArray, std::allocator<NDArray>> const&,
- std::vector<OpReqType, std::allocator<OpReqType>> const&,
- std::vector<NDArray, std::allocator<NDArray>> const&);
+ std::vector<NDArray, std::allocator<NDArray> > const&,
+ std::vector<OpReqType, std::allocator<OpReqType> > const&,
+ std::vector<NDArray, std::allocator<NDArray> > const&);
void OpCheck::Init(const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::NDArray>& outputs_) {
@@ -548,14 +555,14 @@ void OpCheck::Init(const std::vector<mxnet::NDArray>& inputs_,
inputs.emplace_back(data.shape(), ctx, false, data.dtype());
if (data.IsMKLDNNData() && data.IsView())
data = data.Reorder2Default();
- auto mem = data.GetMKLDNNData();
- inputs[i].CopyFrom(*mem);
+ auto mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+ inputs[i].CopyFrom(mem);
}
for (size_t i = 0; i < outputs_.size(); i++) {
outputs.emplace_back(outputs_[i].shape(), ctx, false, outputs_[i].dtype());
if (backward) {
- auto mem = outputs_[i].GetMKLDNNData();
- outputs[i].CopyFrom(*mem);
+ auto mem = static_cast<const mkldnn::memory*>(outputs_[i].GetMKLDNNData());
+ outputs[i].CopyFrom(mem);
}
}
MKLDNNStream::Get()->Submit();
@@ -602,8 +609,8 @@ void OpCheck::CopyResult(const std::vector<mxnet::NDArray>& outputs_,
CHECK(!MKLDNNStream::Get()->HasOps());
auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray>&>(outputs_);
for (auto i = indice.begin(); i != indice.end(); ++i) {
- auto mem = outputs[*i].GetMKLDNNData();
- non_const_outputs_[*i].CopyFrom(*mem);
+ auto mem = static_cast<const mkldnn::memory*>(outputs[*i].GetMKLDNNData());
+ non_const_outputs_[*i].CopyFrom(mem);
}
MKLDNNStream::Get()->Submit();
}
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 9b25b13..56eed58 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -176,11 +176,13 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
NDArray& data = in_data[batchnorm::kData];
if (data.IsMKLDNNData() && data.IsView())
data = data.Reorder2Default();
- auto data_mem = data.GetMKLDNNData();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
auto& fwd = GetBNForward<DType>(param, ctx, data_mem, flags);
// for output memory
- auto out_mem = const_cast<NDArray&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
+ auto fwd_dst_desc = fwd.GetPd().dst_desc();
+ auto out_mem =
+ static_cast<mkldnn::memory*>(const_cast<NDArray&>(out).CreateMKLDNNData(&fwd_dst_desc));
// mxnet will always use scale shift.
// But if fix_gamma is true, then all scale elements will be set to 1.0f
@@ -226,7 +228,9 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
LOG(FATAL) << "MKLDNN BatchNorm: incorrect workspace input";
}
auto ws = std::make_shared<mkldnn::memory>(
- fwd.GetPd().workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+ fwd.GetPd().workspace_desc(),
+ engine,
+ static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData())->get_data_handle());
net_args[MKLDNN_ARG_WORKSPACE] = *ws;
}
if (!ctx.is_train || param.use_global_stats) {
@@ -239,15 +243,17 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs& attrs,
omean[i] = inmean[i];
ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
}
- net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData());
- net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData());
+ net_args[MKLDNN_ARG_MEAN] =
+ *(static_cast<const mkldnn::memory*>(aux_states[batchnorm::kMovingMean].GetMKLDNNData()));
+ net_args[MKLDNN_ARG_VARIANCE] =
+ *(static_cast<const mkldnn::memory*>(aux_states[batchnorm::kMovingVar].GetMKLDNNData()));
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
MKLDNNStream::Get()->Submit();
} else { // training
- const NDArray& outMean = outputs[batchnorm::kMean];
- const NDArray& outVar = outputs[batchnorm::kVar];
- net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData());
- net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData());
+ const NDArray& outMean = outputs[batchnorm::kMean];
+ const NDArray& outVar = outputs[batchnorm::kVar];
+ net_args[MKLDNN_ARG_MEAN] = *(static_cast<const mkldnn::memory*>(outMean.GetMKLDNNData()));
+ net_args[MKLDNN_ARG_VARIANCE] = *(static_cast<const mkldnn::memory*>(outVar.GetMKLDNNData()));
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
MKLDNNStream::Get()->Submit();
@@ -374,14 +380,17 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
gradIn = gradIn.Reshape(new_shape);
}
- auto data_mem = data.GetMKLDNNData();
- auto diff_mem = diff.GetMKLDNNData();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+ auto diff_mem = static_cast<const mkldnn::memory*>(diff.GetMKLDNNData());
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
// should reorder them.
- if (data.IsDefaultData())
- data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc());
- else if (diff.IsDefaultData())
- diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
+ if (data.IsDefaultData()) {
+ auto diff_desc = diff_mem->get_desc();
+ data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&diff_desc));
+ } else if (diff.IsDefaultData()) {
+ auto data_mem_desc = data_mem->get_desc();
+ diff_mem = static_cast<const mkldnn::memory*>(diff.GetMKLDNNDataReorder(&data_mem_desc));
+ }
auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem =
CreateMKLDNNMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
@@ -414,7 +423,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
const NDArray* workspace = nullptr;
workspace = &inputs[8];
if (workspace != nullptr) {
- net_args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
+ net_args[MKLDNN_ARG_WORKSPACE] =
+ *(static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()));
}
}
@@ -434,11 +444,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs& attrs,
tmp_var_ptr[i] = variance;
moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom;
}
- net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
+ net_args[MKLDNN_ARG_MEAN] = *(static_cast<const mkldnn::memory*>(out_mean.GetMKLDNNData()));
net_args[MKLDNN_ARG_VARIANCE] = var_mem;
} else {
- net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
- net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
+ net_args[MKLDNN_ARG_MEAN] =
+ *(static_cast<const mkldnn::memory*>(moving_mean.GetMKLDNNData()));
+ net_args[MKLDNN_ARG_VARIANCE] =
+ *(static_cast<const mkldnn::memory*>(moving_var.GetMKLDNNData()));
}
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
CommitOutput(gradIn, gradi_mem);
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index 689888a..4904ad3 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -71,7 +71,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs,
data_md.reserve(num_in_data);
data_mem.reserve(num_in_data);
for (int i = 0; i < num_in_data; i++) {
- const mkldnn::memory* tmp_mem = in_data[i].GetMKLDNNData();
+ const mkldnn::memory* tmp_mem = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
mkldnn::memory::desc tmp_md = tmp_mem->get_desc();
data_md.push_back(tmp_md);
data_mem.push_back(tmp_mem);
@@ -98,7 +98,7 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs,
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
const int num_in_data = param.num_args;
const int axis = param.dim;
- const auto gradz_mem = inputs[0].GetMKLDNNData();
+ const auto gradz_mem = static_cast<const mkldnn::memory*>(inputs[0].GetMKLDNNData());
/* init the offset */
mkldnn::memory::dims offsets(outputs[0].shape().ndim());
for (auto& v : offsets) {
@@ -107,7 +107,7 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < num_in_data; i++) {
mkldnn::memory::dims diff_src_tz(outputs[i].shape().begin(), outputs[i].shape().end());
- auto diff_src_md = outputs[i].GetMKLDNNData()->get_desc();
+ auto diff_src_md = static_cast<const mkldnn::memory*>(outputs[i].GetMKLDNNData())->get_desc();
auto gradi_mem = CreateMKLDNNMem(outputs[i], diff_src_md, req[i]);
auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets);
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 966ba21..dcacf24 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -116,16 +116,14 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
// MKL-DNN introduced padded formats since 0.15 which require more memory
// compared to the actual size of the tensor. Currently, MKL-DNN operators
// still reuse memory from memory planning, so here we need to select a
- // suboptimal kernel for computation that has the expected memory size
- // requirements
+ // suboptimal kernel for computation that has the expected memory size requirements
auto conv_pd =
std::make_shared<mkldnn::convolution_forward::primitive_desc>(desc, attr, engine);
while (conv_pd->dst_desc().get_size() != GetArraySize(output) ||
conv_pd->src_desc().get_size() != GetArraySize(data) ||
(!param.mkldnn_param.quantized &&
conv_pd->weights_desc().get_size() != GetArraySize(weights))) {
- // next_impl() will visit desc and engine, please make sure they are
- // still alive here.
+ // next_impl() will visit desc and engine, please make sure they are still alive here.
CHECK(conv_pd->next_impl()) << "No convolution implementation for this request.";
}
return conv_pd;
@@ -488,7 +486,8 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam& param,
auto& weight = in_data[conv::kWeight];
bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn;
- auto data_mem = data.GetMKLDNNDataReorder(fwd->GetPd().src_desc());
+ auto fwd_src_desc = fwd->GetPd().src_desc();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
const mkldnn::memory* weight_mem;
if (ctx.is_train) {
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
@@ -502,25 +501,30 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam& param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
- // We also need to modify the layout on the original weight array. The
- // data conversion happens after the weight array is used.
- weight.MKLDNNDataReorderAsync(fwd->GetPd().weights_desc());
+ // We also need to modify the layout on the original weight array. The data conversion happens
+ // after the weight array is used.
+ auto fwd_weight_desc = fwd->GetPd().weights_desc();
+ weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
} else {
- weight_mem = weight.GetMKLDNNDataReorder(fwd->GetPd().weights_desc());
+ auto fwd_weight_desc = fwd->GetPd().weights_desc();
+ weight_mem =
+ static_cast<const mkldnn::memory*>(weight.GetMKLDNNDataReorder(&fwd_weight_desc));
}
}
mkldnn_output_t out_mem;
if (param.mkldnn_param.with_sum) {
out_mem = mkldnn_output_t(OutDataOp::Noop,
- const_cast<mkldnn::memory*>(out_data[conv::kOut].GetMKLDNNData()));
+ const_cast<mkldnn::memory*>(static_cast<const mkldnn::memory*>(
+ out_data[conv::kOut].GetMKLDNNData())));
} else {
out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->GetPd().dst_desc(), req[conv::kOut]);
}
mkldnn_args_map_t net_args;
if (!no_bias) {
- const mkldnn::memory* bias_mem = in_data[conv::kBias].GetMKLDNNData();
+ const mkldnn::memory* bias_mem =
+ static_cast<const mkldnn::memory*>(in_data[conv::kBias].GetMKLDNNData());
net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
}
@@ -612,7 +616,9 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
MKLDNNConvBackward& convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
- auto out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetDataPd().diff_dst_desc());
+ auto convBwd_data_diff_desc = convBwd.GetDataPd().diff_dst_desc();
+ auto out_grad_mem =
+ static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&convBwd_data_diff_desc));
if (req[conv::kData]) {
auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group);
auto in_grad_mem = CreateMKLDNNMem(
@@ -624,9 +630,14 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs,
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight] || req[conv::kBias]) {
- if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc())
- out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetWeightsPd().diff_dst_desc());
- auto data_mem = data.GetMKLDNNDataReorder(convBwd.GetWeightsPd().src_desc());
+ if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) {
+ auto convBwd_weight_diff_desc = convBwd.GetWeightsPd().diff_dst_desc();
+ out_grad_mem = static_cast<const mkldnn::memory*>(
+ out_grad.GetMKLDNNDataReorder(&convBwd_weight_diff_desc));
+ }
+ auto convBwd_weight_src_desc = convBwd.GetWeightsPd().src_desc();
+ auto data_mem =
+ static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&convBwd_weight_src_desc));
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]);
diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc
index 601df3c..4051144 100644
--- a/src/operator/nn/mkldnn/mkldnn_copy.cc
+++ b/src/operator/nn/mkldnn/mkldnn_copy.cc
@@ -38,18 +38,21 @@ void MKLDNNCopy(const nnvm::NodeAttrs& attrs,
if (req == kNullOp || req == kWriteInplace)
return;
TmpMemMgr::Get()->Init(ctx.requested[0]);
- auto in_mem = in_data.GetMKLDNNData();
+ auto in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
if (req == kAddTo) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
// We should try and force the input memory has the same format
// as the input output. If not, we'll have to reorder memory.
- auto out_mem = out_data.GetMKLDNNData();
- in_mem = in_data.GetMKLDNNData(out_mem->get_desc());
- if (in_mem == nullptr)
- in_mem = in_data.GetMKLDNNDataReorder(out_mem->get_desc());
+ auto out_mem = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData());
+ auto out_mem_desc = out_mem->get_desc();
+ in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData(&out_mem_desc));
+ if (in_mem == nullptr) {
+ auto out_mem_desc = out_mem->get_desc();
+ in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNDataReorder(&out_mem_desc));
+ }
MKLDNNSum(*out_mem, *in_mem, *out_mem);
} else {
- const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
+ const_cast<NDArray&>(out_data).CopyFrom(in_mem);
}
MKLDNNStream::Get()->Submit();
}
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
index b048c13..f1cd322 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h
@@ -58,9 +58,8 @@ using deconv_bwd_data_pd_t = mkldnn::deconvolution_backward_data::primitive_desc
using deconv_bwd_weights_t = mkldnn::deconvolution_backward_weights;
using deconv_bwd_weights_pd_t = mkldnn::deconvolution_backward_weights::primitive_desc;
-// Swaps the logical order of dimensions that in plain format would correspond
-// to input and output channels (for example: oihw => iohw, iohw => oihw, goihw
-// => giohw).
+// Swaps the logical order of dimensions that in plain format would correspond to input and output
+// channels (for example: oihw => iohw, iohw => oihw, goihw => giohw).
inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc& desc,
const uint32_t num_group) {
std::vector<int> order(desc.data.ndims);
@@ -74,18 +73,18 @@ inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc& desc,
inline void IOLogicalSwapMKLDNNMem(const NDArray& arr, const uint32_t num_group) {
mkldnn::memory::desc desc;
if (arr.IsMKLDNNData()) {
- desc = arr.GetMKLDNNData()->get_desc();
+ desc = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData())->get_desc();
} else {
- // GetMKLDNNData won't take groups into account when creating
- // mkldnn::memory, we need to use descriptor from GetWeightDesc but with
- // default format
+ // GetMKLDNNData won't take groups into account when creating mkldnn::memory, we need to use
+ // descriptor from GetWeightDesc but with default format
const auto& temp = GetWeightDesc(arr, num_group);
desc = mkldnn::memory::desc(
temp.dims(),
temp.data_type(),
static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(temp.data.ndims)));
}
- const_cast<NDArray&>(arr).UpdateMKLDNNMemDesc(IOLogicalSwapDesc(desc, num_group));
+ auto logical_swap = IOLogicalSwapDesc(desc, num_group);
+ const_cast<NDArray&>(arr).UpdateMKLDNNMemDesc(&logical_swap);
}
// Version of GetWeightsDesc for deconvolution (with swap)
@@ -152,7 +151,8 @@ MKLDNNDeconvFwd::MKLDNNDeconvFwd(const DeconvolutionParam& param, const Tensors&
}
inline const mkldnn::memory* MKLDNNDeconvFwd::DataMem(const NDArray& data) const {
- return data.GetMKLDNNDataReorder(fwd_pd->src_desc());
+ auto fwd_src_desc = fwd_pd->src_desc();
+ return static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
}
inline const mkldnn::memory* MKLDNNDeconvFwd::WeightsMem(const uint32_t num_group,
@@ -161,7 +161,7 @@ inline const mkldnn::memory* MKLDNNDeconvFwd::WeightsMem(const uint32_t num_grou
}
inline const mkldnn::memory* MKLDNNDeconvFwd::BiasMem(const NDArray& bias) const {
- return bias.GetMKLDNNData();
+ return static_cast<const mkldnn::memory*>(bias.GetMKLDNNData());
}
inline mkldnn_output_t MKLDNNDeconvFwd::OutMem(const OpReqType req, const NDArray& out) const {
@@ -210,8 +210,8 @@ class MKLDNNDeconvBwd {
const NDArray& weights,
const NDArray& weights_grad) const;
- // returns the output gradient memory used to calculate the data (input)
- // gradient, which might be reused when calculating the gradient of weights
+ // returns the output gradient memory used to calculate the data (input) gradient,
+ // which might be reused when calculating the gradient of weights
const mkldnn::memory* ScheduleBwdData(const uint32_t num_group,
const OpReqType req,
const ReadTensors& read_tensors,
@@ -279,7 +279,8 @@ inline void MKLDNNDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group,
}
inline const mkldnn::memory* MKLDNNDeconvBwd::DataMem(const NDArray& data) const {
- return data.GetMKLDNNDataReorder(bwd_weights_pd->src_desc());
+ auto bwd_weight_src_desc = bwd_weights_pd->src_desc();
+ return static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&bwd_weight_src_desc));
}
inline const mkldnn::memory* MKLDNNDeconvBwd::WeightsMem(const uint32_t num_group,
@@ -288,15 +289,18 @@ inline const mkldnn::memory* MKLDNNDeconvBwd::WeightsMem(const uint32_t num_grou
}
inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem(const NDArray& out_grad) const {
- return out_grad.GetMKLDNNDataReorder(bwd_data_pd->diff_dst_desc());
+ auto bwd_data_diff_desc = bwd_data_pd->diff_dst_desc();
+ return static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&bwd_data_diff_desc));
}
inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem(
const NDArray& out_grad,
const mkldnn::memory* const out_grad_mem) const {
+ auto bwd_weight_diff_desc = bwd_weights_pd->diff_dst_desc();
return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc())
? out_grad_mem
- : out_grad.GetMKLDNNDataReorder(bwd_weights_pd->diff_dst_desc());
+ : static_cast<const mkldnn::memory*>(
+ out_grad.GetMKLDNNDataReorder(&bwd_weight_diff_desc));
}
inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req,
@@ -307,14 +311,15 @@ inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req,
inline mkldnn_output_t MKLDNNDeconvBwd::WeightsGradMem(const uint32_t num_group,
const OpReqType req,
const NDArray& weights_grad) const {
- // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat
- // always fails (because of the logical swap - explained in
- // MKLDNNDeconvFwd::Execute). We try to reuse weights_grad memory (which, when
- // not swapped, is always in default format), so here we check if after a
+ // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat always fails (because
+ // of the logical swap - explained in MKLDNNDeconvFwd::Execute). We try to reuse weights_grad
+ // memory (which, when not swapped, is always in default format), so here we check if after a
// swap, weights_md will have a default format
const auto& weights_md = bwd_weights_pd->diff_weights_desc();
if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) {
- return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateMKLDNNData(weights_md)};
+ return {OutDataOp::Noop,
+ static_cast<mkldnn::memory*>(
+ const_cast<NDArray&>(weights_grad).CreateMKLDNNData(&weights_md))};
}
return CreateMKLDNNWeightGrad(weights_grad, weights_md, req);
}
@@ -334,13 +339,13 @@ class DeconvDescCreator {
const NDArray* const bias,
const NDArray& out);
- // Imposes plain formats on memory descriptors with padding (so the next
- // selected implementation will pass CheckImplSizeReq). After calling this
- // method, new primitive descriptor (with new operator descriptor) should be
- // created, which should select an implementation with matching size
- // requirements. data_size, weights_size, out_size - size requirements of
- // current implementation Returns whether successfully imposed a plain format
- // on any of the data, weights, and output memory descriptors.
+ // Imposes plain formats on memory descriptors with padding (so the next selected implementation
+ // will pass CheckImplSizeReq). After calling this method, new primitive descriptor (with new
+ // operator descriptor) should be created, which should select an implementation with matching
+ // size requirements.
+ // data_size, weights_size, out_size - size requirements of current implementation
+ // Returns whether successfully imposed a plain format on any of the data, weights, and output
+ // memory descriptors.
bool ImposePlainWherePadding(const size_t data_size,
const size_t weights_size,
const size_t out_size);
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 211ccd6..8428549 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -112,9 +112,10 @@ void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group,
if (weights.IsDefaultData()) {
// We also need to modify the layout on the original weights array.
// The data conversion happens after the weights array is used.
- weights.MKLDNNDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
+ auto logical_swap_desc = IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group);
+ weights.MKLDNNDataReorderAsync(&logical_swap_desc);
} else {
- CHECK(weights.GetMKLDNNData()->get_desc() ==
+ CHECK(static_cast<const mkldnn::memory*>(weights.GetMKLDNNData())->get_desc() ==
IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
}
}
@@ -123,10 +124,9 @@ void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group,
void MKLDNNDeconvFwd::Execute(const uint32_t num_group,
const OpReqType req,
const Tensors& tensors) const {
- // MXNet (correctly) assumes that deconvolution is implemented using
- // convolution primitives. For that, we would pass input tensor in place of
- // output and output tensor in place of input (for appropriate convolution
- // primitives: deconvolution forward = convolution backward data,
+ // MXNet (correctly) assumes that deconvolution is implemented using convolution primitives.
+ // For that, we would pass input tensor in place of output and output tensor in place of input
+ // (for appropriate convolution primitives: deconvolution forward = convolution backward data,
// deconvolution backward data = convolution forward).
// The convolution primitive expects weights tensor with the shape of
// (primitive_out_channels, primitive_in_channels, h, w), but with swapped
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index ccd1287..a3ed6c8 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -80,8 +80,7 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
} catch (mkldnn::error& e) {
if (e.status == mkldnn_unimplemented && full_param.mkldnn_param.quantized) {
- LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 "
- "fully_connected.";
+ LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected.";
} else {
LOG(ERROR) << e.message;
}
@@ -202,7 +201,8 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
NDArray weight = in_data[fullc::kWeight];
NDArray data = in_data[fullc::kData];
- auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
+ auto fwd_src_desc = fwd->fwd_pd.src_desc();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
const mkldnn::memory* weight_mem;
if (ctx.is_train) {
if (weight.IsMKLDNNData()) {
@@ -210,9 +210,10 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
}
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
} else {
- weight_mem = weight.GetMKLDNNData();
+ weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
- weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
+ auto fwd_weight_desc = fwd->fwd_pd.weights_desc();
+ weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
@@ -225,7 +226,9 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam& full_param,
{MKLDNN_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
- auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(fwd->fwd_pd.bias_desc());
+ auto fwd_bias_desc = fwd->fwd_pd.bias_desc();
+ auto bias_mem = static_cast<const mkldnn::memory*>(
+ in_data[fullc::kBias].GetMKLDNNDataReorder(&fwd_bias_desc));
args[MKLDNN_ARG_BIAS] = *bias_mem;
}
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
@@ -299,8 +302,14 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs,
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd);
- auto out_grad_mem = out_grad.GetMKLDNNDataReorder(ipBwdWeights_pd.diff_dst_desc());
- auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
+
+ auto ipBwdWeights_diff_dst_desc = ipBwdWeights_pd.diff_dst_desc();
+ auto out_grad_mem = static_cast<const mkldnn::memory*>(
+ out_grad.GetMKLDNNDataReorder(&ipBwdWeights_diff_dst_desc));
+
+ auto ipBwdWeights_src_desc = ipBwdWeights_pd.src_desc();
+ auto data_mem =
+ static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&ipBwdWeights_src_desc));
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[fullc::kWeight], ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]);
mkldnn_args_map_t args = {
@@ -323,8 +332,13 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs,
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd =
GetFCBwdData(data, weight, out_grad, fwd_pd);
- auto out_grad_mem = out_grad.GetMKLDNNDataReorder(ipBwdData_pd.diff_dst_desc());
- auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
+ auto ipBwdData_diff_dst_desc = ipBwdData_pd.diff_dst_desc();
+ auto out_grad_mem =
+ static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&ipBwdData_diff_dst_desc));
+
+ auto ipBwdData_weight_desc = ipBwdData_pd.weights_desc();
+ auto weight_mem =
+ static_cast<const mkldnn::memory*>(weight.GetMKLDNNDataReorder(&ipBwdData_weight_desc));
auto in_grad_mem =
CreateMKLDNNMem(in_grad[fullc::kData], ipBwdData_pd.diff_src_desc(), req[fullc::kData]);
mkldnn_args_map_t args = {{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/mkldnn/mkldnn_log_softmax.cc b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
index a4cb66a..c3f41eb 100644
--- a/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
+++ b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc
@@ -60,8 +60,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param,
const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their log_softmax function
// now. Need update this once they start to support it.
- // Currently, MKLDNN shows bad performance when log_softmax is not performed
- // on the last dimension
+ // Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension
if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
@@ -110,7 +109,8 @@ static MKLDNNLogSoftmaxFwd& GetLogSoftmaxFwd(const SoftmaxParam& param,
auto it = fwds.find(key);
if (it == fwds.end()) {
- MKLDNNLogSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
+ MKLDNNLogSoftmaxFwd fwd(
+ is_train, real_axis, *(static_cast<const mkldnn::memory*>(data.GetMKLDNNData())));
it = AddToCache(&fwds, key, fwd);
}
return it->second;
@@ -123,16 +123,16 @@ void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
const NDArray& out_data) {
if (req == kNullOp)
return;
- // same as the FCompute path, log_softmax only supports kWriteTo and
- // kWriteInplace for now.
+ // same as the FCompute path, log_softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data.shape().ndim());
auto fwd = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
- auto in_mem = in_data.GetMKLDNNData();
- auto out_mem = out_data.GetMKLDNNData(fwd.pd.dst_desc());
+ auto in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
+ auto fwd_desc = fwd.pd.dst_desc();
+ auto out_mem = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData(&fwd_desc));
MKLDNNStream* stream = MKLDNNStream::Get();
stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
stream->Submit();
@@ -176,8 +176,8 @@ static MKLDNNLogSoftmaxBwd& GetLogSoftmaxBwd(const SoftmaxParam& param,
auto it = bwds.find(key);
if (it == bwds.end()) {
- auto diff_mem = data[0].GetMKLDNNData();
- auto data_mem = data[1].GetMKLDNNData();
+ auto diff_mem = static_cast<const mkldnn::memory*>(data[0].GetMKLDNNData());
+ auto data_mem = static_cast<const mkldnn::memory*>(data[1].GetMKLDNNData());
auto fwd_pd = GetLogSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNLogSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
@@ -195,8 +195,8 @@ void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
- auto diff_mem = in_data[0].GetMKLDNNData();
- auto data_mem = in_data[1].GetMKLDNNData();
+ auto diff_mem = static_cast<const mkldnn::memory*>(in_data[0].GetMKLDNNData());
+ auto data_mem = static_cast<const mkldnn::memory*>(in_data[1].GetMKLDNNData());
auto bwd = GetLogSoftmaxBwd(param, axis, in_data, out_data);
auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index f47804d..91c6d41 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -106,8 +106,9 @@ class MKLDNNLRNFwd {
}; // End of LRN Forword Class
void MKLDNNLRNFwd::_Init(const LRNParam& param, bool is_train, const NDArray& in_data) {
- mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_desc();
- this->fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
+ mkldnn::memory::desc in_data_md =
+ static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())->get_desc();
+ this->fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward(this->fwd_pd));
}
@@ -119,7 +120,7 @@ void MKLDNNLRNFwd::Execute(const OpContext& ctx,
auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req);
mkldnn_args_map_t args = {
- {MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
+ {MKLDNN_ARG_SRC, *static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())},
{MKLDNN_ARG_DST, *output_mem_t.second},
};
std::shared_ptr<mkldnn::memory> workspace;
@@ -206,10 +207,11 @@ class MKLDNNLRNBwd {
const mkldnn_output_t& diff_src_mem) {
auto engine = CpuEngine::Get()->get_engine();
auto workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
- mkldnn_args_map_t args = {{MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
- {MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()},
- {MKLDNN_ARG_WORKSPACE, *workspace},
- {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}};
+ mkldnn_args_map_t args = {
+ {MKLDNN_ARG_SRC, *static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())},
+ {MKLDNN_ARG_DIFF_DST, *static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData())},
+ {MKLDNN_ARG_WORKSPACE, *workspace},
+ {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}};
MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args);
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
@@ -232,8 +234,10 @@ static MKLDNNLRNBwd& GetLRNBwd(const LRNParam& param,
auto it = lrn_bwds.find(key);
if (it == lrn_bwds.end()) {
- const mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_desc();
- const mkldnn::memory::desc diff_md = out_grad.GetMKLDNNData()->get_desc();
+ const mkldnn::memory::desc in_data_md =
+ static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData())->get_desc();
+ const mkldnn::memory::desc diff_md =
+ static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNData())->get_desc();
MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
it = AddToCache(&lrn_bwds, key, bwd);
}
@@ -252,8 +256,7 @@ void MKLDNNLRNBackward(const nnvm::NodeAttrs& attrs,
const NDArray& out_grad = inputs[0];
const NDArray& in_data = inputs[1];
const NDArray& in_grad = outputs[0];
- // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when
- // in_data is nchw8c
+ // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c
const auto in_buffer = in_data.Reorder2Default();
MKLDNNLRNBwd& bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 3b3be5c..9a88b86 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -42,8 +42,8 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray& input,
const mkldnn::memory::dims& pad_r,
const bool is_train,
const mkldnn::algorithm alg_kind) {
- const auto src_md = input.GetMKLDNNData()->get_desc();
- const auto dst_md = GetMemDesc(output);
+ const auto src_md = static_cast<const mkldnn::memory*>(input.GetMKLDNNData())->get_desc();
+ const auto dst_md = GetMemDesc(output);
const mkldnn::engine engine = CpuEngine::Get()->get_engine();
if (alg_kind != mkldnn::algorithm::pooling_max && alg_kind != mkldnn::algorithm::pooling_avg &&
alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
@@ -75,7 +75,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray& in_data,
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();
- auto input_mem = in_buffer.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req);
mkldnn_args_map_t args = {
@@ -91,7 +91,11 @@ void MKLDNNPoolingFwd::Execute(const NDArray& in_data,
}
auto ws = std::make_shared<mkldnn::memory>(
- (*(this->fwd_pd_)).workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle());
+ (*(this->fwd_pd_)).workspace_desc(),
+ engine,
+ static_cast<const mkldnn::memory*>(
+ static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()))
+ ->get_data_handle());
args[MKLDNN_ARG_WORKSPACE] = *ws;
}
if (this->fwd_) {
@@ -231,7 +235,7 @@ MKLDNNPoolingFwd& GetPoolingFwd(const PoolingParam& param,
if (it == pooling_fwds.end()) {
CHECK(param.kernel.ndim() == 1 || param.kernel.ndim() == 2 || param.kernel.ndim() == 3)
<< "Not Implemented";
- auto data_md = data.GetMKLDNNData()->get_desc();
+ auto data_md = static_cast<const mkldnn::memory*>(data.GetMKLDNNData())->get_desc();
const auto kernel_ndims = param.kernel.ndim();
mkldnn::memory::dims kernel(kernel_ndims);
@@ -288,7 +292,7 @@ MKLDNNPoolingBwd& GetPoolingBwd(const PoolingParam& param,
auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
- auto input_mem = in_data.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
const mkldnn::memory::desc data_md = input_mem->get_desc();
auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), out_grad.shape().end());
@@ -337,14 +341,16 @@ void MKLDNNPoolingGradCompute(const OpContext& ctx,
TmpMemMgr::Get()->Init(ctx.requested[0]);
auto& bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
- auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
+ auto bwd_diff_dst_desc = bwd.pd.diff_dst_desc();
+ auto diff_dst_mem =
+ static_cast<const mkldnn::memory*>(out_grad.GetMKLDNNDataReorder(&bwd_diff_dst_desc));
auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
};
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
- args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
+ args[MKLDNN_ARG_WORKSPACE] = *(static_cast<const mkldnn::memory*>(workspace->GetMKLDNNData()));
}
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args);
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index ef23a74..b9d26a0 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -36,7 +36,7 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType& req,
const NDArray& input,
const NDArray& output) {
const auto engine = CpuEngine::Get()->get_engine();
- auto in_mem = input.GetMKLDNNData();
+ auto in_mem = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
// Create temp memory
auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
@@ -71,18 +71,22 @@ void MKLDNNReshapeFwd::Execute(const NDArray& input,
const OpReqType& req,
void* workspace) {
auto stream = MKLDNNStream::Get();
- auto in_mem = input.GetMKLDNNData();
+ auto in_mem = static_cast<const mkldnn::memory*>(input.GetMKLDNNData());
// register primitives and arguments
std::vector<mkldnn_args_map_t> args_map;
size_t prims_size = prims_.size();
if (prims_size == 1) {
- args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+ args_map.push_back(
+ {{MKLDNN_ARG_FROM, *in_mem},
+ {MKLDNN_ARG_TO, *static_cast<const mkldnn::memory*>(output.GetMKLDNNData())}});
} else if (prims_size == 2) {
if (workspace) {
temp_->set_data_handle(workspace);
}
args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, {MKLDNN_ARG_TO, *temp_}});
- args_map.push_back({{MKLDNN_ARG_FROM, *temp_}, {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+ args_map.push_back(
+ {{MKLDNN_ARG_FROM, *temp_},
+ {MKLDNN_ARG_TO, *static_cast<const mkldnn::memory*>(output.GetMKLDNNData())}});
} else {
CHECK(prims_size == 0 && req != kWriteTo) << "kWriteTo should never reach here.";
}
@@ -120,8 +124,8 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const NDArray& input,
const OpReqType& req,
const NDArray& output) {
- // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's
- // simply fallback to naive implement.
+ // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to
+ // naive implement.
const int input_ndims = input.shape().ndim();
if ((input_ndims < 1 || input_ndims > 4) || !SupportMKLDNNQuantize(input.dtype())) {
if (req != kWriteInplace) {
diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc
index 345560d..8a85e60 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice.cc
+++ b/src/operator/nn/mkldnn/mkldnn_slice.cc
@@ -49,8 +49,8 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam& param, const NDArray& in, const
offsets[i] = s;
}
- auto in_md = in.GetMKLDNNData()->get_desc();
- auto out_md = out.GetMKLDNNData()->get_desc();
+ auto in_md = static_cast<const mkldnn::memory*>(in.GetMKLDNNData())->get_desc();
+ auto out_md = static_cast<const mkldnn::memory*>(out.GetMKLDNNData())->get_desc();
auto sub_md = in_md.submemory_desc(dims, offsets);
auto engine = CpuEngine::Get()->get_engine();
@@ -98,8 +98,8 @@ void MKLDNNSlice(const nnvm::NodeAttrs& attrs,
const NDArray& out) {
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
MKLDNNSliceFwd& fwd = GetSliceForward(param, ctx.is_train, in, out);
- auto in_mem = in.GetMKLDNNData();
- auto out_md = out.GetMKLDNNData()->get_desc();
+ auto in_mem = static_cast<const mkldnn::memory*>(in.GetMKLDNNData());
+ auto out_md = static_cast<const mkldnn::memory*>(out.GetMKLDNNData())->get_desc();
auto out_mem = CreateMKLDNNMem(out, out_md, req);
fwd.SetNewMem(*in_mem, *out_mem.second);
fwd.Register();
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc
index 7e948b3..67aae1f 100644
--- a/src/operator/nn/mkldnn/mkldnn_softmax.cc
+++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc
@@ -62,8 +62,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray& data, const
const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
- // Currently, MKLDNN shows bad performance when softmax is not performed on
- // the last dimension
+ // Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
@@ -111,7 +110,8 @@ static MKLDNNSoftmaxFwd& GetSoftmaxFwd(const SoftmaxParam& param,
auto it = fwds.find(key);
if (it == fwds.end()) {
- MKLDNNSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
+ MKLDNNSoftmaxFwd fwd(
+ is_train, real_axis, *(static_cast<const mkldnn::memory*>(data.GetMKLDNNData())));
it = AddToCache(&fwds, key, fwd);
}
return it->second;
@@ -124,16 +124,16 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
const NDArray& out_data) {
if (req == kNullOp)
return;
- // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace
- // for now.
+ // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data.shape().ndim());
auto fwd = GetSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
- auto in_mem = in_data.GetMKLDNNData();
- auto out_mem = out_data.GetMKLDNNData(fwd.pd.dst_desc());
+ auto in_mem = static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData());
+ auto fwd_desc = fwd.pd.dst_desc();
+ auto out_mem = static_cast<const mkldnn::memory*>(out_data.GetMKLDNNData(&fwd_desc));
MKLDNNStream* stream = MKLDNNStream::Get();
stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
stream->Submit();
@@ -176,8 +176,8 @@ static MKLDNNSoftmaxBwd& GetSoftmaxBwd(const SoftmaxParam& param,
auto it = bwds.find(key);
if (it == bwds.end()) {
- auto diff_mem = data[0].GetMKLDNNData();
- auto data_mem = data[1].GetMKLDNNData();
+ auto diff_mem = static_cast<const mkldnn::memory*>(data[0].GetMKLDNNData());
+ auto data_mem = static_cast<const mkldnn::memory*>(data[1].GetMKLDNNData());
auto fwd_pd = GetSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
@@ -195,8 +195,8 @@ void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
- auto diff_mem = in_data[0].GetMKLDNNData();
- auto data_mem = in_data[1].GetMKLDNNData();
+ auto diff_mem = static_cast<const mkldnn::memory*>(in_data[0].GetMKLDNNData());
+ auto data_mem = static_cast<const mkldnn::memory*>(in_data[1].GetMKLDNNData());
auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data);
auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
index e632245..93ea12c 100644
--- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
+++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
@@ -84,7 +84,7 @@ static MKLDNNSoftmaxOutputFwd& GetSoftmaxOutputForward(const SoftmaxOutputParam&
auto it = fwds.find(key);
if (it == fwds.end()) {
- auto in_mem = *(in_data.GetMKLDNNData());
+ auto in_mem = *(static_cast<const mkldnn::memory*>(in_data.GetMKLDNNData()));
MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem);
it = AddToCache(&fwds, key, fwd);
}
@@ -109,7 +109,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
idata = in_data[softmaxout_enum::kData].Reorder2Default();
}
- auto input_mem = idata.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(idata.GetMKLDNNData());
auto out_mem = CreateMKLDNNMem(
out_data[softmaxout_enum::kOut], input_mem->get_desc(), req[softmaxout_enum::kOut]);
diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc
index fc49b1b..c771efa 100644
--- a/src/operator/nn/mkldnn/mkldnn_sum.cc
+++ b/src/operator/nn/mkldnn/mkldnn_sum.cc
@@ -112,7 +112,7 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs,
data_mem.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
- const mkldnn::memory* in_mem = inputs[i].GetMKLDNNData();
+ const mkldnn::memory* in_mem = static_cast<const mkldnn::memory*>(inputs[i].GetMKLDNNData());
mkldnn::memory::desc tmp_md = in_mem->get_desc();
data_md.push_back(tmp_md);
data_mem.push_back(in_mem);
diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc
index 2f97302..40906b7 100644
--- a/src/operator/nn/mkldnn/mkldnn_transpose.cc
+++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc
@@ -66,7 +66,7 @@ class MKLDNNTransposeForward {
}
auto engine = CpuEngine::Get()->get_engine();
- auto in_mem = data.GetMKLDNNData();
+ auto in_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
auto src_md = in_mem->get_desc();
data_ = std::make_shared<mkldnn::memory>(src_md, engine, nullptr);
@@ -90,7 +90,8 @@ class MKLDNNTransposeForward {
void SetNewMem(const NDArray& data, const NDArray& output) {
if (data.IsMKLDNNData()) {
- this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle());
+ this->data_->set_data_handle(
+ static_cast<const mkldnn::memory*>(data.GetMKLDNNData())->get_data_handle());
} else {
MSHADOW_TYPE_SWITCH(
data.dtype(), DTYPE, { this->data_->set_data_handle(data.data().dptr<DTYPE>()); });
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 808e46c..b93e3bc 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -425,8 +425,7 @@ inline std::vector<nnvm::NodeEntry> MakeGradNode(
return CreateNodeEntries(p);
}
-// quick helper to make gradient nodes that simply pass back zero. could be used
-// in output ops.
+// quick helper to make gradient nodes that simply pass back zero. could be used in output ops.
inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
@@ -614,7 +613,8 @@ class OpSignature {
void AddSign(const NDArray& arr) {
#if MXNET_USE_MKLDNN == 1
if (arr.IsMKLDNNData()) {
- AddSign(*(arr.GetMKLDNNData()));
+ auto arr_data = static_cast<const mkldnn::memory*>(arr.GetMKLDNNData());
+ AddSign(*(arr_data));
} else {
#endif
hash = hash * 2 + arr.dtype();
diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
index 8fbe393..d9300f7 100644
--- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
@@ -62,7 +62,7 @@ void SgMKLDNNDequantizeOperator::Forward(const OpContext& ctx,
NDArray in_buffer = inputs[0];
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
- auto i_mem = in_buffer.GetMKLDNNData();
+ auto i_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
float data_min = *inputs[1].data().dptr<float>();
float data_max = *inputs[2].data().dptr<float>();
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
index 5443a46..0bad466 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
@@ -70,7 +70,7 @@ static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
- auto i_mem = in_buffer.GetMKLDNNData();
+ auto i_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
auto i_desc = i_mem->get_desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::desc o_desc;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
index 61e19bd..9bdc072 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -81,13 +81,14 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
}
}
if (req[0] != kWriteInplace) {
- const_cast<NDArray&>(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData());
+ const_cast<NDArray&>(outputs[0])
+ .CopyFrom(static_cast<const mkldnn::memory*>(inputs[0].GetMKLDNNData()));
MKLDNNStream::Get()->Submit();
}
} else {
if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
- auto i_mem = in_buffer.GetMKLDNNData();
+ auto i_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) {
data_min = param_.min_calib_range.value();
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
index 2e75dbc..9037fc7 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
@@ -41,7 +41,7 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
const NDArray& data = in_data[quantized_batchnorm::kData];
- auto data_mem = data.GetMKLDNNData();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
// reorder if data type = uint8
if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) {
@@ -98,8 +98,10 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
float* moving_var_ptr = moving_var.data().dptr<float>();
// rescale gamma and beta, to make mean=0 and var=1
- auto rescaled_mean_mem = TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_desc());
- auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_desc());
+ auto rescaled_mean_mem = TmpMemMgr::Get()->Alloc(
+ static_cast<const mkldnn::memory*>(moving_mean.GetMKLDNNData())->get_desc());
+ auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(
+ static_cast<const mkldnn::memory*>(moving_var.GetMKLDNNData())->get_desc());
float* rescaled_mean_ptr = reinterpret_cast<float*>(rescaled_mean_mem->get_data_handle());
float* rescaled_var_ptr = reinterpret_cast<float*>(rescaled_var_mem->get_data_handle());
@@ -115,7 +117,9 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
}
const NDArray& out = outputs[batchnorm::kOut];
- auto out_mem = const_cast<NDArray&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
+ auto fwd_dst_desc = fwd.GetPd().dst_desc();
+ auto out_mem =
+ static_cast<mkldnn::memory*>(const_cast<NDArray&>(out).CreateMKLDNNData(&fwd_dst_desc));
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
index 57149bb..9d62a6b 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
@@ -72,11 +72,11 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs,
auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]);
if (i_scale == out_scale) {
CHECK(in_data[i].dtype() == out_dtype);
- auto mem = in_data[i].GetMKLDNNData();
+ auto mem = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
data_mem.push_back(mem);
data_md.push_back(mem->get_desc());
} else {
- auto mem = in_data[i].GetMKLDNNData();
+ auto mem = static_cast<const mkldnn::memory*>(in_data[i].GetMKLDNNData());
auto mem_desc = mem->get_desc();
if (in_data[i].dtype() != out_dtype) {
mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(get_mkldnn_type(out_dtype));
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
index 5c3e8da..3c44123 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
@@ -46,13 +46,15 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
MKLDNNConvFullParam full_param;
full_param.conv_param = param;
full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
- auto& fwd = GetConvFwd(full_param,
+ auto& fwd = GetConvFwd(full_param,
ctx.is_train,
in_data[conv::kData],
in_data[conv::kWeight],
param.no_bias ? nullptr : &in_data[conv::kBias],
out_data[conv::kOut]);
- auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.GetPd().src_desc());
+ auto fwd_src_desc = fwd.GetPd().src_desc();
+ auto data_mem =
+ static_cast<const mkldnn::memory*>(in_data[conv::kData].GetMKLDNNDataReorder(&fwd_src_desc));
const mkldnn::memory* weight_mem;
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
@@ -60,16 +62,18 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
- weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc());
+ auto fwd_weight_desc = fwd.GetPd().weights_desc();
+ weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
} else {
- weight_mem = weight.GetMKLDNNData();
+ weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
}
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]);
mkldnn_args_map_t net_args;
if (!param.no_bias) {
- const mkldnn::memory* bias_mem =
- in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc());
+ auto fwd_bias_desc = fwd.GetPd().bias_desc();
+ const mkldnn::memory* bias_mem = static_cast<const mkldnn::memory*>(
+ in_data[conv::kBias].GetMKLDNNDataReorder(&fwd_bias_desc));
net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
}
net_args.insert({MKLDNN_ARG_SRC, *data_mem});
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
index adee6f9..10f21bc 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
@@ -109,8 +109,10 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs,
const float dataA_absmax = MaxAbs(dataA_min, dataA_max);
const float dataB_absmax = MaxAbs(dataB_min, dataB_max);
- auto dataA_mem = in_data[quantized_elemwise_add_enum::kDataA].GetMKLDNNData();
- auto dataB_mem = in_data[quantized_elemwise_add_enum::kDataB].GetMKLDNNData();
+ auto dataA_mem = static_cast<const mkldnn::memory*>(
+ in_data[quantized_elemwise_add_enum::kDataA].GetMKLDNNData());
+ auto dataB_mem = static_cast<const mkldnn::memory*>(
+ in_data[quantized_elemwise_add_enum::kDataB].GetMKLDNNData());
const bool is_dataA_int8 =
(in_data[quantized_elemwise_add_enum::kDataA].dtype() == mshadow::kInt8);
const float dataA_range = is_dataA_int8 ? kInt8Range : kUint8Range;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
index f994e7f..4e703ad 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
@@ -91,17 +91,20 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
auto& fwd =
GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);
- auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
+ auto fwd_src_desc = fwd.fwd_pd.src_desc();
+ auto data_mem =
+ static_cast<const mkldnn::memory*>(in_data[fullc::kData].GetMKLDNNDataReorder(&fwd_src_desc));
const mkldnn::memory* weight_mem = nullptr;
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
- weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
+ auto fwd_weight_desc = fwd.fwd_pd.weights_desc();
+ weight.MKLDNNDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
} else {
- weight_mem = weight.GetMKLDNNData();
+ weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(), req[fullc::kOut]);
@@ -114,7 +117,9 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const mkldnn::memory* bias_mem = nullptr;
if (!param.no_bias) {
- bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
+ auto fwd_bias_desc = fwd.fwd_pd.bias_desc();
+ bias_mem =
+ static_cast<const mkldnn::memory*>(quantized_bias.GetMKLDNNDataReorder(&fwd_bias_desc));
args[MKLDNN_ARG_BIAS] = *bias_mem;
}
diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
index f93d87b..9292615 100644
--- a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
@@ -80,7 +80,7 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
- auto i_mem = in_buffer.GetMKLDNNData();
+ auto i_mem = static_cast<const mkldnn::memory*>(in_buffer.GetMKLDNNData());
auto i_desc = i_mem->get_desc();
auto o_desc = i_desc;
o_desc.data.data_type = get_mkldnn_type_t<DstType>();
diff --git a/src/operator/subgraph/mkldnn/mkldnn_common.h b/src/operator/subgraph/mkldnn/mkldnn_common.h
index c4f2857..e4fe4e0 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_common.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_common.h
@@ -100,8 +100,8 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
const std::vector<float>& weight_scales,
const bool submit = true) {
MKLDNNStream* stream = MKLDNNStream::Get();
- const auto new_weight = NDArray(weight_md);
- const auto conv_weights_memory = new_weight.GetMKLDNNData();
+ const auto new_weight = NDArray(&weight_md);
+ const auto conv_weights_memory = static_cast<const mkldnn::memory*>(new_weight.GetMKLDNNData());
mkldnn::primitive_attr weight_attr;
if (weight_scales.size()) {
const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1;
@@ -109,7 +109,7 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
}
auto default_weights_memory = GetWeights(*weight, num_group);
if (default_weights_memory == nullptr)
- default_weights_memory = weight->GetMKLDNNData();
+ default_weights_memory = static_cast<const mkldnn::memory*>(weight->GetMKLDNNData());
const auto weight_reorder_pd =
mkldnn::reorder::primitive_desc(*default_weights_memory, *conv_weights_memory, weight_attr);
MKLDNNStream::Get()->RegisterPrimArgs(
@@ -121,12 +121,12 @@ static void ConvertWeightBias2MKLDNN(NDArray* weight,
for (size_t c = 0; c < weight_scales.size(); ++c) {
bias_scales[c] = weight_scales[c] * data_scale;
}
- new_bias = NDArray(*bias_md);
- const auto conv_bias_memory = new_bias.GetMKLDNNData();
+ new_bias = NDArray(bias_md);
+ const auto conv_bias_memory = static_cast<const mkldnn::memory*>(new_bias.GetMKLDNNData());
const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1;
mkldnn::primitive_attr bias_attr;
bias_attr.set_output_scales(bias_mask, bias_scales);
- auto bias_weights_memory = bias->GetMKLDNNData();
+ auto bias_weights_memory = static_cast<const mkldnn::memory*>(bias->GetMKLDNNData());
const auto bias_reorder_pd =
mkldnn::reorder::primitive_desc(*bias_weights_memory, *conv_bias_memory, bias_attr);
MKLDNNStream::Get()->RegisterPrimArgs(
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index 95a997f..1e492b4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -159,17 +159,17 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
// Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
if (mkldnn_param.with_sum) {
if (!initialized_) {
- // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace
- // option, which make check (req[kOut] == kWriteInplace) useless.
- auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
- auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+ // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
+ // which make check (req[kOut] == kWriteInplace) useless.
+ auto in_mkl_mem = static_cast<const mkldnn::memory*>(inputs[in_sum].GetMKLDNNData());
+ auto out_mkl_mem = static_cast<const mkldnn::memory*>(outputs[kOut].GetMKLDNNData());
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
inplace_ = true;
}
}
if (!inplace_) {
- auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
- auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+ auto in_mkl_mem = static_cast<const mkldnn::memory*>(inputs[in_sum].GetMKLDNNData());
+ auto out_mkl_mem = static_cast<const mkldnn::memory*>(outputs[kOut].GetMKLDNNData());
if (outputs[kOut].dtype() == mshadow::kInt32) {
const auto& mem_desc = in_mkl_mem->get_desc();
const auto this_dtype = get_mkldnn_type(mshadow::kInt32);
@@ -337,20 +337,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
full_conv_param.conv_param.num_group,
data_scale_,
weight_scales_);
- args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
- args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
+ args_[MKLDNN_ARG_SRC] = *static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
+ args_[MKLDNN_ARG_WEIGHTS] = *static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
if (has_bias)
- args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
- args_[MKLDNN_ARG_DST] = *output.GetMKLDNNData();
+ args_[MKLDNN_ARG_BIAS] = *static_cast<const mkldnn::memory*>(cached_bias_.GetMKLDNNData());
+ args_[MKLDNN_ARG_DST] = *static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
initialized_ = true;
}
if (mkldnn_param.with_sum) {
- const auto& output_mem = output.GetMKLDNNData();
+ const auto& output_mem = static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
const auto& out_mem_desc = output_mem->get_desc();
const auto& dst_mem_desc = fwd_->GetPd().dst_desc();
if (out_mem_desc != dst_mem_desc) {
- auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->GetPd().dst_desc());
+ auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+ auto tmp_out_mem =
+ static_cast<const mkldnn::memory*>(output.GetMKLDNNDataReorder(&fwd_dst_desc));
auto data_md = dst_mem_desc;
data_md.data.data_type = static_cast<mkldnn_data_type_t>(out_mem_desc.data.data_type);
mkldnn_mem_ptr new_out_mem(new mkldnn::memory(
@@ -362,8 +364,10 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
}
if (mkldnn_param.quantized) {
- auto data_mem = data.GetMKLDNNDataReorder(fwd_->GetPd().src_desc());
- mkldnn::memory* mem = output.CreateMKLDNNData(fwd_->GetPd().dst_desc());
+ auto fwd_src_desc = fwd_->GetPd().src_desc();
+ auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNDataReorder(&fwd_src_desc));
+ auto fwd_pd_dst_desc = fwd_->GetPd().dst_desc();
+ mkldnn::memory* mem = static_cast<mkldnn::memory*>(output.CreateMKLDNNData(&fwd_pd_dst_desc));
args_[MKLDNN_ARG_SRC] = *data_mem;
args_[MKLDNN_ARG_DST] = *mem;
MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
@@ -384,8 +388,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
*outputs[kMax].data().dptr<float>() = cached_output_max_;
}
if (mkldnn_param.with_sum) {
- auto out = const_cast<NDArray&>(outputs[kOut]);
- out.UpdateMKLDNNMemDesc(fwd_->GetPd().dst_desc());
+ auto out = const_cast<NDArray&>(outputs[kOut]);
+ auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+ out.UpdateMKLDNNMemDesc(&fwd_dst_desc);
}
}
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 5578106..123d491 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -113,8 +113,7 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
const int out_index = index++;
const int out_min_index = out_quantized ? index++ : 0;
const int out_max_index = out_quantized ? index++ : 0;
- CHECK_EQ(out_data.size(),
- index); // index is equal to total number of outpits
+ CHECK_EQ(out_data.size(), index); // index is equal to total number of outpits
float min_data = 0.0f;
float max_data = 0.0f;
@@ -132,10 +131,10 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
if (mkldnn_param.with_sum) {
if (!initialized_) {
- // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace
- // option, which make check (req[out_index] == kWriteInplace) useless.
- auto in_mkl_mem = in_data[idx.sum].GetMKLDNNData();
- auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+ // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
+ // which make check (req[out_index] == kWriteInplace) useless.
+ auto in_mkl_mem = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
+ auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
inplace_ = true;
}
@@ -144,8 +143,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
output = in_data[idx.sum];
} else {
// Not in place: copy in_data[idx.sum] into outputs[out_index].
- auto in_mkl_mem = in_data[idx.sum].GetMKLDNNData();
- auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+ auto in_mkl_mem = static_cast<const mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
+ auto out_mkl_mem = static_cast<const mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
if (out_data[out_index].dtype() == mshadow::kInt32) {
auto mem_desc = in_mkl_mem->get_desc();
auto this_dtype = get_mkldnn_type(mshadow::kInt32);
@@ -263,8 +262,7 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
bool fuse_requantize = false;
- // Channelwise scaling is only supported when fusion is enabled
- // (requantize or dequantize).
+ // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) {
cached_min_output_ = mkldnn_param.min_calib_range.value();
cached_max_output_ = mkldnn_param.max_calib_range.value();
@@ -279,16 +277,13 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
// True False Error
// False True/False False
if (channel_wise && !support_channelwise_scale) {
- LOG(FATAL) << "Currently, channel-wise quantization requires fuse requantize "
- "or dequantize."
- << " Please make sure the `min_calib_range` and `max_calib_range` "
- "are set when only"
- << " fuse requantize (outputs of FullyConnected are collected "
- "during calibration "
- "phase),"
- << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
- << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true "
- "(default is false)";
+ LOG(FATAL)
+ << "Currently, channel-wise quantization requires fuse requantize or dequantize."
+ << " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
+ << " fuse requantize (outputs of FullyConnected are collected during calibration "
+ "phase),"
+ << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
+ << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
}
support_channelwise_scale = support_channelwise_scale && channel_wise;
@@ -423,10 +418,11 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
weight_scales_,
false);
} else {
- const auto def_weight_mem = weight.GetMKLDNNData();
+ const auto def_weight_mem = static_cast<const mkldnn::memory*>(weight.GetMKLDNNData());
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
- cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
- auto cached_weight_mem = cached_weight_.GetMKLDNNData();
+ auto weight_desc = fwd_->fwd_pd.weights_desc();
+ cached_weight_ = NDArray(&weight_desc);
+ auto cached_weight_mem = static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
std::unordered_map<int, mkldnn::memory> args(
{{MKLDNN_ARG_FROM, *def_weight_mem}, {MKLDNN_ARG_TO, *cached_weight_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*def_weight_mem, *cached_weight_mem),
@@ -434,23 +430,24 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
}
}
- const auto data_mem = data.GetMKLDNNData();
+ const auto data_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
cached_data_mem_ = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);
args_[MKLDNN_ARG_SRC] = *cached_data_mem_;
- args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
+ args_[MKLDNN_ARG_WEIGHTS] = *static_cast<const mkldnn::memory*>(cached_weight_.GetMKLDNNData());
if (has_bias)
- args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
+ args_[MKLDNN_ARG_BIAS] = *static_cast<const mkldnn::memory*>(cached_bias_.GetMKLDNNData());
args_[MKLDNN_ARG_DST] = *cached_out_mem_;
initialized_ = true;
}
if (mkldnn_param.with_sum) {
- const auto& output_mem = output.GetMKLDNNData();
+ const auto& output_mem = static_cast<const mkldnn::memory*>(output.GetMKLDNNData());
const auto& out_mem_desc = output_mem->get_desc();
auto dst_mem_desc = fwd_->fwd_pd.dst_desc();
if (out_mem_desc != dst_mem_desc) {
- auto tmp_out_mem = output.GetMKLDNNDataReorder(dst_mem_desc);
+ auto tmp_out_mem =
+ static_cast<const mkldnn::memory*>(output.GetMKLDNNDataReorder(&dst_mem_desc));
dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
mkldnn_mem_ptr new_out_mem(new mkldnn::memory(
dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle()));
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index d1e5758..a1a3d9f 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -46,7 +46,7 @@ static void AMPCastExCPU(const nnvm::NodeAttrs& attrs,
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();
- const auto i_mem = data.GetMKLDNNData();
+ const auto i_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
const size_t i_ndim = data.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
@@ -94,7 +94,7 @@ static void AMPMultiCastExCPU(const nnvm::NodeAttrs& attrs,
auto data = inputs[i];
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();
- const auto i_mem = data.GetMKLDNNData();
+ const auto i_mem = static_cast<const mkldnn::memory*>(data.GetMKLDNNData());
const size_t i_ndim = data.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t j = 0; j < i_ndim; j++) {
@@ -170,8 +170,7 @@ NNVM_REGISTER_OP(_backward_amp_cast)
.set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>);
NNVM_REGISTER_OP(amp_multicast)
- .describe(
- R"code(Cast function used by AMP, that casts its inputs to the common widest type.
+ .describe(R"code(Cast function used by AMP, that casts its inputs to the common widest type.
It casts only between low precision float/FP32 and does not do anything for other types.
diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h
index e24f203..b4ee576 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -410,8 +410,8 @@ void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, const ND
// data first.
if (input.IsMKLDNNData() && input.IsView())
tmp_input = input.Reorder2Default();
- const mkldnn::memory* in_mem = tmp_input.GetMKLDNNData();
- const_cast<NDArray&>(output).CopyFrom(*in_mem);
+ const mkldnn::memory* in_mem = static_cast<const mkldnn::memory*>(tmp_input.GetMKLDNNData());
+ const_cast<NDArray&>(output).CopyFrom(in_mem);
MKLDNNStream::Get()->Submit();
} else {
mxnet_op::copy(ctx.get_stream<xpu>(), output.data(), input.data());
diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index 9c32f15..0cda299 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -86,7 +86,7 @@ inline static void InitMKLDNNArray(NDArray* arr,
bool is_rand = false,
int max = 50) {
InitDefaultArray(arr, is_rand, max);
- arr->MKLDNNDataReorderAsync(desc);
+ arr->MKLDNNDataReorderAsync(&desc);
arr->WaitToRead();
}
@@ -352,8 +352,8 @@ inline void PrintVerifyMsg(const NDArrayAttrs& arr1, const NDArrayAttrs& arr2) {
* think we should pass them to all operators. In the inference mode, the MKLDNN
* memory in the weight array will be reordered to 5 dimensions.
*
- * num_inputs / dim arguments used to scale shape (used for concat backwards to
- * enlarge input shapes)
+ * num_inputs / dim arguments used to scale shape (used for concat backwards to enlarge input
+ * shapes)
*/
inline std::vector<NDArrayAttrs> GetTestInputArrays(int types = ArrayTypes::All,
bool rand = false,
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index 41a2ffe..43e010a 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -1209,7 +1209,8 @@ void TestPoolingOp(const OpAttrs& forward_attrs, const OpAttrs& backwards_attrs)
continue;
// cannot pool if ndarray and mkldnn memory have different ndim
if (in_arr.arr.IsView() ||
- in_arr.arr.GetMKLDNNData()->get_desc().data.ndims != in_arr.arr.shape().ndim())
+ static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData())->get_desc().data.ndims !=
+ in_arr.arr.shape().ndim())
continue;
std::vector<float> scale_vector(in_arr.arr.shape().ndim());
for (int i = 0; i < in_arr.arr.shape().ndim(); i++) {
diff --git a/tests/cpp/operator/mkldnn_test.cc b/tests/cpp/operator/mkldnn_test.cc
index 1f8fe5f..ef516b7 100644
--- a/tests/cpp/operator/mkldnn_test.cc
+++ b/tests/cpp/operator/mkldnn_test.cc
@@ -26,25 +26,27 @@
#if MXNET_USE_MKLDNN == 1
#include <mkldnn_types.h>
-#include <cmath>
+
#include <climits>
+#include <cmath>
#include <set>
-#include "gtest/gtest.h"
-#include "mxnet/imperative.h"
-#include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h"
+
#include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h"
#include "../include/test_mkldnn.h"
+#include "gtest/gtest.h"
+#include "mxnet/imperative.h"
using namespace mxnet;
#if __GNUC__ >= 5
-bool test_mem_align(void *mem, size_t size, size_t alignment, size_t space) {
+bool test_mem_align(void* mem, size_t size, size_t alignment, size_t space) {
void *ret1, *ret2;
size_t space1, space2;
space1 = space;
space2 = space;
- ret1 = mxnet::AlignMem(mem, size, alignment, &space1);
- ret2 = std::align(alignment, size, mem, space2);
+ ret1 = mxnet::AlignMem(mem, size, alignment, &space1);
+ ret2 = std::align(alignment, size, mem, space2);
EXPECT_EQ(ret1, ret2);
EXPECT_EQ(space1, space2);
return ret1 == ret2;
@@ -54,29 +56,29 @@ bool test_mem_align(void *mem, size_t size, size_t alignment, size_t space) {
TEST(MKLDNN_UTIL_FUNC, AlignMem) {
#if __GNUC__ >= 5
size_t alignment = 4096;
- void *mem;
+ void* mem;
size_t size, space;
// When mem has been aligned.
- mem = reinterpret_cast<void *>(0x10000);
- size = 1000;
+ mem = reinterpret_cast<void*>(0x10000);
+ size = 1000;
space = 10000;
test_mem_align(mem, size, alignment, space);
// When mem isn't aligned and we have enough space for alignment.
- mem = reinterpret_cast<void *>(0x10010);
- size = 1000;
+ mem = reinterpret_cast<void*>(0x10010);
+ size = 1000;
space = 10000;
test_mem_align(mem, size, alignment, space);
// When mem isn't aligned and we don't have enough memory for alignment
- mem = reinterpret_cast<void *>(0x10010);
- size = 1000;
+ mem = reinterpret_cast<void*>(0x10010);
+ size = 1000;
space = 1001;
test_mem_align(mem, size, alignment, space);
for (size_t i = 0; i < 10000; i++) {
- mem = reinterpret_cast<void *>(random());
- size = random() % 2000;
+ mem = reinterpret_cast<void*>(random());
+ size = random() % 2000;
space = random() % 2000;
test_mem_align(mem, size, alignment, space);
}
@@ -87,12 +89,11 @@ TEST(MKLDNN_UTIL_FUNC, AlignMem) {
#endif
}
-static void VerifyDefMem(const mkldnn::memory &mem) {
- mkldnn::memory::desc desc = mem.get_desc();
- mshadow::default_real_t *data
- = static_cast<mshadow::default_real_t *>(mem.get_data_handle());
- size_t size = desc.get_size() / sizeof(mshadow::default_real_t);
- size_t num_same = 0;
+static void VerifyDefMem(const mkldnn::memory& mem) {
+ mkldnn::memory::desc desc = mem.get_desc();
+ mshadow::default_real_t* data = static_cast<mshadow::default_real_t*>(mem.get_data_handle());
+ size_t size = desc.get_size() / sizeof(mshadow::default_real_t);
+ size_t num_same = 0;
for (int i = 0; i < size; i++)
num_same += data[i] == static_cast<mshadow::default_real_t>(i % 100 - 50);
EXPECT_EQ(num_same, size);
@@ -105,14 +106,14 @@ TEST(MKLDNN_UTIL_FUNC, MemFormat) {
CHECK_EQ(mkldnn_oihw, 5);
}
-static void VerifyMem(const mkldnn::memory &mem) {
+static void VerifyMem(const mkldnn::memory& mem) {
mkldnn::memory::desc desc = mem.get_desc();
mkldnn::memory::dims dims(desc.data.ndims);
for (size_t i = 0; i < dims.size(); i++)
dims[i] = desc.data.dims[i];
mkldnn::memory::desc new_desc{dims,
- static_cast<mkldnn::memory::data_type>(desc.data.data_type),
- static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(desc))};
+ static_cast<mkldnn::memory::data_type>(desc.data.data_type),
+ static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(desc))};
if (desc == new_desc) {
VerifyDefMem(mem);
@@ -121,26 +122,25 @@ static void VerifyMem(const mkldnn::memory &mem) {
mkldnn::memory new_mem(new_desc, CpuEngine::Get()->get_engine());
mkldnn::stream s(CpuEngine::Get()->get_engine());
- mkldnn::reorder(*src_mem, new_mem)
- .execute(s, *src_mem, new_mem);
+ mkldnn::reorder(*src_mem, new_mem).execute(s, *src_mem, new_mem);
VerifyDefMem(new_mem);
}
}
TEST(MKLDNN_NDArray, GetDataReorder) {
- TestArrayShapes tas = GetTestArrayShapes();
- mxnet::ShapeVector shapes = tas.shapes;
+ TestArrayShapes tas = GetTestArrayShapes();
+ mxnet::ShapeVector shapes = tas.shapes;
std::vector<mkldnn::memory::desc> mds = tas.mds;
-
// Reorder from the default to any other layout.
for (auto s : shapes) {
NDArray arr(s, Context());
InitDefaultArray(&arr);
for (auto md : mds) {
if (s.Size() == md.get_size() / sizeof(mshadow::default_real_t)) {
- const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(md);
+ const mkldnn::memory* mem =
+ static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&md));
printf("reorder from (");
for (size_t i = 0; i < s.ndim(); i++)
printf("%ld, ", s[i]);
@@ -172,7 +172,8 @@ TEST(MKLDNN_NDArray, GetDataReorder) {
InitMKLDNNArray(&arr, md);
for (auto to_md : mds) {
if (to_md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) {
- const mkldnn::memory *mem = arr.GetMKLDNNDataReorder(to_md);
+ const mkldnn::memory* mem =
+ static_cast<const mkldnn::memory*>(arr.GetMKLDNNDataReorder(&to_md));
printf("reorder from (");
for (size_t i = 0; i < s.ndim(); i++)
printf("%ld, ", s[i]);
@@ -191,13 +192,13 @@ TEST(MKLDNN_NDArray, GetDataReorder) {
}
TEST(MKLDNN_BASE, MKLDNNSum) {
- std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
- std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
- TestArrayShapes tas = GetTestArrayShapes();
+ std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
+ std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
+ TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::desc> mds = tas.mds;
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
@@ -205,12 +206,12 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
continue;
}
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
- for (auto &out_arr : out_arrs) {
- auto in_mem1 = in_arr.arr.GetMKLDNNData();
- auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+ for (auto& out_arr : out_arrs) {
+ auto in_mem1 = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto in_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
if (out_arr.arr.IsView())
continue;
- auto out_mem = out_arr.arr.GetMKLDNNData();
+ auto out_mem = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
PrintVerifyMsg(in_arr, in_arr);
op::MKLDNNSum(*in_mem1, *in_mem2, *out_mem);
MKLDNNStream::Get()->Submit();
@@ -220,20 +221,20 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
// in place
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
continue;
}
- auto input_mem = in_arr.arr.GetMKLDNNData();
- auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
orig_arr.arr.WaitToRead();
PrintVerifyMsg(orig_arr, in_arr);
InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
- orig_arr.arr.CopyFrom(*input_mem);
+ orig_arr.arr.CopyFrom(input_mem);
op::MKLDNNSum(*input_mem, *input_mem2, *input_mem);
MKLDNNStream::Get()->Submit();
VerifySumResult({&orig_arr.arr, &in_arr2.arr}, {&in_arr.arr});
@@ -241,15 +242,15 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
}
TEST(MKLDNN_BASE, CreateMKLDNNMem) {
- std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
- std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
- TestArrayShapes tas = GetTestArrayShapes();
+ std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
+ std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(ArrayTypes::All, true);
+ TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::desc> mds = tas.mds;
- MKLDNNStream *stream = MKLDNNStream::Get();
+ MKLDNNStream* stream = MKLDNNStream::Get();
// kWriteTo
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
@@ -257,13 +258,13 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
continue;
}
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
- for (auto &out_arr : out_arrs) {
- auto in_mem = in_arr.arr.GetMKLDNNData();
- auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+ for (auto& out_arr : out_arrs) {
+ auto in_mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto in_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
orig_output.WaitToRead();
PrintVerifyMsg(in_arr, out_arr);
- auto out_mem = out_arr.arr.GetMKLDNNData();
+ auto out_mem = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kWriteTo);
op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second);
CommitOutput(out_arr.arr, output_mem_t);
@@ -274,22 +275,22 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
// kWriteInPlace
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
continue;
}
- auto input_mem = in_arr.arr.GetMKLDNNData();
- auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
orig_arr.arr.WaitToRead();
PrintVerifyMsg(orig_arr, in_arr);
InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
- orig_arr.arr.CopyFrom(*input_mem);
- auto output_mem_t = CreateMKLDNNMem(in_arr.arr,
- input_mem->get_desc(), kWriteInplace, &in_arr.arr);
+ orig_arr.arr.CopyFrom(input_mem);
+ auto output_mem_t =
+ CreateMKLDNNMem(in_arr.arr, input_mem->get_desc(), kWriteInplace, &in_arr.arr);
op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second);
CommitOutput(in_arr.arr, output_mem_t);
stream->Submit();
@@ -298,7 +299,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
// kAddTo
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
@@ -306,13 +307,13 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
continue;
}
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
- for (auto &out_arr : out_arrs) {
- auto in_mem = in_arr.arr.GetMKLDNNData();
- auto in_mem2 = in_arr2.arr.GetMKLDNNData();
+ for (auto& out_arr : out_arrs) {
+ auto in_mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto in_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
orig_output.WaitToRead();
PrintVerifyMsg(in_arr, out_arr);
- auto out_mem = out_arr.arr.GetMKLDNNData();
+ auto out_mem = static_cast<const mkldnn::memory*>(out_arr.arr.GetMKLDNNData());
auto output_mem_t = CreateMKLDNNMem(out_arr.arr, out_mem->get_desc(), kAddTo);
op::MKLDNNSum(*in_mem, *in_mem2, *output_mem_t.second);
CommitOutput(out_arr.arr, output_mem_t);
@@ -324,20 +325,20 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
// kNullOp
for (int i = 0; i < in_arrs.size(); i++) {
- auto in_arr = in_arrs[i];
+ auto in_arr = in_arrs[i];
auto in_arr2 = in_arrs2[i];
if (!SupportMKLDNN(in_arr.arr))
continue;
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) {
continue;
}
- auto input_mem = in_arr.arr.GetMKLDNNData();
- auto input_mem2 = in_arr2.arr.GetMKLDNNData();
+ auto input_mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ auto input_mem2 = static_cast<const mkldnn::memory*>(in_arr2.arr.GetMKLDNNData());
NDArrayAttrs orig_arr(in_arr.arr.Copy(in_arr.arr.ctx()), "In Place Copy");
orig_arr.arr.WaitToRead();
PrintVerifyMsg(orig_arr, in_arr);
InitMKLDNNArray(&orig_arr.arr, input_mem->get_desc());
- orig_arr.arr.CopyFrom(*input_mem);
+ orig_arr.arr.CopyFrom(input_mem);
auto output_mem_t = CreateMKLDNNMem(in_arr.arr, input_mem->get_desc(), kNullOp);
op::MKLDNNSum(*input_mem, *input_mem2, *output_mem_t.second);
CommitOutput(in_arr.arr, output_mem_t);
@@ -355,10 +356,10 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
for (size_t i = 0; i < dim + 1; ++i)
scale_vector[i] = 1;
scale_vector[dim] = num_inputs;
- std::vector<NDArrayAttrs> expanded_arrs = GetTestInputArrays(
- ArrayTypes::All, false, scale_vector);
+ std::vector<NDArrayAttrs> expanded_arrs =
+ GetTestInputArrays(ArrayTypes::All, false, scale_vector);
int i = 0;
- for (auto &arr : in_arrs) {
+ for (auto& arr : in_arrs) {
if (dim >= arr.arr.shape().ndim())
continue;
auto ex_arr = expanded_arrs[i];
@@ -372,22 +373,22 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
}
TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) {
- auto shapes_pds = GetTestArrayShapes();
- std::vector<mxnet::TShape> shapes = shapes_pds.shapes;
+ auto shapes_pds = GetTestArrayShapes();
+ std::vector<mxnet::TShape> shapes = shapes_pds.shapes;
std::vector<mkldnn::memory::desc> mds = shapes_pds.mds;
- for (auto &shape : shapes) {
+ for (auto& shape : shapes) {
for (int dim = 0; dim < 5; dim++) {
for (int num_inputs = 2; num_inputs < 5; num_inputs++) {
if (shape.ndim() <= dim)
continue;
- std::cout << "Extending " << shape << " dim " <<
- dim << " and " << num_inputs << "num_inputs\n";
+ std::cout << "Extending " << shape << " dim " << dim << " and " << num_inputs
+ << "num_inputs\n";
std::vector<float> scale_vector(shape.ndim());
for (int i = 0; i < shape.ndim(); i++)
scale_vector[i] = 1;
scale_vector[dim] = num_inputs;
- auto output_arrs = GetTestOutputArrays(shape, mds, scale_vector);
- for (auto &out_arr : output_arrs) {
+ auto output_arrs = GetTestOutputArrays(shape, mds, scale_vector);
+ for (auto& out_arr : output_arrs) {
auto out_shape = out_arr.arr.shape();
EXPECT_EQ(shape.Size() * num_inputs, out_arr.arr.shape().Size());
EXPECT_EQ(shape[dim] * num_inputs, out_arr.arr.shape()[dim]);
@@ -398,19 +399,19 @@ TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) {
}
TEST(MKLDNN_NDArray, CopyFrom) {
- TestArrayShapes tas = GetTestArrayShapes();
+ TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::desc> mds = tas.mds;
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
- for (auto &in_arr : in_arrs) {
+ for (auto& in_arr : in_arrs) {
if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
continue;
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), mds);
- for (auto &out_arr : out_arrs) {
- const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData();
- out_arr.arr.CopyFrom(*mem);
+ for (auto& out_arr : out_arrs) {
+ const mkldnn::memory* mem = static_cast<const mkldnn::memory*>(in_arr.arr.GetMKLDNNData());
+ out_arr.arr.CopyFrom(mem);
MKLDNNStream::Get()->Submit();
- std::vector<NDArray *> inputs(1);
+ std::vector<NDArray*> inputs(1);
inputs[0] = &in_arr.arr;
VerifyCopyResult(inputs, {&out_arr.arr});
}