This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 694bf49 use correct prop_kind for mkl-dnn FC layer (#10442) 694bf49 is described below commit 694bf490c8912153692e1b5930dac65e4512d6f1 Author: Ashok Emani <ashok.em...@intel.com> AuthorDate: Mon Apr 9 10:21:57 2018 -0700 use correct prop_kind for mkl-dnn FC layer (#10442) * use correct prop_kind for mkl-dnn FC layer * fix clang-format issue * fix clang-format issue --- src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index eb379f2..f86f8db 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -32,17 +32,19 @@ namespace op { inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd( const NDArray &data, const NDArray &weight, const NDArray *bias, - const mkldnn::memory::desc &out_md) { + const mkldnn::memory::desc &out_md, const bool is_train) { auto data_md = GetMemDesc(data); auto weight_md = GetMemDesc(weight); auto engine = CpuEngine::Get()->get_engine(); + auto propagation = + is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; if (bias) { auto bias_md = GetMemDesc(*bias); - mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + mkldnn::inner_product_forward::desc ipFwd_desc(propagation, data_md, weight_md, bias_md, out_md); return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); } else { - mkldnn::inner_product_forward::desc ipFwd_desc(mkldnn::prop_kind::forward_training, + mkldnn::inner_product_forward::desc ipFwd_desc(propagation, data_md, weight_md, out_md); return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine); } @@ -112,7 +114,7 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, } mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_data[fullc::kBias], out_md); + param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], @@ -156,7 +158,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, oshape.ProdShape(1, oshape.ndim()))); mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); + param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; if (req[fullc::kData]) { -- To stop receiving notification emails like this one, please contact j...@apache.org.