This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new b56571d [v1.x] backport #17900 "[MKLDNN] support using any format in
pooling backward" (#18067)
b56571d is described below
commit b56571d5f3905f969c6cbfee87068ede75d20f58
Author: YixinBao <[email protected]>
AuthorDate: Thu Apr 16 10:37:05 2020 +0800
[v1.x] backport #17900 "[MKLDNN] support using any format in pooling
backward" (#18067)
* [MKLDNN] support using any format in pooling backward (#17900)
* use any format in pooling backward
* use data_type()
* fix backport
---
src/operator/nn/mkldnn/mkldnn_pooling.cc | 63 +++++++++++++++-----------------
1 file changed, 30 insertions(+), 33 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc
b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index d2f7970..f987054 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -30,6 +30,10 @@
namespace mxnet {
namespace op {
+static inline mkldnn::memory::data_type get_data_type(const
mkldnn::memory::desc &md) {
+ return static_cast<mkldnn::memory::data_type>(md.data_type());
+}
+
void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray
&output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
@@ -93,7 +97,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
auto engine = CpuEngine::Get()->get_engine();
if (workspace == nullptr) {
- LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
+ LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}
auto ws =
std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
@@ -290,30 +294,21 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m,
auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
- NDArray diff_dst_buff = out_grad;
- if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() ==
true) {
- diff_dst_buff = out_grad.Reorder2Default();
- }
- auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
- const mkldnn::memory::desc data_md = input_mem->get_desc();
- const mkldnn::memory::dims dims = {data_md.data.dims[0],
data_md.data.dims[1],
- static_cast<int>(out_grad.shape()[2]),
- static_cast<int>(out_grad.shape()[3])};
- const mkldnn::memory::desc out_md(
- {dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
- mkldnn::memory::format_tag::any);
- auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
- const mkldnn::memory::desc diff_md =
- diff_dst_mem->get_desc();
- const mkldnn::memory::dims dims1 = {diff_md.data.dims[0],
diff_md.data.dims[1],
- static_cast<int>(in_grad.shape()[2]),
- static_cast<int>(in_grad.shape()[3])};
- const mkldnn::memory::desc diff_in_md(
- {dims1},
static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
- mkldnn::memory::format_tag::any);
- const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
- const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+ auto data_md = input_mem->get_desc();
+
+ auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(),
out_grad.shape().end());
+ auto any = mkldnn::memory::format_tag::any;
+ auto dst_md = mkldnn::memory::desc(dst_dims, get_data_type(data_md), any);
+
+ // fwd hint
+ auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md);
+
+ // creat bwd desc
+ auto diff_src_dims = mkldnn::memory::dims(in_grad.shape().begin(),
in_grad.shape().end());
+ auto diff_src_md = mkldnn::memory::desc(diff_src_dims,
get_data_type(data_md), any);
+ auto cpu_engine = CpuEngine::Get()->get_engine();;
+ auto alg = GetMKLDNNPoolAlgo(param);
int kernel_h_, kernel_w_;
if (param.global_pool) {
@@ -338,10 +333,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m,
stride_h_ = stride_w_ = 1;
}
- const mkldnn::pooling_backward::desc desc(
- alg, diff_in_md, diff_md, {stride_h_, stride_w_},
- {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
- const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc,
cpu_engine, fwd_pd);
+ // use dst_md as diff_dst_md with any format
+ auto bwd_desc = mkldnn::pooling_backward::desc(
+ alg, diff_src_md, dst_md, {stride_h_, stride_w_},
+ {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
+ auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc,
cpu_engine, fwd_pd);
+
MKLDNNPoolingBwd bwd(pdesc, with_workspace);
it = AddToCache(&pooling_bwds, key, bwd);
}
@@ -355,15 +352,15 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const
PoolingParam ¶m,
if (req == kNullOp) {
return;
}
+
TmpMemMgr::Get()->Init(ctx.requested[0]);
auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
- auto diff_src_mem =
- CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
-
+ auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
+ auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
- {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
- {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
+ {MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
+ {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
};
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());