This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new db6d2059a2 [FIX][topi.scatter_nd] fixed shape equality assert by using
analyzer to prove equality (#17537)
db6d2059a2 is described below
commit db6d2059a2d45d71c8c2cf91ff15dab4b8c30e74
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Fri Nov 22 13:57:52 2024 +0100
[FIX][topi.scatter_nd] fixed shape equality assert by using analyzer to
prove equality (#17537)
* fixed assert by using analyzer to the prove equality
* updated docs in Analyzer class
---
python/tvm/arith/analyzer.py | 2 +-
python/tvm/topi/scatter.py | 5 ++++-
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 22555e0fb3..f8069a717d 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -218,7 +218,7 @@ class Analyzer:
expr : PrimExpr
The expression.
- dom_map : Dict[Var, tvm.arith.IntSet]
+ dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index 799b3d1673..9cf19e2e61 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -18,9 +18,11 @@
"""ScatterND operator"""
from tvm import te, tir # hide redefinition of min and max
from tvm.tir import expr
+from tvm.arith.analyzer import Analyzer
def _verify_scatter_nd_inputs(data, indices, updates):
+ analyzer = Analyzer()
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or
equal to "
@@ -29,7 +31,8 @@ def _verify_scatter_nd_inputs(data, indices, updates):
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or
isinstance(updates.shape[i], expr.Var):
continue
- assert indices.shape[i + 1] == updates.shape[i], (
+
+ assert analyzer.can_prove_equal(indices.shape[i + 1],
updates.shape[i]), (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal
dimension of "
f"updates[{i}] ({updates.shape[i]})."
)