This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new db6d2059a2 [FIX][topi.scatter_nd] fixed shape equality assert by using 
analyzer to prove equality (#17537)
db6d2059a2 is described below

commit db6d2059a2d45d71c8c2cf91ff15dab4b8c30e74
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Fri Nov 22 13:57:52 2024 +0100

    [FIX][topi.scatter_nd] fixed shape equality assert by using analyzer to 
prove equality (#17537)
    
    * fixed assert by using analyzer to the prove equality
    
    * updated docs in Analyzer class
---
 python/tvm/arith/analyzer.py | 2 +-
 python/tvm/topi/scatter.py   | 5 ++++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 22555e0fb3..f8069a717d 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -218,7 +218,7 @@ class Analyzer:
         expr : PrimExpr
             The expression.
 
-        dom_map : Dict[Var, tvm.arith.IntSet]
+        dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet]
             The domain for variables to be relaxed.
 
         Returns
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index 799b3d1673..9cf19e2e61 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -18,9 +18,11 @@
 """ScatterND operator"""
 from tvm import te, tir  # hide redefinition of min and max
 from tvm.tir import expr
+from tvm.arith.analyzer import Analyzer
 
 
 def _verify_scatter_nd_inputs(data, indices, updates):
+    analyzer = Analyzer()
     mdim = int(indices.shape[0])
     assert mdim <= len(data.shape), (
         f"The first dimension of the indices ({mdim}) must be less than or 
equal to "
@@ -29,7 +31,8 @@ def _verify_scatter_nd_inputs(data, indices, updates):
     for i in range(len(indices.shape) - 1):
         if isinstance(indices.shape[i + 1], expr.Var) or 
isinstance(updates.shape[i], expr.Var):
             continue
-        assert indices.shape[i + 1] == updates.shape[i], (
+
+        assert analyzer.can_prove_equal(indices.shape[i + 1], 
updates.shape[i]), (
             f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal 
dimension of "
             f"updates[{i}] ({updates.shape[i]})."
         )

Reply via email to