apeforest commented on a change in pull request #17882: Improve performance of
broadcast_axis
URL: https://github.com/apache/incubator-mxnet/pull/17882#discussion_r410367756
##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1058,22 +1058,50 @@ struct broadcast_kernel {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
in_shape,
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
out_shape,
const OpReqType req,
- const uint32_t ndim) {
- size_t in_stride = 1;
- size_t out_stride = 1;
- index_t idx = i;
- index_t in_idx = i;
- for (int iter = ndim - 1; iter >= 0; --iter) {
- size_t dim_idx = idx % out_shape[iter];
- in_idx -= dim_idx * out_stride;
- if (in_shape[iter] != 1) {
- in_idx += dim_idx * in_stride;
- }
- idx /= out_shape[iter];
- in_stride *= in_shape[iter];
- out_stride *= out_shape[iter];
- }
- KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
+ const uint32_t ndim,
+ const uint64_t *axes,
+ const uint64_t *out_stride,
+ const size_t no_axes) {
+ index_t idx = i;
+ index_t init_off = 0;
+ for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
+ size_t dim_idx = idx % in_shape[iter];
+ init_off += dim_idx * out_stride[iter];
+ idx /= in_shape[iter];
+ }
+ index_t stride_0, stride_1, stride_2;
+ switch (no_axes) {
Review comment:
Are we only supporting max of 3 axes here?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services