This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 15cacdbdbf broadcast_like CPU optimization (#21004)
15cacdbdbf is described below

commit 15cacdbdbf4b8609252827382bbf0e83fbe40063
Author: bgawrych <[email protected]>
AuthorDate: Tue Jun 7 09:15:21 2022 +0200

    broadcast_like CPU optimization (#21004)
    
    * working with tmp
    
    * working without tmp
    
    * refactor
    
    * condition
    
    * refactor
    
    * remove onednn broadcast
    
    * remove temporary memory resource
    
    * fix sanity
    
    * Fix tests
    
    * apply review comments
    
    * sanity
    
    * update comment
    
    Co-authored-by: Bartlomiej Gawrych <[email protected]>
---
 src/operator/tensor/broadcast_reduce_op.h | 110 +++++++++++++++++++++++++-----
 1 file changed, 92 insertions(+), 18 deletions(-)

diff --git a/src/operator/tensor/broadcast_reduce_op.h 
b/src/operator/tensor/broadcast_reduce_op.h
index 8265a3f475..3d0eba2d90 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -1354,6 +1354,80 @@ struct direct_copy {
   }
 };
 
+template <typename IType, typename OType>
+void BroadcastCPU(const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs,
+                  const mxnet::TShape& src_shape,
+                  const mxnet::TShape& dst_shape,
+                  ShapeAndStride aux_data) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  constexpr size_t ELEMENTS_THRESHOLD = 256;
+  Stream<cpu>* s                      = ctx.get_stream<cpu>();
+
+  std::vector<size_t> elements_to_copy(aux_data.num_broadcast_axes);
+  std::vector<size_t> preaxis_dims(aux_data.num_broadcast_axes);
+  for (int ax = 0; ax < aux_data.num_broadcast_axes; ax++) {
+    index_t axis = aux_data.axes[ax];
+
+    elements_to_copy[ax] = 1;
+    for (int i = axis + 1; i < dst_shape.ndim(); i++) {
+      elements_to_copy[ax] *= dst_shape[i];
+    }
+
+    preaxis_dims[ax] = src_shape[0];
+    for (int i = 1; i < axis; i++) {
+      preaxis_dims[ax] *= src_shape[i];
+    }
+  }
+
+  // determine if version with memcpy should be used
+  // there is no need to check further axis' elements to copy as it for sure 
will be larger
+  if (elements_to_copy[0] < ELEMENTS_THRESHOLD || !std::is_same<IType, 
OType>::value) {
+    IType* src = static_cast<IType*>(inputs[0].dptr_);
+    OType* dst = static_cast<OType*>(outputs[0].dptr_);
+
+    const int ndim = dst_shape.ndim() == 2 ? 2 : MXNET_SPECIAL_MAX_NDIM;
+    Kernel<broadcast_kernel_cpu<mshadow_op::identity>, cpu>::Launch(
+        s, src_shape.Size(), src, dst, aux_data, req[0], ndim);
+
+  } else {
+    IType* src = static_cast<IType*>(inputs[0].dptr_);
+    IType* dst = static_cast<IType*>(outputs[0].dptr_);
+    // broadcast axis independently with result reusage
+    const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    for (int ax = 0; ax < aux_data.num_broadcast_axes; ax++) {
+      index_t axis     = aux_data.axes[ax];
+      size_t bcast_dim = dst_shape[axis];
+
+#pragma omp parallel num_threads(omp_threads)
+      {
+        // start from the end to avoid overwriting values when src == dst
+        for (int i = preaxis_dims[ax] - 1; i >= 0; i--) {
+#pragma omp for
+          for (int j = bcast_dim - 1; j >= 0; j--) {
+#pragma GCC diagnostic push
+#if __GNUC__ >= 8
+#pragma GCC diagnostic ignored "-Wclass-memaccess"
+#endif
+            std::memcpy(dst + (elements_to_copy[ax] * (j + i * bcast_dim)),
+                        src + (elements_to_copy[ax] * i),
+                        elements_to_copy[ax] * sizeof(IType));
+#pragma GCC diagnostic pop
+          }
+        }
+      }
+      // when first of broadcastable axis is broadcasted,
+      // run same algorithm for next brodcast axis with 'new' input
+      // this is why loops are iterating from the end
+      src = dst;
+    }
+  }
+}
+
 /**
  * When CPU context is used the no. of kernel launches are equal to
  * the no. of input elements, this helps leverage vectorization when possible
@@ -1377,13 +1451,17 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& 
attrs,
   //      -> (12,1,5,1,42) (1,3) (50, 9)
   //      and this is the new input for broadcast_kernel whose total
   //      num of dimensions cannot be greater than 5(throws an error 
otherwise).
+
   BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, 
&src_shape);
+
   Stream<xpu>* s = ctx.get_stream<xpu>();
   bool isCPU     = std::is_same<xpu, cpu>::value;
+
   MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, {
     MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, OType, {
       mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
       mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
+
       for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {
         if (i < dst_shape.ndim()) {
           in_shape[i]  = src_shape[i];
@@ -1400,26 +1478,22 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& 
attrs,
         // then simply copy input to outout.
         Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
             s, outputs[0].Size(), inputs[0].dptr<IType>(), 
outputs[0].dptr<OType>(), req[0]);
-      } else if (dst_shape.ndim() == 2) {
-        Tensor<xpu, 2, OType> out = outputs[0].get_with_shape<xpu, 2, 
OType>(dst_shape.get<2>(), s);
-        Tensor<xpu, 2, IType> data = inputs[0].get_with_shape<xpu, 2, 
IType>(src_shape.get<2>(), s);
-        if (isCPU) {
-          Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
-              s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 
2);
-        } else {
+      } else if (isCPU) {
+        BroadcastCPU<IType, OType>(ctx, inputs, req, outputs, src_shape, 
dst_shape, aux_data);
+      } else {
+        if (dst_shape.ndim() == 2) {
+          Tensor<xpu, 2, OType> out =
+              outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
+          Tensor<xpu, 2, IType> data =
+              inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
           Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
               s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 
2);
-        }
-      } else {
-        const int ndim = MXNET_SPECIAL_MAX_NDIM;
-        Tensor<xpu, ndim, OType> out =
-            outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), 
s);
-        Tensor<xpu, ndim, IType> data =
-            inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), 
s);
-        if (isCPU) {
-          Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
-              s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 
ndim);
         } else {
+          const int ndim = MXNET_SPECIAL_MAX_NDIM;
+          Tensor<xpu, ndim, OType> out =
+              outputs[0].get_with_shape<xpu, ndim, 
OType>(dst_shape.get<ndim>(), s);
+          Tensor<xpu, ndim, IType> data =
+              inputs[0].get_with_shape<xpu, ndim, 
IType>(src_shape.get<ndim>(), s);
           Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
               s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 
ndim);
         }
@@ -1640,7 +1714,7 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
   }
 #else
   const std::string& red = param.ord == 1 ? "red::sum{}" : "red::nrm2{}";
-  const std::string& op  = param.ord == 1 ? "abs" : "identity";
+  const std::string& op = param.ord == 1 ? "abs" : "identity";
   ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, nullptr, 
false, op);
 #endif
 }

Reply via email to