You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/10/01 09:47:56 UTC
[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add
MKL-DNN deconv (#16259)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
new 44495dc [mkldnn-v1.0] Add MKL-DNN deconv (#16259)
44495dc is described below
commit 44495dc51539e9faf09b8787a3cd47a153aad717
Author: rongzha1 <ro...@intel.com>
AuthorDate: Tue Oct 1 17:47:07 2019 +0800
[mkldnn-v1.0] Add MKL-DNN deconv (#16259)
* add mkldnn deconv
* coding style
* trigger CI
---
src/operator/nn/convolution.cc | 14 +-
src/operator/nn/deconvolution.cc | 145 ++++----
src/operator/nn/mkldnn/mkldnn_base-inl.h | 20 +-
src/operator/nn/mkldnn/mkldnn_base.cc | 27 ++
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 464 +++++++++++--------------
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 20 +-
6 files changed, 322 insertions(+), 368 deletions(-)
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index ad19128..d51e4e0 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -60,12 +60,7 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- if (CheckMKLDNNInputArrayIsView(inputs)) {
- const auto mkldnn_inputs = GetMKLDNNInputArray(inputs);
- MKLDNNConvolutionForward(attrs, ctx, mkldnn_inputs, req, outputs);
- } else {
- MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs);
- }
+ MKLDNNRun(MKLDNNConvolutionForward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
@@ -80,12 +75,7 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- if (CheckMKLDNNInputArrayIsView(inputs)) {
- const auto mkldnn_inputs = GetMKLDNNInputArray(inputs);
- MKLDNNConvolutionBackward(attrs, ctx, mkldnn_inputs, req, outputs);
- } else {
- MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
- }
+ MKLDNNRun(MKLDNNConvolutionBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 9f461f4e..d3d3eda 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -27,14 +27,75 @@
#include "./deconvolution-inl.h"
#include "../operator_common.h"
#include "../../common/utils.h"
-#if MXNET_USE_MKLDNN == 1
-#include "./mkldnn/mkldnn_ops-inl.h"
+#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_base-inl.h"
-#endif
+#include "./mkldnn/mkldnn_ops-inl.h"
+#endif // MXNET_USE_MKLDNN
namespace mxnet {
namespace op {
+#if MXNET_USE_MKLDNN == 100
+static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ if (SupportMKLDNNDeconv(params, inputs[0])) {
+ MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ MKLDNNRun(MKLDNNDeconvolutionForward, attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ FallBackCompute(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ if (SupportMKLDNNDeconv(params, inputs[0])) {
+ MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+ MKLDNNRun(MKLDNNDeconvolutionBackward, attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ FallBackCompute(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ uint32_t in_expected = param.no_bias ? 2 : 3;
+ CHECK_EQ(in_attrs->size(), in_expected);
+ CHECK_EQ(out_attrs->size(), 1);
+
+ return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
+ out_attrs);
+}
+
+inline static bool BackwardDeconvStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ uint32_t in_expected = param.no_bias ? 3 : 4;
+ uint32_t out_expected = param.no_bias ? 2 : 3;
+ CHECK_EQ(in_attrs->size(), in_expected);
+ CHECK_EQ(out_attrs->size(), out_expected);
+
+ return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
+ out_attrs);
+}
+#endif
+
static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
@@ -284,70 +345,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
return true;
}
-#if MXNET_USE_MKLDNN == 1
-inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
- uint32_t in_expected = param.no_bias ? 2 : 3;
- CHECK_EQ(in_attrs->size(), in_expected);
- CHECK_EQ(out_attrs->size(), 1);
-
- return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
- out_attrs);
-}
-
-inline static bool BackwardDeconvStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
- uint32_t out_expected = param.no_bias ? 2 : 3;
- CHECK_EQ(in_attrs->size(), param.no_bias ? 3U : 4U);
- CHECK_EQ(out_attrs->size(), out_expected);
-
- return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
- out_attrs);
-}
-
-static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
- if (SupportMKLDNNDeconv(param, inputs[0])) {
- MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs);
- MKLDNN_OPCHECK_RUN(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req,
- outputs);
- return;
- }
- FallBackCompute(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req,
- outputs);
-}
-
-static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
- if (SupportMKLDNNDeconv(param, inputs[0])) {
- MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs);
- MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req,
- outputs);
- return;
- }
- FallBackCompute(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req,
- outputs);
-}
-#endif
-
static void DeconvolutionParamParser(nnvm::NodeAttrs* attrs) {
using namespace mshadow;
DeconvolutionParam param_;
@@ -430,18 +427,16 @@ NNVM_REGISTER_OP(Deconvolution)
})
.set_attr<mxnet::FInferShape>("FInferShape", DeconvolutionShape)
.set_attr<nnvm::FInferType>("FInferType", DeconvolutionType)
-#if MXNET_USE_MKLDNN == 1
-.set_attr<FInferStorageType>("FInferStorageType", DeconvStorageType)
-#endif
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", DeconvolutionCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+.set_attr<nnvm::FGradient>("FGradient", DeconvolutionGrad{"_backward_Deconvolution"})
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", DeconvStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionComputeExCPU)
#endif
-.set_attr<nnvm::FGradient>("FGradient", DeconvolutionGrad{"_backward_Deconvolution"})
.add_argument("data", "NDArray-or-Symbol", "Input tensor to the deconvolution operation.")
.add_argument("weight", "NDArray-or-Symbol", "Weights representing the kernel.")
.add_argument("bias", "NDArray-or-Symbol", "Bias added to the result after the deconvolution "
@@ -454,15 +449,13 @@ NNVM_REGISTER_OP(_backward_Deconvolution)
return params.no_bias ? 2 : 3;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-#if MXNET_USE_MKLDNN == 1
-.set_attr<FInferStorageType>("FInferStorageType", BackwardDeconvStorageType)
-#endif
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr_parser(DeconvolutionParamParser)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", BackwardDeconvStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", DeconvolutionGradComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", DeconvolutionGradCompute<cpu>);
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 37563d2..e4c4b98 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -295,19 +295,6 @@ inline static bool CheckMKLDNNInputArrayIsView(const std::vector<NDArray> &input
return false;
}
-inline static const std::vector<NDArray> GetMKLDNNInputArray(const std::vector<NDArray> &inputs) {
- std::vector<NDArray> ret;
- ret.reserve(inputs.size());
- for (const auto &in : inputs) {
- if (in.IsView() && in.IsMKLDNNData()) {
- ret.push_back(in.Reorder2Default());
- } else {
- ret.push_back(in);
- }
- }
- return ret;
-}
-
typedef std::shared_ptr<mkldnn::memory> mkldnn_mem_ptr;
typedef std::shared_ptr<const mkldnn::memory> mkldnn_mem_const_ptr;
@@ -662,6 +649,13 @@ struct MKLDNNPostEltwiseParam {
float beta = 1.f;
};
+void MKLDNNRun(mxnet::FComputeEx fn,
+ const nnvm::NodeAttrs &attrs,
+ const mxnet::OpContext &ctx,
+ const std::vector<mxnet::NDArray> &inputs_,
+ const std::vector<mxnet::OpReqType> &req,
+ const std::vector<mxnet::NDArray> &outputs_);
+
} // namespace mxnet
#endif
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 240e366..6a6e3ee 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -570,6 +570,33 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
return dispatched;
}
+inline static const std::vector<NDArray> GetMKLDNNInputArray(const std::vector<NDArray> &inputs) {
+ std::vector<NDArray> ret;
+ ret.reserve(inputs.size());
+ for (const auto &in : inputs) {
+ if (in.IsView() && in.IsMKLDNNData()) {
+ ret.push_back(in.Reorder2Default());
+ } else {
+ ret.push_back(in);
+ }
+ }
+ return ret;
+}
+
+void MKLDNNRun(mxnet::FComputeEx fn,
+ const nnvm::NodeAttrs &attrs,
+ const mxnet::OpContext &ctx,
+ const std::vector<mxnet::NDArray> &inputs,
+ const std::vector<mxnet::OpReqType> &req,
+ const std::vector<mxnet::NDArray> &outputs) {
+ if (CheckMKLDNNInputArrayIsView(inputs)) {
+ const auto mkldnn_inputs = GetMKLDNNInputArray(inputs);
+ fn(attrs, ctx, mkldnn_inputs, req, outputs);
+ } else {
+ fn(attrs, ctx, inputs, req, outputs);
+ }
+}
+
} // namespace mxnet
#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 02a7368..8aaaa5c 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -20,34 +20,33 @@
/*!
* \file mkldnn_deconvolution.cc
* \brief
- * \author Da Zheng, Rong Zhang (rong.a.zhang@intel.com)
-*/
+ */
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "../deconvolution-inl.h"
-#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
+#include "./mkldnn_ops-inl.h"
namespace mxnet {
namespace op {
-bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) {
- if (params.kernel.ndim() != 2)
- return false;
+bool SupportMKLDNNDeconv(const DeconvolutionParam ¶ms,
+ const NDArray &input) {
+ if (params.kernel.ndim() != 2) return false;
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}
static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) {
mkldnn::memory::dims dims(1);
- // This is convolution on 4D data. The second dimension is the channel.
+ // This is deconvolution on 4D data. The second dimension is the channel.
dims[0] = md.data.dims[1];
- return mkldnn::memory::desc(dims,
- static_cast<mkldnn::memory::data_type>(md.data.data_type),
- mkldnn::memory::format::any);
+ return mkldnn::memory::desc(
+ dims, static_cast<mkldnn::memory::data_type>(md.data.data_type),
+ mkldnn::memory::format_tag::any);
}
-static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_(
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetDeconvBwd_(
const mkldnn::memory::desc &data_md, const mkldnn::memory::desc &weights_md,
bool has_bias, const mkldnn::memory::desc &out_md,
const mkldnn::engine &engine, const mkldnn::memory::dims &strides,
@@ -58,34 +57,40 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_(
// memory size may smaller than what MKL-DNN kernels require. So here we need
// select suboptimal kernel for computation according to tensor sizes.
if (!has_bias) {
- mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training,
- mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md, strides,
- dilates, padding, padding, mkldnn::padding_kind::zero);
- auto deconv_pd = mkldnn::convolution_forward::primitive_desc(desc, engine);
- while (deconv_pd.dst_primitive_desc().get_size() != GetMemDescSize(data_md) ||
- deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) ||
- deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weights_md)) {
- CHECK(deconv_pd.next_impl()) << "No implementation";
+ mkldnn::convolution_forward::desc desc(
+ mkldnn::prop_kind::forward_training,
+ mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md,
+ strides, dilates, padding, padding);
+ auto deconv_pd =
+ std::make_shared<mkldnn::convolution_forward::primitive_desc>(desc,
+ engine);
+ while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) ||
+ deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) ||
+ deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) {
+ CHECK(deconv_pd->next_impl()) << "No implementation";
}
return deconv_pd;
} else {
auto bias_md = GetBiasDesc(data_md);
- mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training,
+ mkldnn::convolution_forward::desc desc(
+ mkldnn::prop_kind::forward_training,
mkldnn::algorithm::convolution_direct, out_md, weights_md, bias_md,
- data_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero);
- auto deconv_pd = mkldnn::convolution_forward::primitive_desc(desc, engine);
- while (deconv_pd.dst_primitive_desc().get_size() != GetMemDescSize(data_md) ||
- deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) ||
- deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weights_md)) {
- CHECK(deconv_pd.next_impl()) << "No implementation";
+ data_md, strides, dilates, padding, padding);
+ auto deconv_pd =
+ std::make_shared<mkldnn::convolution_forward::primitive_desc>(desc,
+ engine);
+ while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) ||
+ deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) ||
+ deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) {
+ CHECK(deconv_pd->next_impl()) << "No implementation";
}
return deconv_pd;
}
}
-static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
- const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
- bool has_bias, const NDArray &output) {
+std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+GetDeconvFwdImpl(const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, bool has_bias, const NDArray &output) {
auto data_md = GetMemDesc(data);
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
@@ -103,27 +108,30 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
- strides, padding, dilate);
- mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
- out_md, weight_md, data_md, strides, dilate, padding, padding,
- mkldnn::padding_kind::zero);
- auto deconv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd);
+ strides, padding, dilate);
+ mkldnn::convolution_backward_data::desc desc(
+ mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md,
+ strides, dilate, padding, padding);
+ auto deconv_pd =
+ std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
+ desc, engine, *bwd_pd);
// MKL-DNN introduced padded formats since 0.15 which require more memory
// for computation compared with the actual tensor size. Currently, MKL-DNN
// operators are still reusing those memory from memory planning and the
// memory size may smaller than what MKL-DNN kernels require. So here we need
// select suboptimal kernel for computation according to tensor sizes.
- while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) ||
- deconv_pd.diff_src_primitive_desc().get_size() != GetMemDescSize(out_md) ||
- deconv_pd.weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) {
- CHECK(deconv_pd.next_impl()) << "No implementation";
+ while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) ||
+ deconv_pd->diff_src_desc().get_size() != GetMemDescSize(out_md) ||
+ deconv_pd->weights_desc().get_size() != GetMemDescSize(weight_md)) {
+ CHECK(deconv_pd->next_impl()) << "No implementation";
}
return deconv_pd;
}
-static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
- const DeconvolutionParam ¶m, const NDArray &data,
- const NDArray &weights, bool has_bias, const NDArray &output) {
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
+GetDeconvBwdDataImpl(const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, bool has_bias,
+ const NDArray &output) {
auto data_md = GetMemDesc(data);
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
@@ -140,11 +148,11 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
mkldnn::memory::dims dilate{0, 0};
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
- return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
- strides, padding, dilate);
+ return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides,
+ padding, dilate);
}
-static mkldnn::convolution_backward_weights::primitive_desc
+std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
GetDeconvBwdWeightsImpl(
const DeconvolutionParam ¶m, const NDArray &data,
const NDArray &weights, bool has_bias, const NDArray &output,
@@ -172,125 +180,64 @@ GetDeconvBwdWeightsImpl(
// memory size may smaller than what MKL-DNN kernels require. So here we need
// select suboptimal kernel for computation according to tensor sizes.
if (!has_bias) {
- mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
- out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero);
- auto deconv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
- while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) ||
- deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) ||
- deconv_pd.diff_weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) {
- CHECK(deconv_pd.next_impl()) << "No implementation";
+ mkldnn::convolution_backward_weights::desc desc(
+ mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md,
+ strides, dilate, padding, padding);
+ auto deconv_pd =
+ std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
+ desc, engine, fwd_pd);
+ while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) ||
+ deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) ||
+ deconv_pd->diff_weights_desc().get_size() !=
+ GetMemDescSize(weight_md)) {
+ CHECK(deconv_pd->next_impl()) << "No implementation";
}
return deconv_pd;
} else {
auto bias_md = GetBiasDesc(data_md);
- mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
- out_md, weight_md, bias_md, data_md, strides, dilate, padding, padding,
- mkldnn::padding_kind::zero);
- auto deconv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
- while (deconv_pd.diff_dst_primitive_desc().get_size() != GetMemDescSize(data_md) ||
- deconv_pd.src_primitive_desc().get_size() != GetMemDescSize(out_md) ||
- deconv_pd.diff_weights_primitive_desc().get_size() != GetMemDescSize(weight_md)) {
- CHECK(deconv_pd.next_impl()) << "No implementation";
+ mkldnn::convolution_backward_weights::desc desc(
+ mkldnn::algorithm::convolution_direct, out_md, weight_md, bias_md,
+ data_md, strides, dilate, padding, padding);
+ auto deconv_pd =
+ std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
+ desc, engine, fwd_pd);
+ while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) ||
+ deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) ||
+ deconv_pd->diff_weights_desc().get_size() !=
+ GetMemDescSize(weight_md)) {
+ CHECK(deconv_pd->next_impl()) << "No implementation";
}
return deconv_pd;
}
}
class MKLDNNDeconvForward {
- std::shared_ptr<mkldnn::convolution_backward_data> fwd;
- std::shared_ptr<mkldnn::memory> data;
- std::shared_ptr<mkldnn::memory> weight;
- std::shared_ptr<mkldnn::memory> bias;
- std::shared_ptr<mkldnn::memory> out;
- OutDataOp data_op;
-
public:
- MKLDNNDeconvForward(const DeconvolutionParam& param,
- const NDArray &data,
- const NDArray &weights,
- bool has_bias,
+ MKLDNNDeconvForward(const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, bool has_bias,
const NDArray &output);
- void SetDataHandle(const DeconvolutionParam& param,
- const OpContext &ctx,
- const NDArray &in_data,
- const NDArray &weight,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data);
+ const mkldnn::convolution_backward_data &GetFwd() const { return *fwd; }
- void Execute(const std::vector<NDArray> &out_data);
+ const mkldnn::convolution_backward_data::primitive_desc &GetPd() const {
+ return *fwd_pd;
+ }
private:
- mkldnn::convolution_backward_data::primitive_desc fwd_pd;
+ std::shared_ptr<mkldnn::convolution_backward_data> fwd;
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> fwd_pd;
}; // class MKLDNNDeconvForward
-MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param,
- const NDArray &data,
- const NDArray &weights,
- bool has_bias,
- const NDArray &output)
- :fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) {
- this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.diff_dst_primitive_desc()));
- this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.weights_primitive_desc()));
- this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.diff_src_primitive_desc()));
- this->fwd = std::shared_ptr<mkldnn::convolution_backward_data>(
- new mkldnn::convolution_backward_data(fwd_pd,
- mkldnn::primitive::at(*this->data),
- mkldnn::primitive::at(*this->weight),
- *this->out));
+MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam ¶m,
+ const NDArray &data,
+ const NDArray &weights, bool has_bias,
+ const NDArray &output)
+ : fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) {
+ fwd = std::make_shared<mkldnn::convolution_backward_data>(GetPd());
}
-void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
- const OpContext &ctx,
- const NDArray &in_data,
- const NDArray &weight,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data) {
- auto data_mem = in_data.GetMKLDNNDataReorder(
- fwd_pd.diff_dst_primitive_desc());
- const mkldnn::memory *weight_mem;
- if (ctx.is_train) {
- // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
- // to the default format for now.
- if (weight.IsMKLDNNData())
- // This asks the engine to reorder data after the weight array is used.
- const_cast<NDArray&>(weight).Reorder2DefaultAsync();
- weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
- } else {
- // For inference, we want to reorder the weight array so we don't need to
- // reorder data every time.
- if (weight.IsDefaultData()) {
- // We also need to modify the layout on the original weight array.
- // Don't switch below sequence because naive engine will executes
- // pushAsync synchronously.
- const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
- weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
- } else {
- weight_mem = weight.GetMKLDNNData();
- CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
- }
- }
- auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
- fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
- auto output = out_mem.second;
- this->data->set_data_handle(data_mem->get_data_handle());
- this->weight->set_data_handle(weight_mem->get_data_handle());
- this->out->set_data_handle(output->get_data_handle());
- this->data_op = out_mem.first;
-}
-
-void MKLDNNDeconvForward::Execute(const std::vector<NDArray> &out_data) {
- MKLDNNStream::Get()->RegisterPrim(*fwd);
- CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get()));
- MKLDNNStream::Get()->Submit();
-}
-
-static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
- const OpContext &ctx,
- const NDArray &bias,
- const std::vector<NDArray> &out_data) {
+static void MKLDNNDeconvFwdBiasPostProcess(
+ const DeconvolutionParam ¶m, const OpContext &ctx, const NDArray &bias,
+ const std::vector<NDArray> &out_data) {
// add bias, broadcast bias to dim 1: channel
if (!param.no_bias) {
// MKLDNN only supports float right now.
@@ -306,18 +253,19 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
}
}
-static inline MKLDNNDeconvForward &GetDeconvFwd(
- const nnvm::NodeAttrs& attrs, const NDArray &data,
- const NDArray &weights, const NDArray *bias,
- const NDArray &output) {
+MKLDNNDeconvForward &GetDeconvFwd(const nnvm::NodeAttrs &attrs,
+ const NDArray &data, const NDArray &weights,
+ const NDArray *bias, const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
- static thread_local
- std::unordered_map<DeconvSignature, MKLDNNDeconvForward, OpHash> fwds;
+ static thread_local std::unordered_map<DeconvSignature, MKLDNNDeconvForward,
+ OpHash>
+ fwds;
#else
static MX_THREAD_LOCAL
- std::unordered_map<DeconvSignature, MKLDNNDeconvForward, OpHash> fwds;
+ std::unordered_map<DeconvSignature, MKLDNNDeconvForward, OpHash>
+ fwds;
#endif
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ const DeconvolutionParam ¶m = nnvm::get<DeconvolutionParam>(attrs.parsed);
DeconvSignature key(param);
// Here we can sign the conv op with NDArray because conv primitive will
// decide the right layout for the, so we only need to get the shape and the
@@ -325,82 +273,95 @@ static inline MKLDNNDeconvForward &GetDeconvFwd(
key.AddSign(data);
key.AddSign(weights);
key.AddSign(output);
- if (bias)
- key.AddSign(*bias);
+ if (bias) key.AddSign(*bias);
auto it = fwds.find(key);
if (it == fwds.end()) {
bool has_bias = (bias != nullptr);
- MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
+ auto fwd = MKLDNNDeconvForward(param, data, weights, has_bias, output);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
-void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ const DeconvolutionParam ¶m = nnvm::get<DeconvolutionParam>(attrs.parsed);
- auto data = in_data[deconv::kData];
- if (data.IsView() && data.IsMKLDNNData())
- data = data.Reorder2Default();
+ auto &data = in_data[deconv::kData];
+ auto &weight = in_data[deconv::kWeight];
+ const NDArray *bias = param.no_bias ? nullptr : &in_data[deconv::kBias];
- auto weight = in_data[deconv::kWeight];
- if (weight.IsView() && weight.IsMKLDNNData())
- weight = weight.Reorder2Default();
+ MKLDNNDeconvForward &fwd =
+ GetDeconvFwd(attrs, data, weight, bias, out_data[deconv::kOut]);
- const NDArray* bias = param.no_bias ? nullptr : &in_data[deconv::kBias];
+ auto data_mem = data.GetMKLDNNDataReorder(fwd.GetPd().diff_dst_desc());
+ const mkldnn::memory *weight_mem;
+ if (ctx.is_train) {
+ // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
+ // to the default format for now.
+ if (weight.IsMKLDNNData())
+ // This asks the engine to change the layout of the weight array after
+ // it's used.
+ weight.Reorder2DefaultAsync();
+ weight_mem =
+ GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
+ } else {
+ // For inference, we want to reorder the weight array so we don't need to
+ // reorder data every time.
+ if (weight.IsDefaultData()) {
+ // We also need to modify the layout on the original weight array. The
+ // data conversion happens after the weight array is used.
+ weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc());
+ weight_mem =
+ GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
- MKLDNNDeconvForward &deconvFwd = GetDeconvFwd(
- attrs, data, weight, bias, out_data[deconv::kOut]);
+ } else {
+ weight_mem = weight.GetMKLDNNData();
+ CHECK(weight_mem->get_desc() == fwd.GetPd().weights_desc());
+ }
+ }
+ mkldnn_output_t out_mem;
+ out_mem = CreateMKLDNNMem(out_data[deconv::kOut], fwd.GetPd().diff_src_desc(),
+ req[deconv::kOut]);
- deconvFwd.SetDataHandle(param, ctx, data, weight, req, out_data);
+ mkldnn_args_map_t net_args;
- deconvFwd.Execute(out_data);
+ net_args.insert({MKLDNN_ARG_DIFF_DST, *data_mem});
+ net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem});
+ net_args.insert({MKLDNN_ARG_DIFF_SRC, *out_mem.second});
+ MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+ CommitOutput(out_data[deconv::kOut], out_mem);
+ MKLDNNStream::Get()->Submit();
MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data);
}
class MKLDNNDeconvBackwardData {
std::shared_ptr<mkldnn::convolution_forward> bwd;
- std::shared_ptr<mkldnn::memory> data;
- std::shared_ptr<mkldnn::memory> weight;
- std::shared_ptr<mkldnn::memory> out;
public:
- const mkldnn::convolution_forward::primitive_desc pd;
-
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> bwd_pd;
MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data,
- const NDArray &weights, const NDArray &output)
- : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) {
- }
-
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
- const mkldnn::memory &output) {
- if (bwd == nullptr) {
- this->data = std::shared_ptr<mkldnn::memory>(
- new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle()));
- this->weight = std::shared_ptr<mkldnn::memory>(
- new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle()));
- this->out = std::shared_ptr<mkldnn::memory>(
- new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle()));
- bwd = std::shared_ptr<mkldnn::convolution_forward>(
- new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data),
- mkldnn::primitive::at(*this->weight),
- *this->out));
- } else {
- this->data->set_data_handle(data.get_data_handle());
- this->weight->set_data_handle(weight.get_data_handle());
- this->out->set_data_handle(output.get_data_handle());
- }
- }
+ const NDArray &weights, const NDArray &output);
const mkldnn::convolution_forward &GetBwd() const { return *bwd; }
+ const mkldnn::convolution_forward::primitive_desc &GetDataPd() const {
+ return *bwd_pd;
+ }
};
+MKLDNNDeconvBackwardData::MKLDNNDeconvBackwardData(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output)
+ : bwd_pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) {
+ bwd = std::make_shared<mkldnn::convolution_forward>(GetDataPd());
+}
+
typedef ParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;
static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
@@ -425,7 +386,7 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
auto it = bwds.find(key);
if (it == bwds.end()) {
- MKLDNNDeconvBackwardData bwd(param, data, weights, output);
+ auto bwd = MKLDNNDeconvBackwardData(param, data, weights, output);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
@@ -433,44 +394,30 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
class MKLDNNDeconvBackwardWeights {
std::shared_ptr<mkldnn::convolution_backward_weights> bwd;
- std::shared_ptr<mkldnn::memory> data;
- std::shared_ptr<mkldnn::memory> weight;
- std::shared_ptr<mkldnn::memory> out;
public:
- const mkldnn::convolution_backward_weights::primitive_desc pd;
-
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+ bwd_data_pd;
MKLDNNDeconvBackwardWeights(
const DeconvolutionParam ¶m, const NDArray &data,
const NDArray &weights, const NDArray &output,
- const mkldnn::convolution_forward::primitive_desc &bwd_data_pd)
- : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output,
- bwd_data_pd)) {}
-
- void SetNewMem(
- const mkldnn::memory &data, const mkldnn::memory &weight,
- const mkldnn::memory &output,
- const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
- if (bwd == nullptr) {
- this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- bwd_data_pd.src_primitive_desc(), data.get_data_handle()));
- this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- bwd_data_pd.weights_primitive_desc(), weight.get_data_handle()));
- this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- bwd_data_pd.dst_primitive_desc(), output.get_data_handle()));
- bwd = std::shared_ptr<mkldnn::convolution_backward_weights>(
- new mkldnn::convolution_backward_weights(pd, *this->data,
- *this->weight, *this->out));
- } else {
- this->data->set_data_handle(data.get_data_handle());
- this->weight->set_data_handle(weight.get_data_handle());
- this->out->set_data_handle(output.get_data_handle());
- }
- }
-
+ const mkldnn::convolution_forward::primitive_desc &bwd_data_pd);
const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; }
+ const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd()
+ const {
+ return *bwd_data_pd;
+ }
};
+MKLDNNDeconvBackwardWeights::MKLDNNDeconvBackwardWeights(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output,
+ const mkldnn::convolution_forward::primitive_desc &bwd_data_pd)
+ : bwd_data_pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output,
+ bwd_data_pd)) {
+ bwd = std::make_shared<mkldnn::convolution_backward_weights>(GetWeightsPd());
+}
+
static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights(
const DeconvolutionParam ¶m, const NDArray &data,
const NDArray &weights, const NDArray &output,
@@ -494,7 +441,8 @@ static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights(
auto it = bwds.find(key);
if (it == bwds.end()) {
- MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd);
+ auto bwd =
+ MKLDNNDeconvBackwardWeights(param, data, weights, output, bwd_data_pd);
auto ins_ret = bwds.insert(
std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardWeights>(key,
bwd));
@@ -513,47 +461,50 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &in_grad = outputs;
const DeconvolutionParam ¶m = nnvm::get<DeconvolutionParam>(attrs.parsed);
- auto data = inputs[deconv::kData + 1];
- if (data.IsView() && data.IsMKLDNNData())
- data = data.Reorder2Default();
-
- auto weight = inputs[deconv::kWeight + 1];
- if (weight.IsView() && weight.IsMKLDNNData())
- weight = weight.Reorder2Default();
+ auto &data = inputs[deconv::kData + 1];
+ auto &weight = inputs[deconv::kWeight + 1];
+ auto &out_grad = inputs[deconv::kOut];
CHECK_NE(req[deconv::kWeight], kWriteInplace)
<< "cannot write weight inplace";
MKLDNNDeconvBackwardData &bwd_data =
GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]);
- auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
- bwd_data.pd.src_primitive_desc());
+ auto out_grad_mem =
+ out_grad.GetMKLDNNDataReorder(bwd_data.GetDataPd().src_desc());
if (req[deconv::kData]) {
- auto weight_mem =
- GetWeights(weight, bwd_data.pd.weights_primitive_desc(), param.num_group);
+ auto weight_mem = GetWeights(weight, bwd_data.GetDataPd().weights_desc(),
+ param.num_group);
auto in_grad_mem =
- CreateMKLDNNMem(in_grad[deconv::kData],
- bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
- bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
- MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd());
+ CreateMKLDNNMem(in_grad[deconv::kData], bwd_data.GetDataPd().dst_desc(),
+ req[deconv::kData]);
+ mkldnn_args_map_t net_args = {{MKLDNN_ARG_SRC, *out_grad_mem},
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
+ {MKLDNN_ARG_DST, *in_grad_mem.second}};
+ MKLDNNStream::Get()->RegisterPrimArgs(bwd_data.GetBwd(), net_args);
CommitOutput(in_grad[deconv::kData], in_grad_mem);
}
if (req[deconv::kWeight]) {
MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
- param, data, weight,
- inputs[deconv::kOut], bwd_data.pd);
- if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc())
- out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
- bwd_weights.pd.src_primitive_desc());
- auto data_mem = data.GetMKLDNNDataReorder(
- bwd_weights.pd.diff_dst_primitive_desc());
+ param, data, weight, inputs[deconv::kOut], bwd_data.GetDataPd());
+ if (bwd_data.GetDataPd().src_desc() !=
+ bwd_weights.GetWeightsPd().src_desc())
+ out_grad_mem =
+ out_grad.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().src_desc());
+ auto data_mem =
+ data.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().diff_dst_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(
- in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
- req[deconv::kWeight]);
- bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd);
- MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd());
+ in_grad[deconv::kWeight],
+ bwd_weights.GetWeightsPd().diff_weights_desc(), req[deconv::kWeight]);
+
+ mkldnn_args_map_t net_args = {
+ {MKLDNN_ARG_SRC, *out_grad_mem},
+ {MKLDNN_ARG_DIFF_DST, *data_mem},
+ {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}};
+ MKLDNNStream::Get()->RegisterPrimArgs(bwd_weights.GetBwd(), net_args);
CommitOutput(in_grad[deconv::kWeight], in_grad_weight);
}
MKLDNNStream::Get()->Submit();
+
if (!param.no_bias) {
typedef float DType;
Stream<cpu> *s = ctx.get_stream<cpu>();
@@ -573,5 +524,4 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
} // namespace op
} // namespace mxnet
-
-#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_USE_MKLDNN == 100
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 20d80cd..3713098 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -44,16 +44,6 @@ namespace mxnet {
namespace op {
#if MXNET_USE_MKLDNN == 1
-/* For deconvolution */
-void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray> &in_data,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data);
-void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs);
-
/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
@@ -114,6 +104,16 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
+/* For deconvolution */
+void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
/* For activation */
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,