This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b56571d  [v1.x] backport #17900 "[MKLDNN] support using any format in 
pooling backward" (#18067)
b56571d is described below

commit b56571d5f3905f969c6cbfee87068ede75d20f58
Author: YixinBao <[email protected]>
AuthorDate: Thu Apr 16 10:37:05 2020 +0800

    [v1.x] backport #17900 "[MKLDNN] support using any format in pooling 
backward" (#18067)
    
    * [MKLDNN] support using any format in pooling backward (#17900)
    
    * use any format in pooling backward
    
    * use data_type()
    
    * fix backport
---
 src/operator/nn/mkldnn/mkldnn_pooling.cc | 63 +++++++++++++++-----------------
 1 file changed, 30 insertions(+), 33 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc 
b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index d2f7970..f987054 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -30,6 +30,10 @@
 namespace mxnet {
 namespace op {
 
+static inline mkldnn::memory::data_type get_data_type(const 
mkldnn::memory::desc &md) {
+  return static_cast<mkldnn::memory::data_type>(md.data_type());
+}
+
 void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray 
&output,
                             const int kernel_h,  const int kernel_w,
                             const int stride_h,  const int stride_w,
@@ -93,7 +97,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
     auto engine = CpuEngine::Get()->get_engine();
 
     if (workspace == nullptr) {
-        LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
+      LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
     }
 
     auto ws = 
std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
@@ -290,30 +294,21 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
 
   auto it = pooling_bwds.find(key);
   if (it == pooling_bwds.end()) {
-    NDArray diff_dst_buff = out_grad;
-    if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == 
true) {
-      diff_dst_buff = out_grad.Reorder2Default();
-    }
-    auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
     auto input_mem = in_data.GetMKLDNNData();
-    const mkldnn::memory::desc data_md = input_mem->get_desc();
-    const mkldnn::memory::dims dims = {data_md.data.dims[0], 
data_md.data.dims[1],
-                               static_cast<int>(out_grad.shape()[2]),
-                               static_cast<int>(out_grad.shape()[3])};
-    const mkldnn::memory::desc out_md(
-        {dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
-        mkldnn::memory::format_tag::any);
-    auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
-    const mkldnn::memory::desc diff_md =
-        diff_dst_mem->get_desc();
-    const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], 
diff_md.data.dims[1],
-                                static_cast<int>(in_grad.shape()[2]),
-                                static_cast<int>(in_grad.shape()[3])};
-    const mkldnn::memory::desc diff_in_md(
-        {dims1}, 
static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
-        mkldnn::memory::format_tag::any);
-    const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
-    const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+    auto data_md = input_mem->get_desc();
+
+    auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), 
out_grad.shape().end());
+    auto any = mkldnn::memory::format_tag::any;
+    auto dst_md = mkldnn::memory::desc(dst_dims, get_data_type(data_md), any);
+
+    // fwd hint
+    auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md);
+
+    // creat bwd desc
+    auto diff_src_dims = mkldnn::memory::dims(in_grad.shape().begin(), 
in_grad.shape().end());
+    auto diff_src_md = mkldnn::memory::desc(diff_src_dims, 
get_data_type(data_md), any);
+    auto cpu_engine = CpuEngine::Get()->get_engine();;
+    auto alg = GetMKLDNNPoolAlgo(param);
 
     int kernel_h_, kernel_w_;
     if (param.global_pool) {
@@ -338,10 +333,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
       stride_h_ = stride_w_ = 1;
     }
 
-    const mkldnn::pooling_backward::desc desc(
-        alg, diff_in_md, diff_md, {stride_h_, stride_w_},
-        {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
-    const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, 
cpu_engine, fwd_pd);
+    // use dst_md as diff_dst_md with any format
+    auto bwd_desc = mkldnn::pooling_backward::desc(
+      alg, diff_src_md, dst_md, {stride_h_, stride_w_},
+      {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
+    auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc, 
cpu_engine, fwd_pd);
+
     MKLDNNPoolingBwd bwd(pdesc, with_workspace);
     it = AddToCache(&pooling_bwds, key, bwd);
   }
@@ -355,15 +352,15 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const 
PoolingParam &param,
   if (req == kNullOp) {
     return;
   }
+
   TmpMemMgr::Get()->Init(ctx.requested[0]);
 
   auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
-  auto diff_src_mem =
-      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
-
+  auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
+  auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
   mkldnn_args_map_t args = {
-    {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
-    {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
+    {MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
+    {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
   };
   if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
     args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());

Reply via email to