access2rohit commented on a change in pull request #17882:
URL: https://github.com/apache/incubator-mxnet/pull/17882#discussion_r411024747



##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1103,21 +1132,52 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& 
attrs,
           out_shape[i] = 1;
         }
       }
+      uint64_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()];
+      int iter = dst_shape.ndim() - 1, i = 0;
+      out_stride[iter] = 1;
+      if (in_shape[iter] != dst_shape[iter]) {
+        axes[i] = iter;
+        i++;
+      }
+      --iter;
+      for (; iter >= 0; --iter) {
+        if (in_shape[iter] != dst_shape[iter]) {
+          axes[i] = iter;
+          i++;
+        }
+        out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1];
+      }
       if (dst_shape.ndim() == 2) {
         Tensor<xpu, 2, OType> out =
           outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
         Tensor<xpu, 2, IType> data =
           inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, 
req[0], 2);
+          if (!enable_lt) {
+            typedef int32_t index_t;
+            Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+              s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+              out_shape, req[0], 2, axes, out_stride, 1);
+          } else {
+            Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+              s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+              out_shape, req[0], 2, axes, out_stride, 1);
+          }
       } else {
         const int ndim = MXNET_SPECIAL_MAX_NDIM;
         Tensor<xpu, ndim, OType> out =
           outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), 
s);
         Tensor<xpu, ndim, IType> data =
           inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, 
req[0], ndim);
+          if (!enable_lt) {

Review comment:
       @apeforest 
   Here are the results after removing typedef and enable_lt conditional check. 
There is a slight improvement in the results. But its not that significant.
   
   | code version | cases                                       | avg    | p50  
  | p90    |
   
|--------------|---------------------------------------------|--------|--------|--------|
   |  master LT   | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 28.89   | 
26.91  | 34.51  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 28.70  | 
26.45  | 34.03  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 123.14 | 
121.88 | 125.35 |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 65.18  | 
61.52  | 74.58  |
   |              |                                             |        |      
  |        |
   | master no-LT | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 13.82  | 
13.24  | 16.83  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 13.46  | 
13.00  | 16.18  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 52.38  | 
51.36  | 53.51  |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 35.29  | 
38.65  | 40.23  |
   |              |                                             |        |      
  |        |
   | new LT       | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 15.91  | 
15.79  | 18.23  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 16.44  | 
17.50  | 17.97  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 124.17 | 
125.65 | 126.42 |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 37.33  | 
36.76  | 38.07  |
   |              |                                             |        |      
  |        |
   | new no-LT    | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 9.76   | 9.91 
  | 11.98  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 9.04   | 9.54 
  | 9.96  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 58.40  | 
60.54  | 61.01  |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 23.45  | 
23.11  | 26.79  |




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to