You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by we...@apache.org on 2022/05/11 22:57:39 UTC

[incubator-mxnet] 01/01: Fix next_impl in deconvolution (#20750)

This is an automated email from the ASF dual-hosted git repository.

weichu pushed a commit to branch v1.9.1-test
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 681a130503c52e0b415670d2a7bda5d5b713956b
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Tue Dec 14 08:06:44 2021 +0100

    Fix next_impl in deconvolution (#20750)
---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 52 +++++++++++++++-----------
 1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 21608153bd..43423e792d 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -76,9 +76,11 @@ MKLDNNDeconvFwd &MKLDNNDeconvFwd::GetCached(const DeconvolutionParam &param,
 std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
     const DeconvolutionParam &param, const Tensors &tensors) {
   DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd = std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
-  const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+  auto fwd_desc = ddc.CreateFwdDesc();  // `fwd_desc` lifetime must be longer than `pd`
+                                        // when using next_impl
+  const auto& engine          = CpuEngine::Get()->get_engine();
+  const auto pd               = std::make_shared<deconv_fwd_pd_t>(fwd_desc, engine);
+  const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
 
@@ -88,7 +90,8 @@ std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of deconvolution forward propagation";
-      *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+      fwd_desc = ddc.CreateFwdDesc();
+      *pd      = deconv_fwd_pd_t(fwd_desc, engine);
     }
   }
   return pd;
@@ -201,13 +204,16 @@ MKLDNNDeconvBwd &MKLDNNDeconvBwd::GetCached(const DeconvolutionParam &param,
 }
 
 std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
-    const DeconvolutionParam &param, const ReadTensors &read_tensors,
-    const deconv_fwd_pd_t &fwd_pd) {
-  DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, nullptr,
-                        read_tensors.out_grad);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd = std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
-  const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); };
+    const DeconvolutionParam& param,
+    const ReadTensors& read_tensors,
+    const deconv_fwd_pd_t& fwd_pd) {
+  DeconvDescCreator ddc(
+      param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad);
+  auto bwd_d_desc = ddc.CreateBwdDataDesc();  // `bwd_d_desc` lifetime must be longer than `pd`
+                                              // when using next_impl
+  const auto& engine = CpuEngine::Get()->get_engine();
+  const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
+  const auto get_data_size    = [&pd]() { return pd->diff_src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
 
@@ -217,21 +223,24 @@ std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of deconvolution backward propagation";
-      *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+      bwd_d_desc = ddc.CreateBwdDataDesc();
+      *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
     }
   }
   return pd;
 }
 
 std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc(
-    const DeconvolutionParam &param, const ReadTensors &read_tensors,
-    const deconv_fwd_pd_t &fwd_pd) {
-  DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, read_tensors.bias,
-                        read_tensors.out_grad);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd =
-      std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
-  const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+    const DeconvolutionParam& param,
+    const ReadTensors& read_tensors,
+    const deconv_fwd_pd_t& fwd_pd) {
+  DeconvDescCreator ddc(
+      param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad);
+  auto bwd_w_desc = ddc.CreateBwdWeightsDesc();  // `bwd_w_desc` lifetime must be longer than `pd`
+                                                 // when using next_impl
+  const auto& engine = CpuEngine::Get()->get_engine();
+  const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
+  const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
 
@@ -241,7 +250,8 @@ std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitive
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of calculating deconvolution weights gradient";
-      *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+      bwd_w_desc = ddc.CreateBwdWeightsDesc();
+      *pd        = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
     }
   }
   return pd;