This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 930e140 Reintroduce next_impl in onednn deconvolution (#20663)
930e140 is described below
commit 930e14047ba42bf9519f308069e07eee7ef7a687
Author: Paweł Głomski <[email protected]>
AuthorDate: Fri Nov 19 17:39:54 2021 +0100
Reintroduce next_impl in onednn deconvolution (#20663)
---
src/operator/nn/dnnl/dnnl_deconvolution.cc | 58 ++++++++++++++++++------------
1 file changed, 36 insertions(+), 22 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc
b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index f4766a1..b853d1a 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -75,18 +75,23 @@ DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const
DeconvolutionParam& param, const T
std::shared_ptr<deconv_fwd_pd_t> DNNLDeconvFwd::CreatePrimitiveDesc(const
DeconvolutionParam& param,
const
Tensors& tensors) {
DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias,
tensors.out);
+ auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer
than `pd`
+ // when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
+ const auto pd = std::make_shared<deconv_fwd_pd_t>(fwd_desc,
engine);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
plain formats
- // imposed, meaning there is no implementation with plain formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution forward propagation";
- *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution forward propagation";
+ fwd_desc = ddc.CreateFwdDesc();
+ *pd = deconv_fwd_pd_t(fwd_desc, engine);
+ }
}
return pd;
}
@@ -204,18 +209,23 @@ std::shared_ptr<deconv_bwd_data_pd_t>
DNNLDeconvBwd::CreateDataPrimitiveDesc(
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, nullptr,
read_tensors.out_grad);
- const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+ auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be
longer than `pd`
+ // when using next_impl
+ const auto& engine = CpuEngine::Get()->get_engine();
+ const auto pd =
std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
const auto get_data_size = [&pd]() { return
pd->diff_src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
plain formats
- // imposed, meaning there is no implementation with plain formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of deconvolution backward propagation";
- *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of deconvolution backward propagation";
+ bwd_d_desc = ddc.CreateBwdDataDesc();
+ *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
+ }
}
return pd;
}
@@ -226,19 +236,23 @@ std::shared_ptr<deconv_bwd_weights_pd_t>
DNNLDeconvBwd::CreateWeightsPrimitiveDe
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, read_tensors.bias,
read_tensors.out_grad);
- const auto& engine = CpuEngine::Get()->get_engine();
- const auto pd =
- std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(),
engine, fwd_pd);
- const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
+ auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must
be longer than `pd`
+ // when using next_impl
+ const auto& engine = CpuEngine::Get()->get_engine();
+ const auto pd =
std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
+ const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return
pd->diff_weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return
pd->diff_dst_desc().get_size(); };
while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(),
get_out_size())) {
- // ImposePlainWherePadding fails when all memory descriptors already have
plain formats
- // imposed, meaning there is no implementation with plain formats
- CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
- << "No implementation of calculating deconvolution weights gradient";
- *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
+ if (!pd->next_impl()) {
+ // ImposePlainWherePadding fails when all memory descriptors already
have plain formats
+ // imposed, meaning there is no implementation with plain formats
+ CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(),
get_out_size()))
+ << "No implementation of calculating deconvolution weights gradient";
+ bwd_w_desc = ddc.CreateBwdWeightsDesc();
+ *pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
+ }
}
return pd;
}