You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/01/21 09:40:40 UTC

[incubator-mxnet] branch v1.4.x updated: api change (#13905)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.4.x by this push:
     new 45a1554  api change (#13905)
45a1554 is described below

commit 45a1554242e0cf3e9a3683d04ad99bc653d2256d
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Mon Jan 21 17:40:12 2019 +0800

    api change (#13905)
---
 include/mxnet/ndarray.h                     | 10 +++++---
 src/ndarray/ndarray.cc                      | 39 ++++++++++++++++++-----------
 src/operator/subgraph/mkldnn/mkldnn_conv.cc | 12 ++++++---
 3 files changed, 40 insertions(+), 21 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 4ba13ca..5de42e1 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -694,9 +694,13 @@ class NDArray {
   /*
    * Create NDArray from mkldnn memory.
    * mkldnn_mem The mkldnn memory to be managed.
-   * static_data If true, mkldnn memory won't be freed on destruction.
    */
-  explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
+  explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
+  /*
+   * Create NDArray from mkldnn memory descriptor.
+   * mem_pd The mkldnn memory descriptor to be created.
+   */
+  explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
   /*
    * Test if the data is stored in one of special MKLDNN format.
    */
@@ -776,7 +780,7 @@ class NDArray {
    /*!
    * \ Fix mkldnn memory descriptor mismatch from NDArray.
    */
-  void UpdateMKLDNNMemDesc();
+  void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
 #endif
 
   /*!
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 081d4e7..a1c3497 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
 
 #if MXNET_USE_MKLDNN == 1
 
-NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
+NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
     : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
-  auto mem_pd = mkldnn_mem->get_primitive_desc();
   auto mem_desc = mem_pd.desc();
   shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
   dtype_ = get_mxnet_type(mem_desc.data.data_type);
-  auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_);
-  ptr_ = std::make_shared<Chunk>(data, 0);
+  ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+  ptr_->CheckAndAlloc(mem_pd.get_size());
   ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
-  ptr_->static_data = static_data;
+}
+
+NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
+    : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
+  auto mem_pd = mkldnn_mem->get_primitive_desc();
+  auto mem_desc = mem_pd.desc();
+  shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
+  dtype_ = get_mxnet_type(mem_desc.data.data_type);
+  ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+  ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
+  ptr_->shandle.size = mem_pd.get_size();
+  ptr_->delay_alloc = false;
+  ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
+  ptr_->static_data = true;
 }
 
 NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
@@ -717,19 +729,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
   return ptr_->mkl_mem_->GetRaw();
 }
 
-void NDArray::UpdateMKLDNNMemDesc() {
+void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
   const mkldnn::memory *mem = GetMKLDNNData();
   auto mem_desc = mem->get_primitive_desc().desc();
   auto this_dtype = get_mkldnn_type(dtype());
-  if (this_dtype != mem_desc.data.data_type) {
-    mkldnn::memory::desc data_md(
-        mkldnn::memory::dims(mem_desc.data.dims,
-                             mem_desc.data.dims + mem_desc.data.ndims),
-        this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
-    mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
-    ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
-  }
+  mkldnn::memory::desc data_md(
+      mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
+      this_dtype, format);
+  mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
+  ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
 }
 #endif
 
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index dfa98d1..8f81380 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
     }
     if (!inplace_) {
       auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
-      const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
-      output = NDArray(outputs[kOut].GetMKLDNNData());
+      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+      mkldnn_mem_ptr tmp_mem(
+          new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle()));
+      MKLDNNStream::Get()->RegisterMem(tmp_mem);
+      mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
+      output = NDArray(tmp_mem);
     }
   }
 
@@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
 
   if (mkldnn_param.with_sum) {
     auto out = const_cast<NDArray &>(outputs[kOut]);
-    out.UpdateMKLDNNMemDesc();
+    auto format = static_cast<mkldnn::memory::format>(
+        fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
+    out.UpdateMKLDNNMemDesc(format);
   }
 }