piiswrong closed pull request #10442: use correct prop_kind for mkl-dnn FC layer URL: https://github.com/apache/incubator-mxnet/pull/10442
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index eb379f2c12f..f86f8dbefa2 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -32,17 +32,19 @@ namespace op { inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd( const NDArray &data, const NDArray &weight, const NDArray *bias, - const mkldnn::memory::desc &out_md) { + const mkldnn::memory::desc &out_md, const bool is_train) { auto data_md = GetMemDesc(data); auto weight_md = GetMemDesc(weight); auto engine = CpuEngine::Get()->get_engine(); + auto propagation = + is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; if (bias) { auto bias_md = GetMemDesc(*bias); - mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + mkldnn::inner_product_forward::desc ipFwd_desc(propagation, data_md, weight_md, bias_md, out_md); return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); } else { - mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + mkldnn::inner_product_forward::desc ipFwd_desc(propagation, data_md, weight_md, out_md); return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); } @@ -112,7 +114,7 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, } mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_data[fullc::kBias], out_md); + param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], @@ -156,7 +158,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, oshape.ProdShape(1, oshape.ndim()))); mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); + param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; if (req[fullc::kData]) { ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services