You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/10/15 11:10:33 UTC
[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add
MKL-DNN slice (#16484)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
new fbb43a9 [mkldnn-v1.0] Add MKL-DNN slice (#16484)
fbb43a9 is described below
commit fbb43a9d8d51d15f542c268d6888231a88c929d1
Author: Tao Lv <ta...@intel.com>
AuthorDate: Tue Oct 15 19:09:25 2019 +0800
[mkldnn-v1.0] Add MKL-DNN slice (#16484)
* change slice to mkldnn v1.0
* fix lint
---
src/operator/nn/mkldnn/mkldnn_slice-inl.h | 6 +++---
src/operator/nn/mkldnn/mkldnn_slice.cc | 31 +++++++++++++++++--------------
2 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
index f41db01..2233466 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h
@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
@@ -45,7 +45,7 @@ class MKLDNNSliceFwd {
const NDArray &in,
const NDArray &out);
void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
- const mkldnn::reorder &GetPd() const;
+ void Register();
private:
std::shared_ptr<mkldnn::memory> data_;
@@ -62,5 +62,5 @@ void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx,
} // namespace op
} // namespace mxnet
-#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc
index 2a817a2..d9d98b3 100644
--- a/src/operator/nn/mkldnn/mkldnn_slice.cc
+++ b/src/operator/nn/mkldnn/mkldnn_slice.cc
@@ -23,7 +23,7 @@
* \author Zhiyuan Huang
*/
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
@@ -49,13 +49,15 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m,
dims[i] = oshape[i];
offsets[i] = s;
}
- auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc();
- auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
- auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
- auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd);
- this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
- this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
- this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, *this->out_);
+
+ auto in_md = in.GetMKLDNNData()->get_desc();
+ auto out_md = out.GetMKLDNNData()->get_desc();
+ auto sub_md = in_md.submemory_desc(dims, offsets);
+
+ auto engine = CpuEngine::Get()->get_engine();
+ this->data_ = std::make_shared<mkldnn::memory>(sub_md, engine, nullptr);
+ this->out_ = std::make_shared<mkldnn::memory>(out_md, engine, nullptr);
+ this->fwd_ = std::make_shared<mkldnn::reorder>(*this->data_, *this->out_);
}
void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) {
@@ -63,8 +65,9 @@ void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory
this->out_->set_data_handle(output.get_data_handle());
}
-const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
- return *fwd_;
+void MKLDNNSliceFwd::Register() {
+ MKLDNNStream::Get()->RegisterPrimArgs(*fwd_,
+ {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}});
}
MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train,
@@ -91,14 +94,14 @@ void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out);
auto in_mem = in.GetMKLDNNData();
- auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
- auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
+ auto out_md = out.GetMKLDNNData()->get_desc();
+ auto out_mem = CreateMKLDNNMem(out, out_md, req);
fwd.SetNewMem(*in_mem, *out_mem.second);
- MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
+ fwd.Register();
CommitOutput(out, out_mem);
MKLDNNStream::Get()->Submit();
}
} // namespace op
} // namespace mxnet
-#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_USE_MKLDNN == 100