masahi commented on a change in pull request #6759:
URL: https://github.com/apache/incubator-tvm/pull/6759#discussion_r511781645
##########
File path: include/tvm/topi/broadcast.h
##########
@@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
return tvm::te::compute(oshape, l, name, tag);
}
+inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor&
shape_tensor1,
+ const tvm::te::Tensor&
shape_tensor2,
+ std::string name =
"T_broadcast_shape_tensors",
+ std::string tag = kBroadcast) {
+ const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
+ const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
+ const auto out_rank = std::max<int32_t>(rank1, rank2);
+ const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));
+
+ auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
+ tvm::tir::Var index) -> PrimExpr {
+ if (rank < out_rank) {
+ // if the rank is smaller, dimension 1 is prepended according to
+ // the numpy broadcasting semantics.
+ return tvm::tir::Select(rank - (out_rank - index) < 0, one,
+ shape_tensor[rank - (out_rank - index)]);
+ } else {
+ // rank == out_rank, safe to index directly
+ return shape_tensor[index];
+ }
+ };
+
+ auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
+ auto index = ovars[0];
+ PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
+ PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
+ if (topi::detail::EqualCheck(one, dim1)) {
+ return dim2;
+ } else if (topi::detail::EqualCheck(one, dim2)) {
+ return dim1;
+ }
+ return tvm::max(dim1, dim2);
Review comment:
@tqchen Is it possible to generate `AssertStmt` here? I want add a
runtime assert that checks
```
dim1 == dim2 or dim1 == 1 or dim2 == 1
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]