access2rohit commented on a change in pull request #17882:
URL: https://github.com/apache/incubator-mxnet/pull/17882#discussion_r411024747



##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1103,21 +1132,52 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& 
attrs,
           out_shape[i] = 1;
         }
       }
+      uint64_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()];
+      int iter = dst_shape.ndim() - 1, i = 0;
+      out_stride[iter] = 1;
+      if (in_shape[iter] != dst_shape[iter]) {
+        axes[i] = iter;
+        i++;
+      }
+      --iter;
+      for (; iter >= 0; --iter) {
+        if (in_shape[iter] != dst_shape[iter]) {
+          axes[i] = iter;
+          i++;
+        }
+        out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1];
+      }
       if (dst_shape.ndim() == 2) {
         Tensor<xpu, 2, OType> out =
           outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
         Tensor<xpu, 2, IType> data =
           inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, 
req[0], 2);
+          if (!enable_lt) {
+            typedef int32_t index_t;
+            Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+              s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+              out_shape, req[0], 2, axes, out_stride, 1);
+          } else {
+            Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+              s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+              out_shape, req[0], 2, axes, out_stride, 1);
+          }
       } else {
         const int ndim = MXNET_SPECIAL_MAX_NDIM;
         Tensor<xpu, ndim, OType> out =
           outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), 
s);
         Tensor<xpu, ndim, IType> data =
           inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
-        Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
-          s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, 
req[0], ndim);
+          if (!enable_lt) {

Review comment:
       Here are the results after removing typedef and enable_lt conditional 
check
   
   | code version | cases                                       | avg    | p50  
  | p90    |
   
|--------------|---------------------------------------------|--------|--------|--------|
   |  master LT   | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 28.89   | 
26.91  | 34.51  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 28.70  | 
26.45  | 34.03  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 123.14 | 
121.88 | 125.35 |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 65.18  | 
61.52  | 74.58  |
   |              |                                             |        |      
  |        |
   | master no-LT | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 13.82  | 
13.24  | 16.83  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 13.46  | 
13.00  | 16.18  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 52.38  | 
51.36  | 53.51  |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 35.29  | 
38.65  | 40.23  |
   |              |                                             |        |      
  |        |
   | new LT       | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 15.91  | 
15.79  | 18.23  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 16.44  | 
17.50  | 17.97  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 124.17 | 
125.65 | 126.42 |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 37.33  | 
36.76  | 38.07  |
   |              |                                             |        |      
  |        |
   | new no-LT    | (1000, 1, 100, 1)->(1000, 10, 100, 5)       | 9.76   | 9.91 
  | 11.98  |
   |              | (1000, 1, 1, 100)->(1000, 10, 5, 100)       | 9.04   | 9.54 
  | 9.96  |
   |              | (1, 1, 1, 1)->(1000, 10, 100, 5)            | 58.40  | 
60.54  | 61.01  |
   |              | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 23.45  | 
23.11  | 26.79  |




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to