piiswrong closed pull request #10442: use correct prop_kind for mkl-dnn FC layer
URL: https://github.com/apache/incubator-mxnet/pull/10442
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc 
b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index eb379f2c12f..f86f8dbefa2 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -32,17 +32,19 @@ namespace op {
 
 inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd(
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const mkldnn::memory::desc &out_md) {
+    const mkldnn::memory::desc &out_md, const bool is_train) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto engine = CpuEngine::Get()->get_engine();
+  auto propagation =
+    is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_forward::desc 
ipFwd_desc(mkldnn::prop_kind::forward_training,
+    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
         data_md, weight_md, bias_md, out_md);
     return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
   } else {
-    mkldnn::inner_product_forward::desc 
ipFwd_desc(mkldnn::prop_kind::forward_training,
+    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
         data_md, weight_md, out_md);
     return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
   }
@@ -112,7 +114,7 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
   }
 
   mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, 
weight,
-      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md);
+      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
   auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
   auto weight_mem = 
weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
@@ -156,7 +158,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const 
OpContext &ctx,
                                              oshape.ProdShape(1, 
oshape.ndim())));
 
   mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, 
weight,
-      param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));
+      param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), 
ctx.is_train);
 
   CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight 
inplace";
   if (req[fullc::kData]) {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to