This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 694bf49  use correct prop_kind for mkl-dnn FC layer (#10442)
694bf49 is described below

commit 694bf490c8912153692e1b5930dac65e4512d6f1
Author: Ashok Emani <ashok.em...@intel.com>
AuthorDate: Mon Apr 9 10:21:57 2018 -0700

    use correct prop_kind for mkl-dnn FC layer (#10442)
    
    * use correct prop_kind for mkl-dnn FC layer
    
    * fix clang-format issue
    
    * fix clang-format issue
---
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc 
b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index eb379f2..f86f8db 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -32,17 +32,19 @@ namespace op {
 
 inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd(
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const mkldnn::memory::desc &out_md) {
+    const mkldnn::memory::desc &out_md, const bool is_train) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto engine = CpuEngine::Get()->get_engine();
+  auto propagation =
+    is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_forward::desc 
ipFwd_desc(mkldnn::prop_kind::forward_training,
+    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
         data_md, weight_md, bias_md, out_md);
     return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
   } else {
-    mkldnn::inner_product_forward::desc 
ipFwd_desc(mkldnn::prop_kind::forward_training,
+    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
         data_md, weight_md, out_md);
     return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
   }
@@ -112,7 +114,7 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
   }
 
   mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, 
weight,
-      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md);
+      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
   auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
   auto weight_mem = 
weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
@@ -156,7 +158,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
                                              oshape.ProdShape(1, 
oshape.ndim())));
 
   mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, 
weight,
-      param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));
+      param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), 
ctx.is_train);
 
   CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight 
inplace";
   if (req[fullc::kData]) {

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to