You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/09/13 05:50:36 UTC
[incubator-mxnet] branch master updated: MKLDNN Backward op cache
(#11301)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 741635a MKLDNN Backward op cache (#11301)
741635a is described below
commit 741635abdb51daad7d87e386423dd22fdd1eceda
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Thu Sep 13 13:50:24 2018 +0800
MKLDNN Backward op cache (#11301)
* Enable primitive allocation cache for _backward_Activation.
Change-Id: I545628ff68a54cb01b7fef323dc3de9bd47b1a19
* Enable primitive allocation cache for _backward_Deconvolution.
Change-Id: I1e9bf1b9b44bae52068a9c564dff037851e896e5
* Enable primitive allocation cache for _backward_Pooling.
Change-Id: Idbe94e21f1e2ddf711523767194b95beda19b120
* Enable primitive allocation cache for _backward_LRN.
Change-Id: Iefe9f720de719ec2e2f5d24a006602425136711b
* Enable primitive allocation cache for _backward_BatchNorm.
Change-Id: I9e52651bd830b8cb5d2f193076ef51606c9056f9
* Enable primitive allocation cache for _backward_Convolution
Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca
* Enable primitive allocation cache for _backward_Fully_Connected
Change-Id: I8347527ec1271b1518921a74e3581d7d84187429
* remove fc forward and fix indent problem
* remove fc forward and fix convolution indent problem
* Change log level to FATAL for unreachable code in mkldnn_act.cc
* remove fc forward and fix convolution indent problem
* remove useless hint in fc
* Merge branch 'master' into backward_op_cache
* Empty commit to retrigger the CI.
* Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
* Fix build issue after code merge.
* Fix lint after merge
* Fix mkldnn act.
---
src/operator/nn/fully_connected.cc | 3 +-
src/operator/nn/lrn.cc | 2 +
src/operator/nn/mkldnn/mkldnn_act.cc | 125 +++++++++++---
src/operator/nn/mkldnn/mkldnn_base-inl.h | 5 +
src/operator/nn/mkldnn/mkldnn_base.cc | 11 ++
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 129 +++++++++-----
src/operator/nn/mkldnn/mkldnn_convolution.cc | 205 +++++++++++++++++++---
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 225 ++++++++++++++++++++-----
src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 180 ++++++++++++++------
src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 21 +++
src/operator/nn/mkldnn/mkldnn_pooling.cc | 186 +++++++++++++-------
11 files changed, 844 insertions(+), 248 deletions(-)
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index d8a32f0..a178b27 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);
-
- bool dispatched = false;
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
+ bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index a428eb1..99f0dc4 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ // Copy outputs[1] from opcheck reference as backward check needs it.
+ MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
return;
}
FallBackCompute(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 744fed2..c914b38 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
alg, data_md, alpha);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
});
- LOG(INFO) << "Unsupported data type for MKLDNN activation";
+ LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
@@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
stream->Submit();
}
+static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
+ const ActivationParam ¶m, const mkldnn::memory &input_mem,
+ const mkldnn::memory &diff_dst_memory, int dtype) {
+ mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
+ mkldnn::memory::desc data_md = data_mpd.desc();
+ mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
+ auto cpu_engine = data_mpd.get_engine();
+ auto alg = GetMKLDNNActAlgo(param);
+
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ DType alpha = 0;
+ mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+ alg, data_md, alpha);
+ mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+ mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
+ mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+ fw_pdesc);
+ return bw_pdesc;
+ });
+ LOG(FATAL) << "Unsupported data type for MKLDNN activation";
+ mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
+ alg, data_md, 0.0);
+ mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
+ mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
+ mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
+ fw_pdesc);
+ return bw_pdesc;
+}
+
+class MKLDNNActBackward {
+ std::shared_ptr<mkldnn::eltwise_backward> bwd;
+ std::shared_ptr<mkldnn::memory> data;
+ std::shared_ptr<mkldnn::memory> diff_dst_memory;
+ std::shared_ptr<mkldnn::memory> diff_src_memory;
+
+ public:
+ const mkldnn::eltwise_backward::primitive_desc pd;
+
+ explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data,
+ const mkldnn::memory &mem,
+ const mkldnn::memory &diff_dst_memory)
+ : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}
+
+ void SetNewMem(const mkldnn::memory &data,
+ const mkldnn::memory &diff_dst_memory,
+ const mkldnn::memory &diff_src_memory) {
+ if (this->bwd != nullptr) {
+ this->data->set_data_handle(data.get_data_handle());
+ this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
+ this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
+ } else {
+ this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ data.get_primitive_desc(), data.get_data_handle()));
+ this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
+ diff_dst_memory.get_data_handle()));
+ this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(diff_src_memory.get_primitive_desc(),
+ diff_src_memory.get_data_handle()));
+ this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
+ new mkldnn::eltwise_backward(
+ this->pd, mkldnn::primitive::at(*this->data),
+ *this->diff_dst_memory, *this->diff_src_memory));
+ }
+ }
+
+ const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m,
+ const OpContext &ctx,
+ const NDArray &in_data,
+ const NDArray &out_grad,
+ const mkldnn::memory &in_mem) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
+#endif
+ MKLDNNActSignature key(param);
+ key.AddSign(in_data);
+ key.AddSign(out_grad);
+
+ auto it = bwds.find(key);
+ if (it == bwds.end()) {
+ MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
+ auto ins_ret =
+ bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
// function, since the computation only involes non-zeros.
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
@@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
- mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
- mkldnn::memory::desc data_md = data_mpd.desc();
- mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
- auto cpu_engine = data_mpd.get_engine();
-
+ MKLDNNActBackward &bwd =
+ GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
- auto alg = GetMKLDNNActAlgo(param);
- mkldnn_output_t diff_src_memory;
-
- MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
- DType alpha = 0;
- mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
- alg, data_md, alpha);
- mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
- mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
- mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
- fw_pdesc);
-
- diff_src_memory = CreateMKLDNNMem(in_grad,
- bw_pdesc.diff_src_primitive_desc(), req);
- stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
- *diff_dst_memory,
- *diff_src_memory.second));
- });
+ mkldnn_output_t diff_src_memory =
+ CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+ bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
+ stream->RegisterPrim(bwd.GetBwd());
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 6eb90f8..c4a4e52 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -509,6 +509,9 @@ class OpCheck {
const std::vector<mxnet::NDArray> &inputs_,
const std::vector<mxnet::OpReqType> &req,
const std::vector<mxnet::NDArray> &outputs_);
+
+ void CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+ const std::vector<size_t>& indice);
};
bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
@@ -525,6 +528,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
#define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \
if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs);
+#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \
+ if (debug) check.CopyResult(outputs, indice);
} // namespace mxnet
#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index f3facd9..029f23b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
}
}
+void OpCheck::CopyResult(const std::vector<mxnet::NDArray> &outputs_,
+ const std::vector<size_t> &indice) {
+ CHECK(!MKLDNNStream::Get()->HasOps());
+ auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray> &>(outputs_);
+ for (auto i = indice.begin(); i != indice.end(); ++i) {
+ auto mem = outputs[*i].GetMKLDNNData();
+ non_const_outputs_[*i].CopyFrom(*mem);
+ }
+ MKLDNNStream::Get()->Submit();
+}
+
bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
bool support_mkldnn,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 9046836..496ff99 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
}
}
+class MKLDNNBNBackward {
+ std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
+ std::shared_ptr<mkldnn::memory> data_m;
+ std::shared_ptr<mkldnn::memory> diff_m;
+ std::shared_ptr<mkldnn::memory> gradi_m;
+ std::shared_ptr<mkldnn::memory> mean_m;
+ std::shared_ptr<mkldnn::memory> var_m;
+ const std::shared_ptr<mkldnn::memory> weight_m;
+ const std::shared_ptr<mkldnn::memory> gradw_m;
+
+ public:
+ const t_bn_b_pdesc pd;
+
+ explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
+ : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
+ gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
+ pd(_pd) {}
+
+ const mkldnn::memory &GetWeight() const { return *weight_m; }
+
+ const mkldnn::memory &GetGradw() const { return *gradw_m; }
+
+ void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
+ const NDArray &mean, const mkldnn::memory &var,
+ const mkldnn::memory &gradi) {
+ auto mean_ptr = mean.data().dptr_;
+ if (bwd == nullptr) {
+ data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
+ data.get_data_handle()));
+ diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
+ diff.get_data_handle()));
+ gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
+ gradi.get_data_handle()));
+ mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
+ var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
+ var.get_data_handle()));
+ bwd.reset(new mkldnn::batch_normalization_backward(
+ pd, *data_m, mkldnn::primitive::at(*mean_m),
+ mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
+ *gradw_m));
+ } else {
+ data_m->set_data_handle(data.get_data_handle());
+ diff_m->set_data_handle(diff.get_data_handle());
+ gradi_m->set_data_handle(gradi.get_data_handle());
+ mean_m->set_data_handle(mean_ptr);
+ var_m->set_data_handle(var.get_data_handle());
+ }
+ }
+
+ const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
+};
+
+template <typename DType>
+static MKLDNNBNBackward &GetBNBackward(
+ const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data,
+ const mkldnn::memory &in_mem, const NDArray &diff_data,
+ const mkldnn::memory &diff_mem, unsigned flags) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
+#endif
+ MKLDNNBNSignature key(param);
+ key.AddSign(in_data);
+ key.AddSign(diff_data);
+
+ auto it = bwds.find(key);
+ if (it == bwds.end()) {
+ auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
+ MKLDNNBNBackward bwd(bwd_pd);
+ auto ins_ret =
+ bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
template <typename DType>
void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
const std::vector<NDArray> &out_grad,
@@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
- auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags);
+ auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());
if (flags & use_scale_shift) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
- // TODO(tao): how to reuse this memory?
- std::shared_ptr<const mkldnn::memory> weight_mem(
- new mkldnn::memory(bwd_pd.weights_primitive_desc()));
-
- DType* weight_buf = reinterpret_cast<DType *>(weight_mem->get_data_handle());
+ DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
@@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
}
- std::shared_ptr<const mkldnn::memory> gradw_mem(
- new mkldnn::memory(bwd_pd.diff_weights_primitive_desc()));
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
- mkldnn::memory var_mem(bwd_pd.variance_primitive_desc());
+ mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());
DType minus_mom = (1.0f - param.momentum);
@@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m,
moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
variance * minus_mom;
}
-
- std::shared_ptr<const mkldnn::memory> out_mean_mem(
- new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr));
- std::shared_ptr<const mkldnn::memory> out_var_mem(
- new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr));
-
- auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
- *data_mem,
- mkldnn::primitive::at(*out_mean_mem),
- mkldnn::primitive::at(var_mem),
- *diff_mem,
- *weight_mem,
- *gradi_mem,
- *gradw_mem);
-
- MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+ bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
+ MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
} else {
- std::shared_ptr<const mkldnn::memory> imean_mem(
- new mkldnn::memory(bwd_pd.mean_primitive_desc(),
- moving_mean.data().dptr<DType>()));
- std::shared_ptr<const mkldnn::memory> ivar_mem(
- new mkldnn::memory(bwd_pd.variance_primitive_desc(),
- moving_var.data().dptr<DType>()));
- auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
- *data_mem,
- mkldnn::primitive::at(*imean_mem),
- mkldnn::primitive::at(*ivar_mem),
- *diff_mem,
- *weight_mem,
- *gradi_mem,
- *gradw_mem);
-
- MKLDNNStream::Get()->RegisterPrim(bn_bwd);
+ bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
+ *moving_var.GetMKLDNNData(), *gradi_mem);
+ MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
}
// copy data from gradw_mem to in_grad[1] and in_grad[2]
- DType* gw_buf = reinterpret_cast<DType *>(gradw_mem->get_data_handle());
+ DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
(in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index cf04ea8..2e19d32 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
MKLDNNStream::Get()->Submit();
}
+class MKLDNNConvBackward {
+ std::shared_ptr<mkldnn::convolution_backward_data> bwd_data;
+ std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight;
+ // conv::kData
+ std::shared_ptr<mkldnn::memory> out_grad;
+ std::shared_ptr<mkldnn::memory> in_grad;
+ std::shared_ptr<mkldnn::memory> weight;
+ // conv::kWeight
+ std::shared_ptr<mkldnn::memory> data;
+ std::shared_ptr<mkldnn::memory> output;
+ std::shared_ptr<mkldnn::memory> in_grad_weight;
+ std::shared_ptr<mkldnn::memory> in_grad_bias;
+
+ public:
+ mkldnn::convolution_backward_data::primitive_desc bwdData_pd;
+ mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd;
+
+ MKLDNNConvBackward(
+ const ConvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray *bias, const NDArray &output,
+ const mkldnn::convolution_forward::primitive_desc &fwd_pd):
+ bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)),
+ bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) {
+ }
+
+ void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight,
+ const mkldnn::memory &in_grad) {
+ if (this->out_grad == nullptr)
+ this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+ else
+ this->out_grad->set_data_handle(out_grad.get_data_handle());
+ if (this->in_grad == nullptr)
+ this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
+ else
+ this->in_grad->set_data_handle(in_grad.get_data_handle());
+ if (this->weight == nullptr)
+ this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
+ else
+ this->weight->set_data_handle(weight.get_data_handle());
+ if (this->bwd_data == nullptr)
+ this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>(
+ new mkldnn::convolution_backward_data(
+ this->bwdData_pd, mkldnn::primitive::at(*this->out_grad),
+ mkldnn::primitive::at(*this->weight), *this->in_grad));
+ }
+
+void SetWeightNewMem(const mkldnn::memory &data,
+ const mkldnn::memory &out_grad,
+ const mkldnn::memory &in_grad_weight) {
+ if (this->data == nullptr)
+ this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+ else
+ this->data->set_data_handle(data.get_data_handle());
+ if (this->output == nullptr)
+ this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+ else
+ this->output->set_data_handle(out_grad.get_data_handle());
+ if (this->in_grad_weight == nullptr)
+ this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+ in_grad_weight.get_data_handle()));
+ else
+ this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+ if (this->bwd_weight == nullptr)
+ this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+ new mkldnn::convolution_backward_weights(
+ this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+ mkldnn::primitive::at(*this->output), *this->in_grad_weight));
+ }
+
+ void SetWeightNewMem(const mkldnn::memory &data,
+ const mkldnn::memory &out_grad,
+ const mkldnn::memory &in_grad_weight,
+ const mkldnn::memory &in_grad_bias) {
+ if (this->data == nullptr)
+ this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
+ else
+ this->data->set_data_handle(data.get_data_handle());
+ if (this->output == nullptr)
+ this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
+ else
+ this->output->set_data_handle(out_grad.get_data_handle());
+ if (this->in_grad_weight == nullptr)
+ this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
+ in_grad_weight.get_data_handle()));
+ else
+ this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());
+
+ if (this->in_grad_bias == nullptr)
+ this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(),
+ in_grad_bias.get_data_handle()));
+ else
+ this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());
+ if (this->bwd_weight == nullptr)
+ this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
+ new mkldnn::convolution_backward_weights(
+ this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
+ mkldnn::primitive::at(*this->output), *this->in_grad_weight,
+ *this->in_grad_bias));
+ }
+
+ const mkldnn::convolution_backward_data &GetBwdData() const {
+ return *bwd_data;
+ }
+
+ const mkldnn::convolution_backward_weights &GetBwdWeights() const {
+ return *bwd_weight;
+ }
+};
+
+static inline MKLDNNConvBackward &GetConvBwd(
+ const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights,
+ const NDArray *bias, const NDArray &output,
+ const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
+#endif
+ const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
+ MKLDNNConvSignature key(param);
+ // Here we can sign the conv op with NDArray because conv primitive will
+ // decide the right layout for the, so we only need to get the shape and the
+ // data type of the arrays.
+ key.AddSign(data);
+ key.AddSign(weights);
+ key.AddSign(output);
+ if (bias)
+ key.AddSign(*bias);
+
+ auto it = bwds.find(key);
+ if (it == bwds.end()) {
+ MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
+ auto ins_ret = bwds.insert(
+ std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
@@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]);
CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
- mkldnn::convolution_backward_data::primitive_desc bwdData_pd
- = GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
- inputs[conv::kOut], fwd_pd);
+ MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
+ inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
- bwdData_pd.diff_dst_primitive_desc());
+ convBwd.bwdData_pd.diff_dst_primitive_desc());
if (req[conv::kData]) {
auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
- bwdData_pd.weights_primitive_desc(), param.num_group);
+ convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
- bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
- *out_grad_mem, *weight_mem, *in_grad_mem.second));
+ convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
+ convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+ MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
- mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
- = GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
- param.no_bias ? nullptr : &inputs[conv::kBias + 1],
- inputs[conv::kOut], fwd_pd);
- if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc())
+ MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1],
+ inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1],
+ inputs[conv::kOut], fwd_pd);
+ if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
+ convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
- bwdWeights_pd.diff_dst_primitive_desc());
+ convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
- bwdWeights_pd.src_primitive_desc());
- auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight],
- bwdWeights_pd.diff_weights_primitive_desc(),
- req[conv::kWeight]);
+ convBwdWeight.bwdWeights_pd.src_primitive_desc());
+ auto in_grad_weight = CreateMKLDNNWeightGrad(
+ in_grad[conv::kWeight],
+ convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
+ req[conv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
- MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
- bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
+ convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+ *in_grad_weight.second);
+ MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
} else {
- in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
- bwdWeights_pd.diff_bias_primitive_desc(),
- req[conv::kBias]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
- bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
- *in_grad_bias.second));
+ in_grad_bias = CreateMKLDNNMem(
+ in_grad[conv::kBias],
+ convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
+ convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
+ *in_grad_weight.second, *in_grad_bias.second);
+ MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
CommitOutput(in_grad[conv::kBias], in_grad_bias);
}
CommitOutput(in_grad[conv::kWeight], in_grad_weight);
@@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
} // namespace op
} // namespace mxnet
-
#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 7f3676a..54d4f67 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -93,9 +93,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
return mkldnn::convolution_backward_data::primitive_desc(desc, engine, bwd_pd);
}
-static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
- const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights,
- bool has_bias, const NDArray &output) {
+static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, bool has_bias, const NDArray &output) {
auto data_md = GetMemDesc(data);
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
@@ -116,9 +116,10 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
strides, padding, dilate);
}
-static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
- const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
- bool has_bias, const NDArray &output,
+static mkldnn::convolution_backward_weights::primitive_desc
+GetDeconvBwdWeightsImpl(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, bool has_bias, const NDArray &output,
const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetWeightDesc(weights, param.num_group);
@@ -308,55 +309,203 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
}
-void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
+class MKLDNNDeconvBackwardData {
+ std::shared_ptr<mkldnn::convolution_forward> bwd;
+ std::shared_ptr<mkldnn::memory> data;
+ std::shared_ptr<mkldnn::memory> weight;
+ std::shared_ptr<mkldnn::memory> out;
+
+ public:
+ const mkldnn::convolution_forward::primitive_desc pd;
+
+ MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output)
+ : pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) {
+ }
+
+ void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+ const mkldnn::memory &output) {
+ if (bwd == nullptr) {
+ this->data = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(pd.src_primitive_desc(), data.get_data_handle()));
+ this->weight = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(pd.weights_primitive_desc(), weight.get_data_handle()));
+ this->out = std::shared_ptr<mkldnn::memory>(
+ new mkldnn::memory(pd.dst_primitive_desc(), output.get_data_handle()));
+ bwd = std::shared_ptr<mkldnn::convolution_forward>(
+ new mkldnn::convolution_forward(pd, mkldnn::primitive::at(*this->data),
+ mkldnn::primitive::at(*this->weight),
+ *this->out));
+ } else {
+ this->data->set_data_handle(data.get_data_handle());
+ this->weight->set_data_handle(weight.get_data_handle());
+ this->out->set_data_handle(output.get_data_handle());
+ }
+ }
+
+ const mkldnn::convolution_forward &GetBwd() const { return *bwd; }
+};
+
+typedef ParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;
+
+static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNDeconvSignature,
+ MKLDNNDeconvBackwardData, OpHash>
+ bwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+ MKLDNNDeconvBackwardData, OpHash>
+ bwds;
+#endif
+ MKLDNNDeconvSignature key(param);
+ // Here we can sign the conv op with NDArray because conv primitive will
+ // decide the right layout for the, so we only need to get the shape and the
+ // data type of the arrays.
+ key.AddSign(data);
+ key.AddSign(weights);
+ key.AddSign(output);
+
+ auto it = bwds.find(key);
+ if (it == bwds.end()) {
+ MKLDNNDeconvBackwardData bwd(param, data, weights, output);
+ auto ins_ret = bwds.insert(
+ std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
+class MKLDNNDeconvBackwardWeights {
+ std::shared_ptr<mkldnn::convolution_backward_weights> bwd;
+ std::shared_ptr<mkldnn::memory> data;
+ std::shared_ptr<mkldnn::memory> weight;
+ std::shared_ptr<mkldnn::memory> out;
+
+ public:
+ const mkldnn::convolution_backward_weights::primitive_desc pd;
+
+ MKLDNNDeconvBackwardWeights(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output,
+ const mkldnn::convolution_forward::primitive_desc &bwd_data_pd)
+ : pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output,
+ bwd_data_pd)) {}
+
+ void SetNewMem(
+ const mkldnn::memory &data, const mkldnn::memory &weight,
+ const mkldnn::memory &output,
+ const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+ if (bwd == nullptr) {
+ this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwd_data_pd.src_primitive_desc(), data.get_data_handle()));
+ this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwd_data_pd.weights_primitive_desc(), weight.get_data_handle()));
+ this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ bwd_data_pd.dst_primitive_desc(), output.get_data_handle()));
+ bwd = std::shared_ptr<mkldnn::convolution_backward_weights>(
+ new mkldnn::convolution_backward_weights(pd, *this->data,
+ *this->weight, *this->out));
+ } else {
+ this->data->set_data_handle(data.get_data_handle());
+ this->weight->set_data_handle(weight.get_data_handle());
+ this->out->set_data_handle(output.get_data_handle());
+ }
+ }
+
+ const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; }
+};
+
+static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights(
+ const DeconvolutionParam ¶m, const NDArray &data,
+ const NDArray &weights, const NDArray &output,
+ const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNDeconvSignature,
+ MKLDNNDeconvBackwardWeights, OpHash>
+ bwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNDeconvSignature,
+ MKLDNNDeconvBackwardWeights, OpHash>
+ bwds;
+#endif
+ MKLDNNDeconvSignature key(param);
+ // Here we can sign the conv op with NDArray because conv primitive will
+ // decide the right layout for the, so we only need to get the shape and the
+ // data type of the arrays.
+ key.AddSign(data);
+ key.AddSign(weights);
+ key.AddSign(output);
+
+ auto it = bwds.find(key);
+ if (it == bwds.end()) {
+ MKLDNNDeconvBackwardWeights bwd(param, data, weights, output, bwd_data_pd);
+ auto ins_ret = bwds.insert(
+ std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardWeights>(key,
+ bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
+void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const std::vector<NDArray> &in_grad = outputs;
- const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
- CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace";
- mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData(
- param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], false,
- inputs[deconv::kOut]);
+ const DeconvolutionParam ¶m = nnvm::get<DeconvolutionParam>(attrs.parsed);
+ CHECK_NE(req[deconv::kWeight], kWriteInplace)
+ << "cannot write weight inplace";
+ MKLDNNDeconvBackwardData &bwd_data =
+ GetDeconvBwdData(param, inputs[deconv::kData + 1],
+ inputs[deconv::kWeight + 1], inputs[deconv::kOut]);
auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
- bwdData_pd.src_primitive_desc());
+ bwd_data.pd.src_primitive_desc());
if (req[deconv::kData]) {
- auto weight_mem = GetWeights(inputs[deconv::kWeight + 1],
- bwdData_pd.weights_primitive_desc(),
- param.num_group);
- auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData],
- bwdData_pd.dst_primitive_desc(),
- req[deconv::kData]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_forward(bwdData_pd,
- *out_grad_mem, *weight_mem, *in_grad_mem.second));
+ auto weight_mem =
+ GetWeights(inputs[deconv::kWeight + 1],
+ bwd_data.pd.weights_primitive_desc(), param.num_group);
+ auto in_grad_mem =
+ CreateMKLDNNMem(in_grad[deconv::kData],
+ bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
+ bwd_data.SetNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
+ MKLDNNStream::Get()->RegisterPrim(bwd_data.GetBwd());
CommitOutput(in_grad[deconv::kData], in_grad_mem);
}
if (req[deconv::kWeight]) {
- mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
- = GetDeconvBwdWeights(param, inputs[deconv::kData + 1],
- inputs[deconv::kWeight + 1], false, inputs[deconv::kOut], bwdData_pd);
- if (bwdData_pd.src_primitive_desc() != bwdWeights_pd.src_primitive_desc())
+ MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
+ param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1],
+ inputs[deconv::kOut], bwd_data.pd);
+ if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc())
out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
- bwdWeights_pd.src_primitive_desc());
+ bwd_weights.pd.src_primitive_desc());
auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
- bwdWeights_pd.diff_dst_primitive_desc());
- auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight],
- bwdWeights_pd.diff_weights_primitive_desc(),
- req[deconv::kWeight]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
- bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second));
+ bwd_weights.pd.diff_dst_primitive_desc());
+ auto in_grad_weight = CreateMKLDNNWeightGrad(
+ in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
+ req[deconv::kWeight]);
+ bwd_weights.SetNewMem(*out_grad_mem, *data_mem, *in_grad_weight.second, bwd_data.pd);
+ MKLDNNStream::Get()->RegisterPrim(bwd_weights.GetBwd());
CommitOutput(in_grad[deconv::kWeight], in_grad_weight);
}
MKLDNNStream::Get()->Submit();
if (!param.no_bias) {
typedef float DType;
Stream<cpu> *s = ctx.get_stream<cpu>();
- Tensor<cpu, 1, DType> gbias = in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
+ Tensor<cpu, 1, DType> gbias =
+ in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
// If there is bias, the out grad has already been converted to the default
// format, so this shouldn't cause any performance issues.
- Tensor<cpu, 4, DType> grad = inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
- Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad));
+ Tensor<cpu, 4, DType> grad =
+ inputs[deconv::kOut].data().get<cpu, 4, DType>(s);
+ Assign(gbias, req[deconv::kBias],
+ mshadow::expr::sumall_except_dim<1>(grad));
}
}
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index 4b179a7..bc386be 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -18,7 +18,7 @@
*/
/*!
- * \file mkldnn_lrn-inl.h
+ * \file mkldnn_lrn-inl.h
* \brief
* \Author: Patric Zhao, patric.zhao@intel.com
*/
@@ -40,9 +40,8 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) {
return algorithm::lrn_across_channels;
}
-inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m,
- const bool is_train,
- const memory::desc &src_md) {
+inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
+ const LRNParam ¶m, const bool is_train, const memory::desc &src_md) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
@@ -59,11 +58,10 @@ inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m,
return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
}
-inline mkldnn::lrn_backward::primitive_desc
-GetLRNBwd(const LRNParam ¶m,
- const mkldnn::memory::desc &data_in_md,
- const mkldnn::memory::desc &diff_md,
- const lrn_forward::primitive_desc &lrnFwd_desc) {
+inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
+ const LRNParam ¶m, const mkldnn::memory::desc &data_in_md,
+ const mkldnn::memory::desc &diff_md,
+ const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
@@ -96,8 +94,15 @@ class MKLDNNLRNFwd {
const NDArray &output,
const OpReqType req);
+ void SetNewMem(const NDArray &in_data,
+ const mkldnn::memory *out_mem);
+
void Execute(const NDArray &out_data);
+ mkldnn::lrn_forward &GetFwd();
+
+ const mkldnn::memory *GetWs();
+
private:
std::shared_ptr<mkldnn::lrn_forward> fwd;
std::shared_ptr<mkldnn::memory> in_mem;
@@ -113,15 +118,17 @@ class MKLDNNLRNFwd {
void MKLDNNLRNFwd::_Init(const LRNParam ¶m,
bool is_train,
const NDArray &in_data) {
- mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc();
- lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md);
+ mkldnn::memory::desc in_data_md =
+ in_data.GetMKLDNNData()->get_primitive_desc().desc();
+ mkldnn::lrn_forward::primitive_desc fwd_pd =
+ GetLRNFwdDesc(param, is_train, in_data_md);
this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
->get_primitive_desc()));
this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
if (is_train) {
- // If it's training, we have to create a workspace memory. Otherwise, MKLDNN
- // will have segmentation fault.
+ // If it's training, we have to create a workspace memory. Otherwise,
+ // MKLDNN will have segmentation fault.
ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
@@ -142,11 +149,22 @@ void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
}
+void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
+ const mkldnn::memory *out_mem) {
+ const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
+ this->in_mem->set_data_handle(in_data_mem->get_data_handle());
+ this->out_mem->set_data_handle(out_mem->get_data_handle());
+}
+
void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
CommitOutput(out_data, output_mem_t);
MKLDNNStream::Get()->Submit();
}
+
+mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
+
+const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
// End of LRN Class and its functions
static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
@@ -161,16 +179,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
MKLDNNLRNFwd,
OpHash> lrn_fwds;
#endif
- auto alg_ = algorithm::lrn_across_channels;
- auto kind_ = prop_kind::forward_training;
- if (ctx.is_train) {
- kind_ = prop_kind::forward_training;
- } else {
- kind_ = prop_kind::forward_scoring;
- }
+ auto kind_ =
+ ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
MKLDNNLRNSignature key(param);
- key.AddSign(alg_);
key.AddSign(kind_);
key.AddSign(in_data);
@@ -185,10 +197,8 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
return it->second;
}
-void MKLDNNLRNForward(const OpContext &ctx,
- const LRNParam ¶m,
- const NDArray &in_data,
- const OpReqType req,
+void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m,
+ const NDArray &in_data, const OpReqType req,
const NDArray &out_data) {
auto in_buffer = in_data;
if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
@@ -198,6 +208,90 @@ void MKLDNNLRNForward(const OpContext &ctx,
fwd.Execute(out_data);
}
+// LRN Backward Class
+class MKLDNNLRNBwd {
+ std::shared_ptr<mkldnn::lrn_backward> bwd;
+ std::shared_ptr<mkldnn::memory> in_data_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+ std::shared_ptr<mkldnn::memory> ws_mem;
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+
+ public:
+ const mkldnn::lrn_forward::primitive_desc fwd_pd;
+ const mkldnn::lrn_backward::primitive_desc bwd_pd;
+
+ ~MKLDNNLRNBwd() {}
+
+ MKLDNNLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc in_data_md,
+ const mkldnn::memory::desc diff_md)
+ : fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
+ bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}
+
+ void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
+ const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
+ if (bwd == nullptr) {
+ this->in_data_mem.reset(
+ new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
+ in_data.GetMKLDNNData()->get_data_handle()));
+ this->diff_dst_mem.reset(
+ new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
+ out_grad.GetMKLDNNData()->get_data_handle()));
+ this->ws_mem.reset(
+ new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
+ ws->get_data_handle()));
+ this->diff_src_mem.reset(
+ new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
+ diff_src_mem->get_data_handle()));
+ this->bwd.reset(new mkldnn::lrn_backward(
+ this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
+ mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
+ *this->diff_src_mem));
+ } else {
+ this->in_data_mem->set_data_handle(
+ in_data.GetMKLDNNData()->get_data_handle());
+ this->diff_dst_mem->set_data_handle(
+ out_grad.GetMKLDNNData()->get_data_handle());
+ this->ws_mem->set_data_handle(ws->get_data_handle());
+ this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
+ }
+ }
+
+ void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
+ MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
+ CommitOutput(in_grad, diff_src_mem_);
+ MKLDNNStream::Get()->Submit();
+ }
+}; // End of LRN Class
+
+static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data,
+ const NDArray &in_grad, const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local
+ std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#else
+ static MX_THREAD_LOCAL
+ std::unordered_map<MKLDNNLRNSignature, MKLDNNLRNBwd, OpHash> lrn_bwds;
+#endif
+ MKLDNNLRNSignature key(param);
+ key.AddSign(in_data);
+ key.AddSign(in_grad);
+ key.AddSign(out_grad);
+
+ auto it = lrn_bwds.find(key);
+ if (it == lrn_bwds.end()) {
+ const mkldnn::memory::desc in_data_md =
+ in_data.GetMKLDNNData()->get_primitive_desc().desc();
+ const mkldnn::memory::desc diff_md =
+ out_grad.GetMKLDNNData()->get_primitive_desc().desc();
+ MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
+ auto ins_ret =
+ lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m,
const NDArray &out_grad,
const NDArray &in_data,
@@ -206,43 +300,27 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m,
if (req == kNullOp) {
return;
}
-
// TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c
auto in_buffer = in_data;
if (in_buffer.IsMKLDNNData()) {
in_buffer = in_data.Reorder2Default();
}
-
+ MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
// Repeat FW for getting workspace
- const mkldnn::memory *data_mem = in_buffer.GetMKLDNNData();
- const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
- const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train,
- data_md);
-
// TODO(Patric): To keep the function stateless, we can't pass workspace
// from LRN forward to backward. We have to re-compute
// LRN forward to get the workspace.
// Will refine this code later.
- std::shared_ptr<const mkldnn::memory> ws_mem(
- new mkldnn::memory(pdesc_fwd.workspace_primitive_desc()));
+ MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
std::shared_ptr<const mkldnn::memory> dst_temp(
- new mkldnn::memory(pdesc_fwd.dst_primitive_desc()));
- MKLDNNStream::Get()->RegisterPrim(
- lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
- *ws_mem, *dst_temp));
-
- const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData();
- const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
- const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_md,
- diff_md, pdesc_fwd);
- mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
- pdesc_bwd.diff_src_primitive_desc(), req);
-
- MKLDNNStream::Get()->RegisterPrim(
- lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem),
- mkldnn::primitive::at(*diff_mem), *ws_mem, *diff_src_mem.second));
- CommitOutput(in_grad, diff_src_mem);
- MKLDNNStream::Get()->Submit();
+ new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
+ fwd.SetNewMem(in_buffer, dst_temp.get());
+ MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+
+ mkldnn_output_t diff_src_mem =
+ CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
+ bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
+ bwd.Execute(in_grad, diff_src_mem);
}
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 5d349d3..6667961 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -80,6 +80,27 @@ class MKLDNNPoolingFwd {
const int padding_l, const int padding_r);
};
+class MKLDNNPoolingBwd {
+ std::shared_ptr<const mkldnn::pooling_backward> bwd;
+ std::shared_ptr<mkldnn::memory> diff_dst;
+ std::shared_ptr<mkldnn::memory> diff_src;
+ std::shared_ptr<mkldnn::memory> ws;
+ bool with_workspace;
+
+ public:
+ const mkldnn::pooling_backward::primitive_desc pd;
+
+ MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
+ bool with_ws);
+
+ ~MKLDNNPoolingBwd() {}
+ void SetNewMem(const mxnet::NDArray *workspace,
+ const mxnet::NDArray &out_grad,
+ const mkldnn::memory *diff_src_mem);
+ const mkldnn::pooling_backward &GetBwd();
+ const mkldnn::pooling_backward::primitive_desc &GetPd();
+};
+
inline bool SupportMKLDNNPooling(const PoolingParam ¶m) {
return param.kernel.ndim() == 2 &&
(param.pool_type == pool_enum::kMaxPooling ||
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index d8d65ba..1610944 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -134,10 +134,9 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) {
}
}
-mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m,
- const bool is_train,
- const memory::desc &data_md,
- const memory::desc &out_md) {
+mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
+ const PoolingParam ¶m, const bool is_train, const memory::desc &data_md,
+ const memory::desc &out_md) {
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
int kernel_h_, kernel_w_;
if (param.global_pool) {
@@ -255,11 +254,124 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m,
void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m,
const NDArray &in_data, const OpReqType req,
const NDArray &out_data, const NDArray *workspace) {
- auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
+ auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
fwd.SetNewMem(in_data, out_data, req, workspace);
fwd.Execute(out_data);
}
+MKLDNNPoolingBwd::MKLDNNPoolingBwd(
+ const pooling_backward::primitive_desc &pdesc, bool with_ws)
+ : with_workspace(with_ws), pd(pdesc) {}
+
+void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
+ const mxnet::NDArray &out_grad,
+ const mkldnn::memory *diff_src_mem) {
+ if (bwd == nullptr) {
+ diff_dst.reset(
+ new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
+ out_grad.GetMKLDNNData()->get_data_handle()));
+ diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
+ diff_src_mem->get_data_handle()));
+ if (with_workspace) {
+ CHECK(workspace != nullptr);
+ ws.reset(
+ new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
+ workspace->GetMKLDNNData()->get_data_handle()));
+ bwd.reset(
+ new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
+ } else {
+ bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
+ }
+ } else {
+ diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
+ diff_src->set_data_handle(diff_src_mem->get_data_handle());
+ if (with_workspace) {
+ CHECK(workspace != nullptr);
+ ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
+ }
+ }
+}
+
+const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
+ return *this->bwd;
+}
+
+MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m,
+ const NDArray &in_data,
+ const NDArray &in_grad,
+ const NDArray &out_grad) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local
+ std::unordered_map<MKLDNNPoolingSignature,
+ MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#else
+ static MX_THREAD_LOCAL
+ std::unordered_map<MKLDNNPoolingSignature,
+ MKLDNNPoolingBwd, OpHash> pooling_bwds;
+#endif
+
+ bool with_workspace = MKLDNNRequireWorkspace(param);
+ MKLDNNPoolingSignature key(param);
+ key.AddSign(in_data);
+ key.AddSign(in_grad);
+ key.AddSign(out_grad);
+
+ auto it = pooling_bwds.find(key);
+ if (it == pooling_bwds.end()) {
+ auto diff_dst_mem = out_grad.GetMKLDNNData();
+ auto input_mem = in_data.GetMKLDNNData();
+ mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
+ const mkldnn::memory::desc data_md = data_mpd.desc();
+ const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
+ static_cast<int>(out_grad.shape()[2]),
+ static_cast<int>(out_grad.shape()[3])};
+ const memory::desc out_md(
+ {dims}, static_cast<memory::data_type>(data_md.data.data_type),
+ static_cast<memory::format>(data_md.data.format));
+ auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
+
+ const mkldnn::memory::desc diff_md =
+ diff_dst_mem->get_primitive_desc().desc();
+ const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
+ static_cast<int>(in_grad.shape()[2]),
+ static_cast<int>(in_grad.shape()[3])};
+ const memory::desc diff_in_md(
+ {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
+ static_cast<memory::format>(diff_md.data.format));
+ const mkldnn::engine cpu_engine = data_mpd.get_engine();
+ const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+
+ int kernel_h_, kernel_w_;
+ if (param.global_pool) {
+ kernel_h_ = data_md.data.dims[2];
+ kernel_w_ = data_md.data.dims[3];
+ } else {
+ kernel_h_ = param.kernel[0];
+ kernel_w_ = param.kernel[1];
+ }
+
+ int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
+ int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
+ int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
+ if (param.global_pool) {
+ pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
+ stride_h_ = stride_w_ = 1;
+ }
+
+ const pooling_backward::desc desc(
+ alg, diff_in_md, diff_md, {stride_h_, stride_w_},
+ {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
+ mkldnn::padding_kind::zero);
+ const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
+ MKLDNNPoolingBwd bwd(pdesc, with_workspace);
+ auto ins_ret = pooling_bwds.insert(
+ std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m,
const NDArray &out_grad, const NDArray &in_data,
const NDArray *workspace, const OpReqType req,
@@ -267,68 +379,14 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m,
if (req == kNullOp) {
return;
}
-
TmpMemMgr::Get()->Init(ctx.requested[0]);
- // mkldnn::memory
- auto diff_dst_mem = out_grad.GetMKLDNNData();
- auto input_mem = in_data.GetMKLDNNData();
- mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
- const mkldnn::memory::desc data_md = data_mpd.desc();
- const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
- static_cast<int>(out_grad.shape()[2]),
- static_cast<int>(out_grad.shape()[3])};
- const memory::desc out_md({dims},
- static_cast<memory::data_type>(data_md.data.data_type),
- static_cast<memory::format>(data_md.data.format));
- auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md);
-
- const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc();
- const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
- static_cast<int>(in_grad.shape()[2]),
- static_cast<int>(in_grad.shape()[3])};
- const memory::desc diff_in_md(
- {dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
- static_cast<memory::format>(diff_md.data.format));
- const mkldnn::engine cpu_engine = data_mpd.get_engine();
- const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
-
- int kernel_h_, kernel_w_;
- if (param.global_pool) {
- kernel_h_ = data_md.data.dims[2];
- kernel_w_ = data_md.data.dims[3];
- } else {
- kernel_h_ = param.kernel[0];
- kernel_w_ = param.kernel[1];
- }
-
- int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
- int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
- int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
- if (param.global_pool) {
- pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
- stride_h_ = stride_w_ = 1;
- }
-
- const pooling_backward::desc desc(alg, diff_in_md, diff_md,
- {stride_h_, stride_w_},
- {kernel_h_, kernel_w_},
- {pad_t_, pad_l_}, {pad_b_, pad_r_},
- mkldnn::padding_kind::zero);
- const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd);
+ auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto diff_src_mem =
- CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req);
-
- if (MKLDNNRequireWorkspace(param)) {
- CHECK(workspace != nullptr);
- auto workspace_mem = workspace->GetMKLDNNData();
- MKLDNNStream::Get()->RegisterPrim(
- pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem),
- *diff_src_mem.second));
- } else {
- MKLDNNStream::Get()->RegisterPrim(
- pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second));
- }
+ CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
+
+ bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
+ MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
}