This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b4aca83e31 Use requested mem in dot op to reduce memory usage (#21067)
b4aca83e31 is described below
commit b4aca83e31455fa0bec306f3753846c9e9c16141
Author: bartekkuncer <[email protected]>
AuthorDate: Thu Jun 23 13:37:47 2022 +0200
Use requested mem in dot op to reduce memory usage (#21067)
* Use requested mem in dot op to reduce memory usage
* Fix oneDNN still allocating the memory
* Assign pointer to a variable for simplification
---
src/operator/nn/dnnl/dnnl_dot-inl.h | 5 +++--
src/operator/nn/dnnl/dnnl_dot.cc | 21 +++++++++++++++------
2 files changed, 18 insertions(+), 8 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_dot-inl.h
b/src/operator/nn/dnnl/dnnl_dot-inl.h
index b375872fc2..5a9f32b8a2 100644
--- a/src/operator/nn/dnnl/dnnl_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_dot-inl.h
@@ -52,7 +52,8 @@ class DNNLDotFwd {
const std::vector<NDArray>& outputs,
const bool isNumpy);
- void Execute(const std::vector<NDArray>& inputs,
+ void Execute(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
const bool isNumpy);
@@ -78,7 +79,7 @@ void DNNLDotForward(const nnvm::NodeAttrs& attrs,
param = nnvm::get<DotParam>(attrs.parsed);
}
DNNLDotFwd& fwd = DNNLDotFwd::GetCached(param, inputs, outputs, isNumpy);
- fwd.Execute(inputs, req, outputs, isNumpy);
+ fwd.Execute(ctx, inputs, req, outputs, isNumpy);
}
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
index 30468ce22a..e70d4e5a97 100644
--- a/src/operator/nn/dnnl/dnnl_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -105,16 +105,27 @@ DNNLDotFwd::DNNLDotFwd(const DotParam& param,
fwd = std::make_shared<dot_fwd_t>(*fwd_pd);
}
-void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
+void DNNLDotFwd::Execute(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
const bool isNumpy) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto lhs = dnnl::memory(
fwd_pd->src_desc(), engine,
reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
- auto rhs = dnnl::memory(fwd_pd->weights_desc(), engine);
- auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
- if (isNumpy && ndimRhs > 2) {
+ auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
+ const bool specialNumpyCase = isNumpy && ndimRhs > 2;
+ auto rhsMemPointer =
+ specialNumpyCase ?
+ reinterpret_cast<void*>(
+ ctx.requested[0]
+
.get_space<cpu>(mshadow::Shape1(inputs[DotIn::rhs].shape().Size() *
+
GetTypeSize(inputs[DotIn::rhs].dtype())),
+ ctx.get_stream<cpu>())
+ .dptr_) :
+ reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_);
+ dnnl::memory rhs(fwd_pd->weights_desc(), engine, rhsMemPointer);
+ if (specialNumpyCase) {
// Necessity of this reorder is described in DNNLDotFwd constructor.
auto tmp_rhs = inputs[DotIn::rhs].GetDNNLData();
dnnl::memory::desc rhs_md(
@@ -125,8 +136,6 @@ void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
const auto rhs_reorder_pd = dnnl::reorder::primitive_desc(*tmp_rhs,
tmp_rhs_dst);
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(rhs_reorder_pd),
{{DNNL_ARG_FROM, *tmp_rhs},
{DNNL_ARG_TO, tmp_rhs_dst}});
- } else {
-
rhs.set_data_handle(reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
}
dnnl_output_t out_mem = CreateDNNLMem(
outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out],
&inputs[DotIn::lhs]);