t-vi commented on a change in pull request #6759:
URL: https://github.com/apache/incubator-tvm/pull/6759#discussion_r511757403



##########
File path: include/tvm/topi/broadcast.h
##########
@@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
   return tvm::te::compute(oshape, l, name, tag);
 }
 
+inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& 
shape_tensor1,
+                                               const tvm::te::Tensor& 
shape_tensor2,
+                                               std::string name = 
"T_broadcast_shape_tensors",
+                                               std::string tag = kBroadcast) {
+  const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
+  const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
+  const auto out_rank = std::max<int32_t>(rank1, rank2);
+  const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));
+
+  auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
+                        tvm::tir::Var index) -> PrimExpr {
+    if (rank < out_rank) {
+      // if the rank is smaller, dimension 1 is prepended according to
+      // the numpy broadcasting semantics.
+      return tvm::tir::Select(rank - (out_rank - index) < 0, one,
+                              shape_tensor[rank - (out_rank - index)]);
+    } else {
+      // rank == out_rank, safe to index directly
+      return shape_tensor[index];
+    }
+  };
+
+  auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
+    auto index = ovars[0];
+    PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
+    PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
+    if (topi::detail::EqualCheck(one, dim1)) {
+      return dim2;
+    } else if (topi::detail::EqualCheck(one, dim2)) {
+      return dim1;
+    }
+    return tvm::max(dim1, dim2);

Review comment:
       Two comments (I'd lean to not them not needing to be addressed in this 
PR):
   - Does this (with EqualCheck and C++ if) work as expected on dynamic shapes? 
Should it?
   - While this seems to work for valid broadcasting (save the potential 
dynamic caveat), it does fail to reject invalid broadcasting, i.e. when none of 
the dims is 1 but they're different. Of course, this might be intentional to 
cover dynamic (if they aren't covered in the two cases above, but it might lead 
to funny error messages etc., so I think it would be neat to have a comment of 
what each code path is handling).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to