You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/15 21:39:19 UTC

[GitHub] azai91 closed pull request #11291: [MXNET-550] Create MKLDNNCopy helper for copying mkldnn memory

azai91 closed pull request #11291: [MXNET-550] Create MKLDNNCopy helper for copying mkldnn memory
URL: https://github.com/apache/incubator-mxnet/pull/11291
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 94d3d90413a..6592f600e48 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -638,7 +638,6 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
 
   CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_))
       << "The size of NDArray doesn't match the requested MKLDNN memory desc";
-  MKLDNNStream *stream = MKLDNNStream::Get();
   // If this array uses MKLDNN layout, we have to make sure it's not a view.
   // Otherwise, we'll have to change the layout inside the array.
 
@@ -646,74 +645,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     ptr_->Reorder2Default();
 
   const mkldnn::memory *this_mem = GetMKLDNNData();
-  mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
-  mkldnn::memory::desc from_desc = from_pd.desc();
-  mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
-  mkldnn::memory::desc this_desc = this_pd.desc();
-  mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
-  mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
-  if (IsView()) {
-    // Sliced array must use the default layout.
-    CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
-  }
-  // It's possible that the memory and the NDArray don't have the same shape.
-  if (!same_shape(this_desc, from_desc)
-      // If the source memory uses the default layout, we can reshape directly.
-      && from_def_format == from_desc.data.format) {
-    // In this case, we can simply create a new MKLDNN memory for the required
-    // shape.
-    mkldnn::memory::dims dims(this_desc.data.dims,
-                              this_desc.data.dims + this_desc.data.ndims);
-    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
-    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
-    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
-    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
-    stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
-  } else if (!same_shape(this_desc, from_desc)) {
-    // In this case, the source memory stores data in a customized layout. We
-    // need to reorganize the data in memory before we can reshape.
-    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
-    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
-    stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
-    // Now we can reshape it
-    mkldnn::memory::dims dims(this_desc.data.dims,
-                              this_desc.data.dims + this_desc.data.ndims);
-    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
-    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
-    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
-    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
-    stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
-  } else if (from_pd == this_pd) {
-    // If the layout is the same, we can just copy data.
-    stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
-  } else {
-    // If both are not using the default layouts. There isn't much we can do,
-    // other than reorder data layout directly.
-    if (this_def_format != this_desc.data.format
-        && from_def_format != from_desc.data.format) {
-      stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
-    } else if (this_def_format == this_desc.data.format) {
-      // If the dest mem uses the default memory layout, we can simply use
-      // the default format of the source memory to improve perf of reorder.
-      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
-                                                           from_def_format);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
-      stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
-    } else {
-      // If the src mem uses the default memory layout, we can use
-      // the default format of the source memory to improve perf.
-      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
-                                                           this_def_format);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
-      stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
-    }
-  }
+  MKLDNNCopy(mem, this_mem);
 }
 
 mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) {
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index bd2faf5775a..e6fe5e4525e 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -318,6 +318,7 @@ enum OutDataOp {
 };
 
 typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
+void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);
 
 /*
  * These two functions try to create MKLDNN memory in an NDArray based on `req'.
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 1bd1581dbc2..22b97409376 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -77,6 +77,75 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) {
   }
 }
 
+void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
+  MKLDNNStream *stream = MKLDNNStream::Get();
+
+  mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
+  mkldnn::memory::desc from_desc = from_pd.desc();
+  mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
+  mkldnn::memory::desc this_desc = this_pd.desc();
+  mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
+  mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
+  // It's possible that the memory and the NDArray don't have the same shape.
+  if (!same_shape(this_desc, from_desc)
+      // If the source memory uses the default layout, we can reshape directly.
+      && from_def_format == from_desc.data.format) {
+    // In this case, we can simply create a new MKLDNN memory for the required
+    // shape.
+    mkldnn::memory::dims dims(this_desc.data.dims,
+                              this_desc.data.dims + this_desc.data.ndims);
+    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
+    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
+    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
+    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
+    stream->RegisterMem(tmp_mem);
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (!same_shape(this_desc, from_desc)) {
+    // In this case, the source memory stores data in a customized layout. We
+    // need to reorganize the data in memory before we can reshape.
+    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
+    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
+    // Now we can reshape it
+    mkldnn::memory::dims dims(this_desc.data.dims,
+                              this_desc.data.dims + this_desc.data.ndims);
+    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
+    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
+    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
+    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
+    stream->RegisterMem(tmp_mem);
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (from_pd == this_pd) {
+    // If the layout is the same, we can just copy data.
+    stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
+  } else {
+    // If both are not using the default layouts. There isn't much we can do,
+    // other than reorder data layout directly.
+    if (this_def_format != this_desc.data.format
+        && from_def_format != from_desc.data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
+    } else if (this_def_format == this_desc.data.format) {
+      // If the dest mem uses the default memory layout, we can simply use
+      // the default format of the source memory to improve perf of reorder.
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
+                                                           from_def_format);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
+      stream->RegisterMem(tmp_mem);
+      stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
+    } else {
+      // If the src mem uses the default memory layout, we can use
+      // the default format of the source memory to improve perf.
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
+                                                           this_def_format);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
+      stream->RegisterMem(tmp_mem);
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+    }
+  }
+}
+
 mkldnn_output_t CreateMKLDNNMem(const NDArray &arr,
                                 const mkldnn::memory::primitive_desc &desc,
                                 OpReqType req) {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services