This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b4aca83e31 Use requested mem in dot op to reduce memory usage (#21067)
b4aca83e31 is described below

commit b4aca83e31455fa0bec306f3753846c9e9c16141
Author: bartekkuncer <[email protected]>
AuthorDate: Thu Jun 23 13:37:47 2022 +0200

    Use requested mem in dot op to reduce memory usage (#21067)
    
    * Use requested mem in dot op to reduce memory usage
    
    * Fix oneDNN still allocating the memory
    
    * Assign pointer to a variable for simplification
---
 src/operator/nn/dnnl/dnnl_dot-inl.h |  5 +++--
 src/operator/nn/dnnl/dnnl_dot.cc    | 21 +++++++++++++++------
 2 files changed, 18 insertions(+), 8 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_dot-inl.h 
b/src/operator/nn/dnnl/dnnl_dot-inl.h
index b375872fc2..5a9f32b8a2 100644
--- a/src/operator/nn/dnnl/dnnl_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_dot-inl.h
@@ -52,7 +52,8 @@ class DNNLDotFwd {
              const std::vector<NDArray>& outputs,
              const bool isNumpy);
 
-  void Execute(const std::vector<NDArray>& inputs,
+  void Execute(const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
                const std::vector<OpReqType>& req,
                const std::vector<NDArray>& outputs,
                const bool isNumpy);
@@ -78,7 +79,7 @@ void DNNLDotForward(const nnvm::NodeAttrs& attrs,
     param = nnvm::get<DotParam>(attrs.parsed);
   }
   DNNLDotFwd& fwd = DNNLDotFwd::GetCached(param, inputs, outputs, isNumpy);
-  fwd.Execute(inputs, req, outputs, isNumpy);
+  fwd.Execute(ctx, inputs, req, outputs, isNumpy);
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
index 30468ce22a..e70d4e5a97 100644
--- a/src/operator/nn/dnnl/dnnl_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -105,16 +105,27 @@ DNNLDotFwd::DNNLDotFwd(const DotParam& param,
   fwd    = std::make_shared<dot_fwd_t>(*fwd_pd);
 }
 
-void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
+void DNNLDotFwd::Execute(const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs,
                          const bool isNumpy) {
   auto engine = mxnet::CpuEngine::Get()->get_engine();
   auto lhs    = dnnl::memory(
       fwd_pd->src_desc(), engine, 
reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
-  auto rhs     = dnnl::memory(fwd_pd->weights_desc(), engine);
-  auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
-  if (isNumpy && ndimRhs > 2) {
+  auto ndimRhs                = inputs[DotIn::rhs].shape().ndim();
+  const bool specialNumpyCase = isNumpy && ndimRhs > 2;
+  auto rhsMemPointer =
+      specialNumpyCase ?
+          reinterpret_cast<void*>(
+              ctx.requested[0]
+                  
.get_space<cpu>(mshadow::Shape1(inputs[DotIn::rhs].shape().Size() *
+                                                  
GetTypeSize(inputs[DotIn::rhs].dtype())),
+                                  ctx.get_stream<cpu>())
+                  .dptr_) :
+          reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_);
+  dnnl::memory rhs(fwd_pd->weights_desc(), engine, rhsMemPointer);
+  if (specialNumpyCase) {
     // Necessity of this reorder is described in DNNLDotFwd constructor.
     auto tmp_rhs = inputs[DotIn::rhs].GetDNNLData();
     dnnl::memory::desc rhs_md(
@@ -125,8 +136,6 @@ void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
     const auto rhs_reorder_pd = dnnl::reorder::primitive_desc(*tmp_rhs, 
tmp_rhs_dst);
     DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(rhs_reorder_pd),
                                         {{DNNL_ARG_FROM, *tmp_rhs}, 
{DNNL_ARG_TO, tmp_rhs_dst}});
-  } else {
-    
rhs.set_data_handle(reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
   }
   dnnl_output_t out_mem = CreateDNNLMem(
       outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], 
&inputs[DotIn::lhs]);

Reply via email to