You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by we...@apache.org on 2022/06/01 20:31:45 UTC

[incubator-mxnet] branch v1.9.x updated: Fix next_impl in deconvolution (#20750) (#21051)

This is an automated email from the ASF dual-hosted git repository.

weichu pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new ac61ad2b7f Fix next_impl in deconvolution (#20750) (#21051)
ac61ad2b7f is described below

commit ac61ad2b7f38dd35c3f4af3223774a3c427ed135
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Wed Jun 1 13:31:30 2022 -0700

    Fix next_impl in deconvolution (#20750) (#21051)
    
    Co-authored-by: Paweł Głomski <pa...@intel.com>
---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 52 +++++++++++++++-----------
 1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 21608153bd..43423e792d 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -76,9 +76,11 @@ MKLDNNDeconvFwd &MKLDNNDeconvFwd::GetCached(const DeconvolutionParam &param,
 std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
     const DeconvolutionParam &param, const Tensors &tensors) {
   DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd = std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
-  const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+  auto fwd_desc = ddc.CreateFwdDesc();  // `fwd_desc` lifetime must be longer than `pd`
+                                        // when using next_impl
+  const auto& engine          = CpuEngine::Get()->get_engine();
+  const auto pd               = std::make_shared<deconv_fwd_pd_t>(fwd_desc, engine);
+  const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
 
@@ -88,7 +90,8 @@ std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc(
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of deconvolution forward propagation";
-      *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+      fwd_desc = ddc.CreateFwdDesc();
+      *pd      = deconv_fwd_pd_t(fwd_desc, engine);
     }
   }
   return pd;
@@ -201,13 +204,16 @@ MKLDNNDeconvBwd &MKLDNNDeconvBwd::GetCached(const DeconvolutionParam &param,
 }
 
 std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
-    const DeconvolutionParam &param, const ReadTensors &read_tensors,
-    const deconv_fwd_pd_t &fwd_pd) {
-  DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, nullptr,
-                        read_tensors.out_grad);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd = std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
-  const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); };
+    const DeconvolutionParam& param,
+    const ReadTensors& read_tensors,
+    const deconv_fwd_pd_t& fwd_pd) {
+  DeconvDescCreator ddc(
+      param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad);
+  auto bwd_d_desc = ddc.CreateBwdDataDesc();  // `bwd_d_desc` lifetime must be longer than `pd`
+                                              // when using next_impl
+  const auto& engine = CpuEngine::Get()->get_engine();
+  const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
+  const auto get_data_size    = [&pd]() { return pd->diff_src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
 
@@ -217,21 +223,24 @@ std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of deconvolution backward propagation";
-      *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+      bwd_d_desc = ddc.CreateBwdDataDesc();
+      *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
     }
   }
   return pd;
 }
 
 std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc(
-    const DeconvolutionParam &param, const ReadTensors &read_tensors,
-    const deconv_fwd_pd_t &fwd_pd) {
-  DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, read_tensors.bias,
-                        read_tensors.out_grad);
-  const auto &engine = CpuEngine::Get()->get_engine();
-  const auto pd =
-      std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
-  const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+    const DeconvolutionParam& param,
+    const ReadTensors& read_tensors,
+    const deconv_fwd_pd_t& fwd_pd) {
+  DeconvDescCreator ddc(
+      param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad);
+  auto bwd_w_desc = ddc.CreateBwdWeightsDesc();  // `bwd_w_desc` lifetime must be longer than `pd`
+                                                 // when using next_impl
+  const auto& engine = CpuEngine::Get()->get_engine();
+  const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
+  const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); };
   const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };
 
@@ -241,7 +250,8 @@ std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitive
       // imposed, meaning there is no implementation with plain formats
       CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
           << "No implementation of calculating deconvolution weights gradient";
-      *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+      bwd_w_desc = ddc.CreateBwdWeightsDesc();
+      *pd        = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
     }
   }
   return pd;