You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/01/21 09:40:40 UTC
[incubator-mxnet] branch v1.4.x updated: api change (#13905)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.4.x by this push:
new 45a1554 api change (#13905)
45a1554 is described below
commit 45a1554242e0cf3e9a3683d04ad99bc653d2256d
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Mon Jan 21 17:40:12 2019 +0800
api change (#13905)
---
include/mxnet/ndarray.h | 10 +++++---
src/ndarray/ndarray.cc | 39 ++++++++++++++++++-----------
src/operator/subgraph/mkldnn/mkldnn_conv.cc | 12 ++++++---
3 files changed, 40 insertions(+), 21 deletions(-)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 4ba13ca..5de42e1 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -694,9 +694,13 @@ class NDArray {
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
- * static_data If true, mkldnn memory won't be freed on destruction.
*/
- explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
+ explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
+ /*
+ * Create NDArray from mkldnn memory descriptor.
+ * mem_pd The mkldnn memory descriptor to be created.
+ */
+ explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
@@ -776,7 +780,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
- void UpdateMKLDNNMemDesc();
+ void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
#endif
/*!
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 081d4e7..a1c3497 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
#if MXNET_USE_MKLDNN == 1
-NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
+NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
- auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
- auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_);
- ptr_ = std::make_shared<Chunk>(data, 0);
+ ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+ ptr_->CheckAndAlloc(mem_pd.get_size());
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
- ptr_->static_data = static_data;
+}
+
+NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
+ : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
+ auto mem_pd = mkldnn_mem->get_primitive_desc();
+ auto mem_desc = mem_pd.desc();
+ shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
+ dtype_ = get_mxnet_type(mem_desc.data.data_type);
+ ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+ ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
+ ptr_->shandle.size = mem_pd.get_size();
+ ptr_->delay_alloc = false;
+ ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
+ ptr_->static_data = true;
}
NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
@@ -717,19 +729,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
return ptr_->mkl_mem_->GetRaw();
}
-void NDArray::UpdateMKLDNNMemDesc() {
+void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
const mkldnn::memory *mem = GetMKLDNNData();
auto mem_desc = mem->get_primitive_desc().desc();
auto this_dtype = get_mkldnn_type(dtype());
- if (this_dtype != mem_desc.data.data_type) {
- mkldnn::memory::desc data_md(
- mkldnn::memory::dims(mem_desc.data.dims,
- mem_desc.data.dims + mem_desc.data.ndims),
- this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
- mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
- ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
- MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
- }
+ mkldnn::memory::desc data_md(
+ mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
+ this_dtype, format);
+ mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
+ ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
+ MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
#endif
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index dfa98d1..8f81380 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
}
if (!inplace_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
- const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
- output = NDArray(outputs[kOut].GetMKLDNNData());
+ auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+ mkldnn_mem_ptr tmp_mem(
+ new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle()));
+ MKLDNNStream::Get()->RegisterMem(tmp_mem);
+ mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
+ output = NDArray(tmp_mem);
}
}
@@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
if (mkldnn_param.with_sum) {
auto out = const_cast<NDArray &>(outputs[kOut]);
- out.UpdateMKLDNNMemDesc();
+ auto format = static_cast<mkldnn::memory::format>(
+ fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
+ out.UpdateMKLDNNMemDesc(format);
}
}