This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new f88b03e  Fix next_impl in deconvolution (#20750)
f88b03e is described below

commit f88b03e49c0d807c61f7d6ff5f623284fcbb30ce
Author: Paweł Głomski <[email protected]>
AuthorDate: Tue Dec 14 08:06:44 2021 +0100

    Fix next_impl in deconvolution (#20750)
---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 55 +++++++++++++++-----------
 1 file changed, 33 insertions(+), 22 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc 
b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 0776ee4..9783c4b 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -76,19 +76,23 @@ std::shared_ptr<deconv_fwd_pd_t> 
MKLDNNDeconvFwd::CreatePrimitiveDesc(
     const DeconvolutionParam& param,
     const Tensors& tensors) {
   DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, 
tensors.out);
+  auto fwd_desc = ddc.CreateFwdDesc();  // `fwd_desc` lifetime must be longer 
than `pd`
+                                        // when using next_impl
   const auto& engine          = CpuEngine::Get()->get_engine();
-  const auto pd               = 
std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
+  const auto pd               = std::make_shared<deconv_fwd_pd_t>(fwd_desc, 
engine);
   const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return pd->dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have
-    // plain formats imposed, meaning there is no implementation with plain
-    // formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of deconvolution forward propagation";
-    *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of deconvolution forward propagation";
+      fwd_desc = ddc.CreateFwdDesc();
+      *pd      = deconv_fwd_pd_t(fwd_desc, engine);
+    }
   }
   return pd;
 }
@@ -213,19 +217,23 @@ std::shared_ptr<deconv_bwd_data_pd_t> 
MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
     const deconv_fwd_pd_t& fwd_pd) {
   DeconvDescCreator ddc(
       param, read_tensors.data, read_tensors.weights, nullptr, 
read_tensors.out_grad);
+  auto bwd_d_desc = ddc.CreateBwdDataDesc();  // `bwd_d_desc` lifetime must be 
longer than `pd`
+                                              // when using next_impl
   const auto& engine = CpuEngine::Get()->get_engine();
-  const auto pd = 
std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+  const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, 
fwd_pd);
   const auto get_data_size    = [&pd]() { return 
pd->diff_src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have
-    // plain formats imposed, meaning there is no implementation with plain
-    // formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of deconvolution backward propagation";
-    *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of deconvolution backward propagation";
+      bwd_d_desc = ddc.CreateBwdDataDesc();
+      *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
+    }
   }
   return pd;
 }
@@ -236,20 +244,23 @@ std::shared_ptr<deconv_bwd_weights_pd_t> 
MKLDNNDeconvBwd::CreateWeightsPrimitive
     const deconv_fwd_pd_t& fwd_pd) {
   DeconvDescCreator ddc(
       param, read_tensors.data, read_tensors.weights, read_tensors.bias, 
read_tensors.out_grad);
+  auto bwd_w_desc = ddc.CreateBwdWeightsDesc();  // `bwd_w_desc` lifetime must 
be longer than `pd`
+                                                 // when using next_impl
   const auto& engine = CpuEngine::Get()->get_engine();
-  const auto pd =
-      std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), 
engine, fwd_pd);
+  const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, 
engine, fwd_pd);
   const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->diff_weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have
-    // plain formats imposed, meaning there is no implementation with plain
-    // formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of calculating deconvolution weights gradient";
-    *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of calculating deconvolution weights gradient";
+      bwd_w_desc = ddc.CreateBwdWeightsDesc();
+      *pd        = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
+    }
   }
   return pd;
 }

Reply via email to