This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 42210a1  Remove next_impl calls (#20589)
42210a1 is described below

commit 42210a12f23a6f6ffe0b22d66b4a43710bc222ce
Author: Paweł Głomski <[email protected]>
AuthorDate: Thu Sep 30 11:35:22 2021 +0200

    Remove next_impl calls (#20589)
---
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 42 +++++++++++---------------
 1 file changed, 18 insertions(+), 24 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc 
b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 8428549..0776ee4 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -83,14 +83,12 @@ std::shared_ptr<deconv_fwd_pd_t> 
MKLDNNDeconvFwd::CreatePrimitiveDesc(
   const auto get_out_size     = [&pd]() { return pd->dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    if (!pd->next_impl()) {
-      // ImposePlainWherePadding fails when all memory descriptors already have
-      // plain formats imposed, meaning there is no implementation with plain
-      // formats
-      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-          << "No implementation of deconvolution forward propagation";
-      *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
-    }
+    // ImposePlainWherePadding fails when all memory descriptors already have
+    // plain formats imposed, meaning there is no implementation with plain
+    // formats
+    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+        << "No implementation of deconvolution forward propagation";
+    *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
   }
   return pd;
 }
@@ -222,14 +220,12 @@ std::shared_ptr<deconv_bwd_data_pd_t> 
MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    if (!pd->next_impl()) {
-      // ImposePlainWherePadding fails when all memory descriptors already have
-      // plain formats imposed, meaning there is no implementation with plain
-      // formats
-      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-          << "No implementation of deconvolution backward propagation";
-      *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
-    }
+    // ImposePlainWherePadding fails when all memory descriptors already have
+    // plain formats imposed, meaning there is no implementation with plain
+    // formats
+    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+        << "No implementation of deconvolution backward propagation";
+    *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
   }
   return pd;
 }
@@ -248,14 +244,12 @@ std::shared_ptr<deconv_bwd_weights_pd_t> 
MKLDNNDeconvBwd::CreateWeightsPrimitive
   const auto get_out_size     = [&pd]() { return 
pd->diff_dst_desc().get_size(); };
 
   while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), 
get_out_size())) {
-    if (!pd->next_impl()) {
-      // ImposePlainWherePadding fails when all memory descriptors already have
-      // plain formats imposed, meaning there is no implementation with plain
-      // formats
-      CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
-          << "No implementation of calculating deconvolution weights gradient";
-      *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, 
fwd_pd);
-    }
+    // ImposePlainWherePadding fails when all memory descriptors already have
+    // plain formats imposed, meaning there is no implementation with plain
+    // formats
+    CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), 
get_out_size()))
+        << "No implementation of calculating deconvolution weights gradient";
+    *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
   }
   return pd;
 }

Reply via email to