This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new f88b03e Fix next_impl in deconvolution (#20750)
f88b03e is described below
commit f88b03e49c0d807c61f7d6ff5f623284fcbb30ce
Author: Paweł Głomski <[email protected]>
AuthorDate: Tue Dec 14 08:06:44 2021 +0100
Fix next_impl in deconvolution (#20750)
---
src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 55 +++++++++++++++-----------
1 file changed, 33 insertions(+), 22 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 0776ee4..9783c4b 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -76,19 +76,23 @@ std::shared_ptr<deconv_fwd_pd_t>
MKLDNNDeconvFwd::CreatePrimitiveDesc(
const DeconvolutionParam& param,
const Tensors& tensors) {
DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias,
tensors.out);
+ auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer
than `pd`
+ // when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
+ const auto pd = std::make_shared<deconv_fwd_pd_t>(fwd_desc,
engine);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution forward propagation";
- *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution forward propagation";
+ fwd_desc = ddc.CreateFwdDesc();
+ *pd = deconv_fwd_pd_t(fwd_desc, engine);
+ }
}
return pd;
}
@@ -213,19 +217,23 @@ std::shared_ptr<deconv_bwd_data_pd_t>
MKLDNNDeconvBwd::CreateDataPrimitiveDesc(
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, nullptr,
read_tensors.out_grad);
+ auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be
longer than `pd`
+ // when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+ const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine,
fwd_pd);
const auto get_data_size = [&pd]() { return
pd->diff_src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution backward propagation";
- *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution backward propagation";
+ bwd_d_desc = ddc.CreateBwdDataDesc();
+ *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
+ }
}
return pd;
}
@@ -236,20 +244,23 @@ std::shared_ptr<deconv_bwd_weights_pd_t>
MKLDNNDeconvBwd::CreateWeightsPrimitive
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, read_tensors.bias,
read_tensors.out_grad);
+ auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must
be longer than `pd`
+ // when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
- std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(),
engine, fwd_pd);
+ const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc,
engine, fwd_pd);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->diff_weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
- // plain formats imposed, meaning there is no implementation with plain
- // formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of calculating deconvolution weights gradient";
- *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of calculating deconvolution weights gradient";
+ bwd_w_desc = ddc.CreateBwdWeightsDesc();
+ *pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
+ }
}
return pd;
}