access2rohit commented on a change in pull request #17882:
URL: https://github.com/apache/incubator-mxnet/pull/17882#discussion_r444736360



##########
File path: src/operator/tensor/broadcast_reduce_op.h
##########
@@ -1058,22 +1065,58 @@ struct broadcast_kernel {
                                   mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
in_shape,
                                   mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
out_shape,
                                   const OpReqType req,
-                                  const uint32_t ndim) {
-    size_t in_stride = 1;
-    size_t out_stride = 1;
-    index_t idx = i;
-    index_t in_idx = i;
-    for (int iter = ndim - 1; iter >= 0; --iter) {
-      size_t dim_idx = idx % out_shape[iter];
-      in_idx -= dim_idx * out_stride;
-      if (in_shape[iter] != 1) {
-        in_idx += dim_idx * in_stride;
-      }
-      idx /= out_shape[iter];
-      in_stride *= in_shape[iter];
-      out_stride *= out_shape[iter];
-    }
-    KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx]));
+                                  const uint32_t ndim,
+                                  const size_t *axes,
+                                  const size_t *out_stride,
+                                  const int num_broadcast_axes) {
+        index_t idx = i;
+        index_t init_off = 0;
+        for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
+          size_t dim_idx = idx % in_shape[iter];
+          init_off += dim_idx * out_stride[iter];
+          idx /= in_shape[iter];
+        }
+        index_t stride_0, stride_1, stride_2;
+        // Each case is based on the number of axis to be broadcasted
+        // (1, 2 or 3) after merging axes.
+        switch (num_broadcast_axes) {
+          // when input shape is amogst one of the form
+          // [(x,1), (x,1,x), (1,x)]
+          // x can be any +ve number >=0 and they need not be equal to each 
other
+          case 1 :
+            stride_0 = out_stride[axes[0]];
+            for (int l=0; l < out_shape[axes[0]]; l++) {
+              KERNEL_ASSIGN(output[init_off + l*stride_0],
+                  req, OP::Map(input[i]));
+            }
+            break;
+          // when input shape is amogst one of the form
+          // [(x,1,x,1), (1,x,1,x), (x,1,x,1,x)]
+          // x can be any +ve number >1 or =0(the axis ) and they need not be 
equal to each other
+          case 2:
+            stride_1 = out_stride[axes[1]], stride_0 = out_stride[axes[0]];
+            for (int k=0; k < out_shape[axes[1]]; k++) {
+              for (int l=0; l < out_shape[axes[0]]; l++) {
+                KERNEL_ASSIGN(output[init_off + k*stride_1 + l*stride_0],
+                    req, OP::Map(input[i]));
+              }
+            }
+            break;
+          // when input shape is of the form [(1,x,1,x,1)] and
+          // x can be any +ve number >=0 and they need not be equal to each 
other
+          case 3:
+            stride_2 = out_stride[axes[2]], stride_1 = out_stride[axes[1]];
+            stride_0 = out_stride[axes[0]];
+            for (int j=0; j < out_shape[axes[2]]; j++) {
+              for (int k=0; k < out_shape[axes[1]]; k++) {
+                for (int l=0; l < out_shape[axes[0]]; l++) {
+                  KERNEL_ASSIGN(output[init_off + j*stride_2 + k*stride_1 + 
l*stride_0],
+                      req, OP::Map(input[i]));
+                }
+              }
+            }
+            break;
+        }

Review comment:
       Deafult case with LOG(FATAL) will cause build failure since `cout` 
cannot be used inside device code. This is due to the way kernel launches are 
written




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to