You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by we...@apache.org on 2022/06/01 20:31:45 UTC
[incubator-mxnet] branch v1.9.x updated: Fix next_impl in deconvolution (#20750) (#21051)
This is an automated email from the ASF dual-hosted git repository.
weichu pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new ac61ad2b7f Fix next_impl in deconvolution (#20750) (#21051)
ac61ad2b7f is described below
commit ac61ad2b7f38dd35c3f4af3223774a3c427ed135
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Wed Jun 1 13:31:30 2022 -0700
Fix next_impl in deconvolution (#20750) (#21051)
Co-authored-by: Paweł Głomski <pa...@intel.com>
---
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 52 +++++++++++++++-----------
1 file changed, 31 insertions(+), 21 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 21608153bd..43423e792d 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -76,9 +76,11 @@ MKLDNNDeconvFwd &MKLDNNDeconvFwd::GetCached(const DeconvolutionParam ¶m,
std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
const DeconvolutionParam ¶m, const Tensors &tensors) {
DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out);
- const auto &engine = CpuEngine::Get()->get_engine();
- const auto pd = std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
- const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+ auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer than `pd`
+ // when using next_impl
+ const auto& engine = CpuEngine::Get()->get_engine();
+ const auto pd = std::make_shared<deconv_fwd_pd_t>(fwd_desc, engine);
+ const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
@@ -88,7 +90,8 @@ std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution forward propagation";
- *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+ fwd_desc = ddc.CreateFwdDesc();
+ *pd = deconv_fwd_pd_t(fwd_desc, engine);
}
}
return pd;
@@ -201,13 +204,16 @@ MKLDNNDeconvBwd &MKLDNNDeconvBwd::GetCached(const DeconvolutionParam ¶m,
}
std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
- const DeconvolutionParam ¶m, const ReadTensors &read_tensors,
- const deconv_fwd_pd_t &fwd_pd) {
- DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, nullptr,
- read_tensors.out_grad);
- const auto &engine = CpuEngine::Get()->get_engine();
- const auto pd = std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
- const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); };
+ const DeconvolutionParam& param,
+ const ReadTensors& read_tensors,
+ const deconv_fwd_pd_t& fwd_pd) {
+ DeconvDescCreator ddc(
+ param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad);
+ auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be longer than `pd`
+ // when using next_impl
+ const auto& engine = CpuEngine::Get()->get_engine();
+ const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
+ const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
@@ -217,21 +223,24 @@ std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution backward propagation";
- *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+ bwd_d_desc = ddc.CreateBwdDataDesc();
+ *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
}
}
return pd;
}
std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc(
- const DeconvolutionParam ¶m, const ReadTensors &read_tensors,
- const deconv_fwd_pd_t &fwd_pd) {
- DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, read_tensors.bias,
- read_tensors.out_grad);
- const auto &engine = CpuEngine::Get()->get_engine();
- const auto pd =
- std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
- const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+ const DeconvolutionParam& param,
+ const ReadTensors& read_tensors,
+ const deconv_fwd_pd_t& fwd_pd) {
+ DeconvDescCreator ddc(
+ param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad);
+ auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must be longer than `pd`
+ // when using next_impl
+ const auto& engine = CpuEngine::Get()->get_engine();
+ const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
+ const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
@@ -241,7 +250,8 @@ std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitive
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of calculating deconvolution weights gradient";
- *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+ bwd_w_desc = ddc.CreateBwdWeightsDesc();
+ *pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
}
}
return pd;