You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/06/13 21:32:39 UTC

[incubator-mxnet] 09/12: handle NDArray slice properly for mkldnn layout

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch v1.2.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 22966a0e6c9759cb83b723bbd05eb01946a00ce7
Author: Ashok Emani <as...@intel.com>
AuthorDate: Thu Apr 19 14:07:46 2018 -0700

    handle NDArray slice properly for mkldnn layout
---
 src/ndarray/ndarray.cc | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 4b45969..0175c5c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -542,11 +542,18 @@ NDArray NDArray::Reorder2Default() const {
   if (format == ptr_->mkl_mem_->GetFormat())
     return *this;
 
-  NDArray ret(shape(), ctx(), false, dtype());
+  // create new ndarray from  mkldnn layout
+  mkldnn::memory::desc from_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
+  TShape tshape(from_desc.data.ndims);
+  for (int i = 0; i < from_desc.data.ndims; i++) tshape[i] = from_desc.data.dims[i];
+  NDArray ret(tshape, ctx(), false, dtype());
   mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format);
   CHECK(ret.ptr_->shandle.size >= def_pd.get_size());
   mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr);
   ptr_->mkl_mem_->ReorderTo(&def_mem);
+  // reshape as needed
+  ret.shape_ = shape_;
+  ret.byte_offset_ = byte_offset_;
   return ret;
 }
 

-- 
To stop receiving notification emails like this one, please contact
anirudh2290@apache.org.