This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5fc254e369 Fix relax.ccl.scatter_from_worker0 assert (#17688)
5fc254e369 is described below

commit 5fc254e369ac43a4d8c5201cb7e6e038394638a0
Author: oskar-inceptron <[email protected]>
AuthorDate: Fri Feb 28 16:55:11 2025 +0100

    Fix relax.ccl.scatter_from_worker0 assert (#17688)
    
    The current code asserts that floormod(input_dims[0], num_shards) can
    not be proven to be true, which has two problems:
    - It is unclear what it means to prove floormod(..). Prove that
      floormod's return value is truthy, i.e. non-zero?
    - It always checks the 0th dimension of the input shape, but the
      dimension index to be sharded is given by the "axis" parameter.
    This commit fixes both of the above.
---
 src/relax/op/ccl/ccl.cc | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 092727cb51..c32cdc3aac 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -133,12 +133,12 @@ StructInfo InferStructInfoScatter(const Call& call, const 
BlockBuilder& ctx) {
   auto input_shape = input_sinfo->GetShape();
   CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should 
have defined shape.";
 
-  if (analyzer->CanProve(floormod(input_shape.value()[0], 
PrimExpr(num_workers))) != 0) {
+  if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], 
PrimExpr(num_workers)) != 0)) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "scatter_from_worker0 expects the size of axis 0 of 
input tensor to be "
-                        "divisible by the "
-                        "num_workers. However, the axis 0 of input tensor is "
-                     << input_shape.value() << " while num_workers is " << 
num_workers);
+                     << "scatter_from_worker0 expects the size of axis " << 
attrs->axis
+                     << " of input tensor to be divisible by the num_workers. 
However, axis "
+                     << attrs->axis << " of input tensor is " << 
input_shape.value()
+                     << " while num_workers is " << num_workers);
   }
 
   Array<PrimExpr> output_shape = input_shape.value();

Reply via email to