You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/10/15 11:10:33 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN slice (#16484)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new fbb43a9  [mkldnn-v1.0] Add MKL-DNN slice (#16484)
fbb43a9 is described below

commit fbb43a9d8d51d15f542c268d6888231a88c929d1
Author: Tao Lv <ta...@intel.com>
AuthorDate: Tue Oct 15 19:09:25 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN slice (#16484)
    
    * change slice to mkldnn v1.0
    
    * fix lint
---
 src/operator/nn/mkldnn/mkldnn_slice-inl.h |  6 +++---
 src/operator/nn/mkldnn/mkldnn_slice.cc    | 31 +++++++++++++++++--------------
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
index f41db01..2233466 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
@@ -26,7 +26,7 @@
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
@@ -45,7 +45,7 @@ class MKLDNNSliceFwd {
                  const NDArray &in,
                  const NDArray &out);
   void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
-  const mkldnn::reorder &GetPd() const;
+  void Register();
 
  private:
   std::shared_ptr<mkldnn::memory> data_;
@@ -62,5 +62,5 @@ void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
 
 }  // namespace op
 }  // namespace mxnet
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
 #endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc
index 2a817a2..d9d98b3 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice.cc
+++ b/src/operator/nn/mkldnn/mkldnn_slice.cc
@@ -23,7 +23,7 @@
  * \author Zhiyuan Huang
 */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include "./mkldnn_ops-inl.h"
 #include "./mkldnn_base-inl.h"
@@ -49,13 +49,15 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam &param,
     dims[i] = oshape[i];
     offsets[i] = s;
   }
-  auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc();
-  auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
-  auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
-  auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd);
-  this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
-  this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
-  this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, *this->out_);
+
+  auto in_md = in.GetMKLDNNData()->get_desc();
+  auto out_md = out.GetMKLDNNData()->get_desc();
+  auto sub_md = in_md.submemory_desc(dims, offsets);
+
+  auto engine = CpuEngine::Get()->get_engine();
+  this->data_ = std::make_shared<mkldnn::memory>(sub_md, engine, nullptr);
+  this->out_ = std::make_shared<mkldnn::memory>(out_md, engine, nullptr);
+  this->fwd_ = std::make_shared<mkldnn::reorder>(*this->data_, *this->out_);
 }
 
 void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) {
@@ -63,8 +65,9 @@ void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory
   this->out_->set_data_handle(output.get_data_handle());
 }
 
-const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
-  return *fwd_;
+void MKLDNNSliceFwd::Register() {
+  MKLDNNStream::Get()->RegisterPrimArgs(*fwd_,
+      {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}});
 }
 
 MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
@@ -91,14 +94,14 @@ void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
                  const NDArray &in, OpReqType req, const NDArray &out) {
   MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out);
   auto in_mem = in.GetMKLDNNData();
-  auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
-  auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
+  auto out_md = out.GetMKLDNNData()->get_desc();
+  auto out_mem = CreateMKLDNNMem(out, out_md, req);
   fwd.SetNewMem(*in_mem, *out_mem.second);
-  MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
+  fwd.Register();
   CommitOutput(out, out_mem);
   MKLDNNStream::Get()->Submit();
 }
 
 }  // namespace op
 }  // namespace mxnet
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100