This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 42210a1 Remove next_impl calls (#20589)
42210a1 is described below
commit 42210a12f23a6f6ffe0b22d66b4a43710bc222ce
Author: Paweł Głomski <[email protected]>
AuthorDate: Thu Sep 30 11:35:22 2021 +0200
Remove next_impl calls (#20589)
---
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 42 +++++++++++---------------
1 file changed, 18 insertions(+), 24 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 8428549..0776ee4 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -83,14 +83,12 @@ std::shared_ptr<deconv_fwd_pd_t>
MKLDNNDeconvFwd::CreatePrimitiveDesc(
const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- if (!pd->next_impl()) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution forward propagation";
- *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
- }
+ // ImposePlainWherePadding fails when all memory descriptors already have
+ // plain formats imposed, meaning there is no implementation with plain
+ // formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution forward propagation";
+ *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
}
return pd;
}
@@ -222,14 +220,12 @@ std::shared_ptr<deconv_bwd_data_pd_t>
MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- if (!pd->next_impl()) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution backward propagation";
- *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
- }
+ // ImposePlainWherePadding fails when all memory descriptors already have
+ // plain formats imposed, meaning there is no implementation with plain
+ // formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution backward propagation";
+ *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
}
return pd;
}
@@ -248,14 +244,12 @@ std::shared_ptr<deconv_bwd_weights_pd_t>
MKLDNNDeconvBwd::CreateWeightsPrimitive
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- if (!pd->next_impl()) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of calculating deconvolution weights gradient";
- *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine,
fwd_pd);
- }
+ // ImposePlainWherePadding fails when all memory descriptors already have
+ // plain formats imposed, meaning there is no implementation with plain
+ // formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of calculating deconvolution weights gradient";
+ *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
}
return pd;
}