This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 930e140  Reintroduce next_impl in onednn deconvolution (#20663)
930e140 is described below

commit 930e14047ba42bf9519f308069e07eee7ef7a687
Author: Paweł Głomski <[email protected]>
AuthorDate: Fri Nov 19 17:39:54 2021 +0100

    Reintroduce next_impl in onednn deconvolution (#20663)
---
 src/operator/nn/dnnl/dnnl_deconvolution.cc | 58 ++++++++++++++++++------------
 1 file changed, 36 insertions(+), 22 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc 
b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index f4766a1..b853d1a 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -75,18 +75,23 @@ DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const 
DeconvolutionParam& param, const T
 std::shared_ptr<deconv_fwd_pd_t> DNNLDeconvFwd::CreatePrimitiveDesc(const 
DeconvolutionParam& param,
                                                                     const 
Tensors& tensors) {
   DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, 
tensors.out);
+  auto fwd_desc = ddc.CreateFwdDesc();  // `fwd_desc` lifetime must be longer 
than `pd`
+                                        // when using next_impl
   const auto& engine          = CpuEngine::Get()->get_engine();
-  const auto pd               = 
std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
+  const auto pd               = std::make_shared<deconv_fwd_pd_t>(fwd_desc, 
engine);
   const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return pd->dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have 
plain formats
-    // imposed, meaning there is no implementation with plain formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of deconvolution forward propagation";
-    *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of deconvolution forward propagation";
+      fwd_desc = ddc.CreateFwdDesc();
+      *pd      = deconv_fwd_pd_t(fwd_desc, engine);
+    }
   }
   return pd;
 }
@@ -204,18 +209,23 @@ std::shared_ptr<deconv_bwd_data_pd_t> 
DNNLDeconvBwd::CreateDataPrimitiveDesc(
     const deconv_fwd_pd_t& fwd_pd) {
   DeconvDescCreator ddc(
       param, read_tensors.data, read_tensors.weights, nullptr, 
read_tensors.out_grad);
-  const auto& engine = CpuEngine::Get()->get_engine();
-  const auto pd = 
std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+  auto bwd_d_desc = ddc.CreateBwdDataDesc();  // `bwd_d_desc` lifetime must be 
longer than `pd`
+                                              // when using next_impl
+  const auto& engine          = CpuEngine::Get()->get_engine();
+  const auto pd               = 
std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
   const auto get_data_size    = [&pd]() { return 
pd->diff_src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have 
plain formats
-    // imposed, meaning there is no implementation with plain formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of deconvolution backward propagation";
-    *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of deconvolution backward propagation";
+      bwd_d_desc = ddc.CreateBwdDataDesc();
+      *pd        = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
+    }
   }
   return pd;
 }
@@ -226,19 +236,23 @@ std::shared_ptr<deconv_bwd_weights_pd_t> 
DNNLDeconvBwd::CreateWeightsPrimitiveDe
     const deconv_fwd_pd_t& fwd_pd) {
   DeconvDescCreator ddc(
       param, read_tensors.data, read_tensors.weights, read_tensors.bias, 
read_tensors.out_grad);
-  const auto& engine = CpuEngine::Get()->get_engine();
-  const auto pd =
-      std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), 
engine, fwd_pd);
-  const auto get_data_size    = [&pd]() { return pd->src_desc().get_size(); };
+  auto bwd_w_desc = ddc.CreateBwdWeightsDesc();  // `bwd_w_desc` lifetime must 
be longer than `pd`
+                                                 // when using next_impl
+  const auto& engine       = CpuEngine::Get()->get_engine();
+  const auto pd            = 
std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
+  const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
   const auto get_weights_size = [&pd]() { return 
pd->diff_weights_desc().get_size(); };
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    // ImposePlainWherePadding fails when all memory descriptors already have 
plain formats
-    // imposed, meaning there is no implementation with plain formats
-    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-        << "No implementation of calculating deconvolution weights gradient";
-    *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+    if (!pd->next_impl()) {
+      // ImposePlainWherePadding fails when all memory descriptors already 
have plain formats
+      // imposed, meaning there is no implementation with plain formats
+      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+          << "No implementation of calculating deconvolution weights gradient";
+      bwd_w_desc = ddc.CreateBwdWeightsDesc();
+      *pd        = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
+    }
   }
   return pd;
 }

Reply via email to