You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/10/11 07:08:18 UTC
[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add
MKL-DNN reshape&flatten&expand_dims (#16258)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
new 922b616 [mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)
922b616 is described below
commit 922b6162e8b62af98fe211d1e24f1aa10716a0e6
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Fri Oct 11 15:07:24 2019 +0800
[mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)
* Add mkldnn 1.0 support for reshape/flatten/expanddims ops
* improve log & modify definition location of args_map_
* fix comments
* rebase code
* trigger CI
* trigger CI
* trigger CI
* trigger CI
---
src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +-
.../{mkldnn_flatten.cc => mkldnn_expand_dims.cc} | 149 ++++++++++-----------
src/operator/nn/mkldnn/mkldnn_flatten-inl.h | 2 +-
src/operator/nn/mkldnn/mkldnn_flatten.cc | 6 +-
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 28 ++--
src/operator/nn/mkldnn/mkldnn_reshape-inl.h | 33 ++++-
src/operator/nn/mkldnn/mkldnn_reshape.cc | 124 +++++++----------
src/operator/tensor/matrix_op-inl.h | 14 ++
src/operator/tensor/matrix_op.cc | 54 ++++++--
9 files changed, 225 insertions(+), 187 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index e4c4b98..c93cdb4 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -189,7 +189,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
-bool SupportMKLDNNReshape(const ReshapeParam ¶m, const NDArray &data);
+bool SupportMKLDNNReshape(const NDArray &in_data, const NDArray &out_data);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc
similarity index 50%
copy from src/operator/nn/mkldnn/mkldnn_flatten.cc
copy to src/operator/nn/mkldnn/mkldnn_expand_dims.cc
index 4090eb0..dcd85f1 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten.cc
+++ b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc
@@ -1,79 +1,70 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mkldnn_flatten.cc
- * \brief Implement flatten operator by using mkldnn reorder primitive
- * \author Wuxun Zhang
-*/
-
-#if MXNET_USE_MKLDNN == 1
-
-#include "mkldnn_flatten-inl.h"
-
-namespace mxnet {
-namespace op {
-
-static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req,
- const NDArray &input,
- const NDArray &output) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<OpSignature,
- MKLDNNFlattenFwd, OpHash> fwds;
-#else
- static MX_THREAD_LOCAL std::unordered_map<OpSignature,
- MKLDNNFlattenFwd, OpHash> fwds;
-#endif
- OpSignature key;
- key.AddSign(req);
- key.AddSign(input);
-
- auto it = fwds.find(key);
- if (it == fwds.end()) {
- MKLDNNFlattenFwd fwd(req, input, output);
- it = AddToCache(&fwds, key, fwd);
- }
- return it->second;
-}
-
-void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &input,
- const OpReqType &req,
- const NDArray &output) {
- if (req == kNullOp) return;
- CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
-
- auto fwd = GetFlattenForward(req, input, output);
- auto ws_size = fwd.GetWorkspaceSize();
- void* ws_ptr = nullptr;
- if (ws_size) {
- mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
- mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
- .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
- ws_ptr = reinterpret_cast<void*>(ws.dptr_);
- }
-
- fwd.Execute(input, output, ws_ptr);
-}
-
-} // namespace op
-} // namespace mxnet
-
-#endif
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_expand_dims.cc
+ * \brief Implement expand_dims operator via MKL-DNN reorder primitive
+ * \author Wuxun Zhang
+*/
+
+#if MXNET_USE_MKLDNN == 100
+
+#include "mkldnn_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNExpandDimsFwd : public MKLDNNReshapeFwd {
+ public:
+ explicit MKLDNNExpandDimsFwd(const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output)
+ : MKLDNNReshapeFwd(req, input, output) {}
+};
+
+typedef ParamOpSign<ExpandDimParam> MKLDNNExpandDimsSignature;
+
+void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output) {
+ const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
+ if (req == kNullOp) return;
+ CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
+
+ auto fwd = GetCachedForward<MKLDNNExpandDimsFwd, ExpandDimParam,
+ MKLDNNExpandDimsSignature>(param, req, input, output);
+
+ auto ws_size = fwd.GetWorkspaceSize();
+ void* ws_ptr = nullptr;
+ if (ws_size) {
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
+ .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
+ ws_ptr = reinterpret_cast<void*>(ws.dptr_);
+ }
+
+ fwd.Execute(input, output, req, ws_ptr);
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
index ae890d8..89e52cc 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h
@@ -25,7 +25,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "mkldnn_reshape-inl.h"
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc
index 4090eb0..4058399 100644
--- a/src/operator/nn/mkldnn/mkldnn_flatten.cc
+++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc
@@ -19,11 +19,11 @@
/*!
* \file mkldnn_flatten.cc
- * \brief Implement flatten operator by using mkldnn reorder primitive
+ * \brief Implement flatten operator via using MKL-DNN reorder primitive
* \author Wuxun Zhang
*/
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "mkldnn_flatten-inl.h"
@@ -70,7 +70,7 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}
- fwd.Execute(input, output, ws_ptr);
+ fwd.Execute(input, output, req, ws_ptr);
}
} // namespace op
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 793aad7..ec97c93 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -63,18 +63,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
-
-void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
- const OpContext &ctx,
- const NDArray &input,
- const OpReqType &req,
- const NDArray &output);
-
-void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &input,
- const OpReqType &req,
- const NDArray &output);
#endif
#if MXNET_USE_MKLDNN == 100
@@ -142,6 +130,22 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const NDArray &data,
const OpReqType &req,
const NDArray &output);
+
+void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output);
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output);
+void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output);
#endif
} // namespace op
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
index 63e367b..aa0f11c 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include <vector>
#include "mkldnn_base-inl.h"
#include "../../tensor/matrix_op-inl.h"
@@ -36,7 +36,6 @@ namespace op {
class MKLDNNReshapeFwd {
protected:
- std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;
@@ -47,15 +46,39 @@ class MKLDNNReshapeFwd {
const NDArray &input,
const NDArray &output);
int GetWorkspaceSize();
- void SetNewMem(const NDArray &input,
- const NDArray &output,
- void* workspace = nullptr);
void Execute(const NDArray &input,
const NDArray &output,
+ const OpReqType &req,
void* workspace = nullptr);
};
typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
+
+template<typename MKLDNNOpFwdType, typename ParamType, typename MKLDNNSigatureType>
+MKLDNNOpFwdType &GetCachedForward(const ParamType& param,
+ const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNSigatureType,
+ MKLDNNOpFwdType, OpHash> fwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNSigatureType,
+ MKLDNNOpFwdType, OpHash> fwds;
+#endif
+ MKLDNNSigatureType key(param);
+ key.AddSign(req);
+ key.AddSign(input);
+ key.AddSign(output);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ MKLDNNOpFwdType fwd(req, input, output);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index 063c85d..d180125 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -23,7 +23,7 @@
* \author Tao Lv
*/
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include <mkldnn.hpp>
#include "mkldnn_reshape-inl.h"
@@ -31,13 +31,14 @@
namespace mxnet {
namespace op {
-bool SupportMKLDNNReshape(const ReshapeParam ¶m,
- const NDArray &data) {
- auto data_ndim = data.shape().ndim();
+bool SupportMKLDNNReshape(const NDArray &in_data,
+ const NDArray &out_data) {
+ auto in_ndim = in_data.shape().ndim();
+ auto out_ndim = out_data.shape().ndim();
- if (data_ndim > 4 ||
- data.dtype() != mshadow::kFloat32 ||
- param.shape.ndim() > 4)
+ if (in_ndim > 4 ||
+ in_data.dtype() != mshadow::kFloat32 ||
+ out_ndim > 4)
return false;
return true;
@@ -48,21 +49,16 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
const NDArray &output) {
auto engine = CpuEngine::Get()->get_engine();
- // data_
+ // source
auto in_mem = input.GetMKLDNNData();
- auto in_pd = in_mem->get_primitive_desc();
- data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
+ auto in_md = in_mem->get_desc();
// temp_
- auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
- auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
- auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
- auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
- auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
- temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+ auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md));
+ temp_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);
// destination
- out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+ out_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);
if (req == kWriteInplace) {
// If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
@@ -70,17 +66,17 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
// address with input buffer.
// If the input has default layout, then nothing need to do.
if (input.IsMKLDNNData()) {
- prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
- prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
+ prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default
+ prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
needInvalidateInput = true;
}
} else if (req == kWriteTo) {
if (input.IsMKLDNNData()) {
- prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
- prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
+ prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default
+ prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
needInvalidateInput = false;
} else {
- prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
+ prims_.push_back(mkldnn::reorder(*in_mem, *out_)); // copy directly from input to output
needInvalidateInput = false;
}
} else {
@@ -89,42 +85,36 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
}
int MKLDNNReshapeFwd::GetWorkspaceSize() {
- return temp_ ? temp_->get_primitive_desc().get_size() : 0;
-}
-
-void MKLDNNReshapeFwd::SetNewMem(const NDArray &input,
- const NDArray &output,
- void* workspace) {
- if (input.IsMKLDNNData()) {
- this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
- } else {
- MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
- this->data_->set_data_handle(input.data().dptr<DTYPE>());
- })
- }
-
- if (output.IsMKLDNNData()) {
- this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
- } else {
- MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
- this->out_->set_data_handle(output.data().dptr<DTYPE>());
- })
- }
-
- if (workspace) {
- this->temp_->set_data_handle(workspace);
- }
+ return temp_ ? temp_->get_desc().get_size() : 0;
}
void MKLDNNReshapeFwd::Execute(const NDArray &input,
const NDArray &output,
+ const OpReqType &req,
void* workspace) {
- // set memory handles
- SetNewMem(input, output, workspace);
- // register primitives
auto stream = MKLDNNStream::Get();
- for (auto &v : this->prims_) {
- stream->RegisterPrim(v);
+ auto in_mem = input.GetMKLDNNData();
+ // register primitives and arguments
+ std::vector<mkldnn_args_map_t> args_map;
+ size_t prims_size = prims_.size();
+ if (prims_size == 1) {
+ args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
+ {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+ } else if (prims_size == 2) {
+ if (workspace) {
+ temp_->set_data_handle(workspace);
+ }
+ args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
+ {MKLDNN_ARG_TO, *temp_}});
+ args_map.push_back({{MKLDNN_ARG_FROM, *temp_},
+ {MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
+ } else {
+ CHECK(prims_size == 0 && req != kWriteTo)
+ << "kWriteTo should never reach here.";
+ }
+
+ for (size_t i = 0; i < prims_size; i++) {
+ stream->RegisterPrimArgs(prims_[i], args_map[i]);
}
stream->Submit();
// invalidate mkldnn memory in input
@@ -133,30 +123,6 @@ void MKLDNNReshapeFwd::Execute(const NDArray &input,
}
}
-MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
- const OpReqType &req,
- const NDArray &input,
- const NDArray &output) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<MKLDNNReshapeSignature,
- MKLDNNReshapeFwd, OpHash> fwds;
-#else
- static MX_THREAD_LOCAL std::unordered_map<MKLDNNReshapeSignature,
- MKLDNNReshapeFwd, OpHash> fwds;
-#endif
- MKLDNNReshapeSignature key(param);
- key.AddSign(req);
- key.AddSign(input);
- key.AddSign(output);
-
- auto it = fwds.find(key);
- if (it == fwds.end()) {
- MKLDNNReshapeFwd fwd(req, input, output);
- it = AddToCache(&fwds, key, fwd);
- }
- return it->second;
-}
-
void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
@@ -166,7 +132,9 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
- auto fwd = GetReshapeForward(param, req, input, output);
+ auto fwd = GetCachedForward<MKLDNNReshapeFwd, ReshapeParam,
+ MKLDNNReshapeSignature>(param, req, input, output);
+
auto ws_size = fwd.GetWorkspaceSize();
void* ws_ptr = nullptr;
if (ws_size) {
@@ -176,7 +144,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}
- fwd.Execute(input, output, ws_ptr);
+ fwd.Execute(input, output, req, ws_ptr);
}
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 5a2bd03..3f1a5f8 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -394,6 +394,10 @@ struct ExpandDimParam : public dmlc::Parameter<ExpandDimParam> {
"the input `NDArray`'s dimension is `ndim`, the range of "
"the inserted axis is `[-ndim, ndim]`");
}
+
+ bool operator==(const ExpandDimParam &other) const {
+ return this->axis == other.axis;
+ }
};
@@ -2936,6 +2940,16 @@ struct hash<mxnet::op::ReshapeParam> {
return ret;
}
};
+
+template<>
+struct hash<mxnet::op::ExpandDimParam> {
+ size_t operator()(const mxnet::op::ExpandDimParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.axis);
+ return ret;
+ }
+};
+
} // namespace std
#endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index a4f0db0..6bf1ec0 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -25,9 +25,11 @@
// this will be invoked by gcc and compile CPU version
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"
+#if MXNET_USE_MKLDNN == 100
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"
+#endif
namespace mxnet {
namespace op {
@@ -105,19 +107,18 @@ DMLC_REGISTER_PARAMETER(SqueezeParam);
DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
DMLC_REGISTER_PARAMETER(SplitParam);
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
// If inputs are supposed to be in MKLDNN format and
// MKLDNNsupport the data type or the shape. Then convert
// it to the output format and shape
- if (SupportMKLDNNReshape(param, inputs[0])) {
+ if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
@@ -207,7 +208,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
@@ -233,7 +234,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
auto data_ndim = inputs[0].shape().ndim();
if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) {
MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]);
@@ -248,7 +249,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
#endif
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
@@ -294,13 +295,13 @@ Example::
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", FlattenShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", FlattenStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_copy" })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", FlattenEx)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -411,6 +412,33 @@ Examples::
.add_arguments(TransposeParam::__FIELDS__());
+#if MXNET_USE_MKLDNN == 100
+static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), 1U);
+ auto data_ndim = inputs[0].shape().ndim();
+ if (data_ndim <= 3 && inputs[0].dtype() == mshadow::kFloat32) {
+ MKLDNNExpandDimsForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+ return;
+ }
+ FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+#endif
+
NNVM_REGISTER_OP(expand_dims)
.add_alias("_npi_expand_dims")
.describe(R"code(Inserts a new axis of size 1 into the array shape
@@ -424,6 +452,9 @@ will return a new array with shape ``(2,1,3,4)``.
.set_attr_parser(ParamParser<ExpandDimParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ExpandDimShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+#if MXNET_USE_MKLDNN == 100
+.set_attr<FInferStorageType>("FInferStorageType", ExpandDimStorageType)
+#endif
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
@@ -434,6 +465,13 @@ will return a new array with shape ``(2,1,3,4)``.
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+#if MXNET_USE_MKLDNN == 100
+.set_attr<FComputeEx>("FComputeEx<cpu>", ExpandDimEx)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(ExpandDimParam::__FIELDS__());