access2rohit commented on a change in pull request #17882:
URL: https://github.com/apache/incubator-mxnet/pull/17882#discussion_r411024747
##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1103,21 +1132,52 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs&
attrs,
out_shape[i] = 1;
}
}
+ uint64_t axes[dst_shape.ndim()], out_stride[dst_shape.ndim()];
+ int iter = dst_shape.ndim() - 1, i = 0;
+ out_stride[iter] = 1;
+ if (in_shape[iter] != dst_shape[iter]) {
+ axes[i] = iter;
+ i++;
+ }
+ --iter;
+ for (; iter >= 0; --iter) {
+ if (in_shape[iter] != dst_shape[iter]) {
+ axes[i] = iter;
+ i++;
+ }
+ out_stride[iter] = out_stride[iter+1] * dst_shape[iter+1];
+ }
if (dst_shape.ndim() == 2) {
Tensor<xpu, 2, OType> out =
outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, IType> data =
inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
- Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
- s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape,
req[0], 2);
+ if (!enable_lt) {
+ typedef int32_t index_t;
+ Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+ s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+ out_shape, req[0], 2, axes, out_stride, 1);
+ } else {
+ Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
+ s, data.shape_.Size(), data.dptr_, out.dptr_, in_shape,
+ out_shape, req[0], 2, axes, out_stride, 1);
+ }
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, OType> out =
outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(),
s);
Tensor<xpu, ndim, IType> data =
inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
- Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
- s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape,
req[0], ndim);
+ if (!enable_lt) {
Review comment:
@apeforest
Here are the results after removing typedef and enable_lt conditional
check(I haven't pushed that code into this yet yet). There is a slight
improvement in the results. But its not that significant.
| code version | cases | avg | p50
| p90 |
|--------------|---------------------------------------------|--------|--------|--------|
| master LT | (1000, 1, 100, 1)->(1000, 10, 100, 5) | 28.89 |
26.91 | 34.51 |
| | (1000, 1, 1, 100)->(1000, 10, 5, 100) | 28.70 |
26.45 | 34.03 |
| | (1, 1, 1, 1)->(1000, 10, 100, 5) | 123.14 |
121.88 | 125.35 |
| | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 65.18 |
61.52 | 74.58 |
| | | |
| |
| master no-LT | (1000, 1, 100, 1)->(1000, 10, 100, 5) | 13.82 |
13.24 | 16.83 |
| | (1000, 1, 1, 100)->(1000, 10, 5, 100) | 13.46 |
13.00 | 16.18 |
| | (1, 1, 1, 1)->(1000, 10, 100, 5) | 52.38 |
51.36 | 53.51 |
| | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 35.29 |
38.65 | 40.23 |
| | | |
| |
| new LT | (1000, 1, 100, 1)->(1000, 10, 100, 5) | 15.91 |
15.79 | 18.23 |
| | (1000, 1, 1, 100)->(1000, 10, 5, 100) | 16.44 |
17.50 | 17.97 |
| | (1, 1, 1, 1)->(1000, 10, 100, 5) | 124.17 |
125.65 | 126.42 |
| | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 37.33 |
36.76 | 38.07 |
| | | |
| |
| new no-LT | (1000, 1, 100, 1)->(1000, 10, 100, 5) | 9.76 | 9.91
| 11.98 |
| | (1000, 1, 1, 100)->(1000, 10, 5, 100) | 9.04 | 9.54
| 9.96 |
| | (1, 1, 1, 1)->(1000, 10, 100, 5) | 58.40 |
60.54 | 61.01 |
| | (1, 1000, 1, 100, 1)->(2, 1000, 10, 100, 5) | 23.45 |
23.11 | 26.79 |
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]