This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b5dae98ebf [Arith] Fix detect linear equation with uint var (#15558)
b5dae98ebf is described below
commit b5dae98ebf0f5987f1ea548a78a712b2e7896145
Author: Yuchao Zhang <[email protected]>
AuthorDate: Thu Aug 24 20:10:18 2023 +0800
[Arith] Fix detect linear equation with uint var (#15558)
fix detect linear equation with uint var
---
src/arith/detect_linear_equation.cc | 3 ++-
tests/python/unittest/test_arith_detect_linear_equation.py | 4 ++++
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/src/arith/detect_linear_equation.cc
b/src/arith/detect_linear_equation.cc
index 576ac1716e..4d3164cbd3 100644
--- a/src/arith/detect_linear_equation.cc
+++ b/src/arith/detect_linear_equation.cc
@@ -100,7 +100,8 @@ class LinearEqDetector : public
ExprFunctor<LinearEqEntry(const PrimExpr&, const
LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
- ret.coeff = make_const(op->dtype, 1);
+ auto dtype = op->dtype;
+ ret.coeff = make_const(DataType::Int(dtype.bits(), dtype.lanes()), 1);
} else {
ret.base = e;
}
diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py
b/tests/python/unittest/test_arith_detect_linear_equation.py
index cedb557829..829b101af3 100644
--- a/tests/python/unittest/test_arith_detect_linear_equation.py
+++ b/tests/python/unittest/test_arith_detect_linear_equation.py
@@ -43,6 +43,10 @@ def test_basic():
assert len(m) == 1
tvm.testing.assert_prim_expr_equal(m[0], b * 7)
+ c = te.var("c", "uint32")
+ m = tvm.arith.detect_linear_equation(128 - c, [c])
+ assert m[0].value == -1
+
def test_multivariate():
v = [te.var("v%d" % i) for i in range(4)]