This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5fc254e369 Fix relax.ccl.scatter_from_worker0 assert (#17688)
5fc254e369 is described below
commit 5fc254e369ac43a4d8c5201cb7e6e038394638a0
Author: oskar-inceptron <[email protected]>
AuthorDate: Fri Feb 28 16:55:11 2025 +0100
Fix relax.ccl.scatter_from_worker0 assert (#17688)
The current code asserts that floormod(input_dims[0], num_shards) can
not be proven to be true, which has two problems:
- It is unclear what it means to prove floormod(..). Prove that
floormod's return value is truthy, i.e. non-zero?
- It always checks the 0th dimension of the input shape, but the
dimension index to be sharded is given by the "axis" parameter.
This commit fixes both of the above.
---
src/relax/op/ccl/ccl.cc | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 092727cb51..c32cdc3aac 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -133,12 +133,12 @@ StructInfo InferStructInfoScatter(const Call& call, const
BlockBuilder& ctx) {
auto input_shape = input_sinfo->GetShape();
CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should
have defined shape.";
- if (analyzer->CanProve(floormod(input_shape.value()[0],
PrimExpr(num_workers))) != 0) {
+ if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis],
PrimExpr(num_workers)) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "scatter_from_worker0 expects the size of axis 0 of
input tensor to be "
- "divisible by the "
- "num_workers. However, the axis 0 of input tensor is "
- << input_shape.value() << " while num_workers is " <<
num_workers);
+ << "scatter_from_worker0 expects the size of axis " <<
attrs->axis
+ << " of input tensor to be divisible by the num_workers.
However, axis "
+ << attrs->axis << " of input tensor is " <<
input_shape.value()
+ << " while num_workers is " << num_workers);
}
Array<PrimExpr> output_shape = input_shape.value();