This is an automated email from the ASF dual-hosted git repository. weichu pushed a commit to branch v1.9.1-test in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 681a130503c52e0b415670d2a7bda5d5b713956b Author: Paweł Głomski <[email protected]> AuthorDate: Tue Dec 14 08:06:44 2021 +0100 Fix next_impl in deconvolution (#20750) --- src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 52 +++++++++++++++----------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 21608153bd..43423e792d 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -76,9 +76,11 @@ MKLDNNDeconvFwd &MKLDNNDeconvFwd::GetCached(const DeconvolutionParam ¶m, std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc( const DeconvolutionParam ¶m, const Tensors &tensors) { DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out); - const auto &engine = CpuEngine::Get()->get_engine(); - const auto pd = std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine); - const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer than `pd` + // when using next_impl + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared<deconv_fwd_pd_t>(fwd_desc, engine); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); }; @@ -88,7 +90,8 @@ std::shared_ptr<deconv_fwd_pd_t> MKLDNNDeconvFwd::CreatePrimitiveDesc( // imposed, meaning there is no implementation with plain formats CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) << "No implementation of deconvolution forward propagation"; - *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine); + fwd_desc = ddc.CreateFwdDesc(); + *pd = deconv_fwd_pd_t(fwd_desc, engine); } } return pd; @@ -201,13 +204,16 @@ MKLDNNDeconvBwd &MKLDNNDeconvBwd::GetCached(const DeconvolutionParam ¶m, } std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc( - const DeconvolutionParam ¶m, const ReadTensors &read_tensors, - const deconv_fwd_pd_t &fwd_pd) { - DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, nullptr, - read_tensors.out_grad); - const auto &engine = CpuEngine::Get()->get_engine(); - const auto pd = std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd); - const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); }; + const DeconvolutionParam& param, + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd) { + DeconvDescCreator ddc( + param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad); + auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be longer than `pd` + // when using next_impl + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; @@ -217,21 +223,24 @@ std::shared_ptr<deconv_bwd_data_pd_t> MKLDNNDeconvBwd::CreateDataPrimitiveDesc( // imposed, meaning there is no implementation with plain formats CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) << "No implementation of deconvolution backward propagation"; - *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd); + bwd_d_desc = ddc.CreateBwdDataDesc(); + *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd); } } return pd; } std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc( - const DeconvolutionParam ¶m, const ReadTensors &read_tensors, - const deconv_fwd_pd_t &fwd_pd) { - DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, read_tensors.bias, - read_tensors.out_grad); - const auto &engine = CpuEngine::Get()->get_engine(); - const auto pd = - std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); - const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + const DeconvolutionParam& param, + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd) { + DeconvDescCreator ddc( + param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad); + auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must be longer than `pd` + // when using next_impl + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; @@ -241,7 +250,8 @@ std::shared_ptr<deconv_bwd_weights_pd_t> MKLDNNDeconvBwd::CreateWeightsPrimitive // imposed, meaning there is no implementation with plain formats CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) << "No implementation of calculating deconvolution weights gradient"; - *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); + bwd_w_desc = ddc.CreateBwdWeightsDesc(); + *pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd); } } return pd;
