This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b5dae98ebf [Arith] Fix detect linear equation with uint var (#15558)
b5dae98ebf is described below

commit b5dae98ebf0f5987f1ea548a78a712b2e7896145
Author: Yuchao Zhang <[email protected]>
AuthorDate: Thu Aug 24 20:10:18 2023 +0800

    [Arith] Fix detect linear equation with uint var (#15558)
    
    fix detect linear equation with uint var
---
 src/arith/detect_linear_equation.cc                        | 3 ++-
 tests/python/unittest/test_arith_detect_linear_equation.py | 4 ++++
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/arith/detect_linear_equation.cc 
b/src/arith/detect_linear_equation.cc
index 576ac1716e..4d3164cbd3 100644
--- a/src/arith/detect_linear_equation.cc
+++ b/src/arith/detect_linear_equation.cc
@@ -100,7 +100,8 @@ class LinearEqDetector : public 
ExprFunctor<LinearEqEntry(const PrimExpr&, const
   LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
     LinearEqEntry ret;
     if (op == var_.get()) {
-      ret.coeff = make_const(op->dtype, 1);
+      auto dtype = op->dtype;
+      ret.coeff = make_const(DataType::Int(dtype.bits(), dtype.lanes()), 1);
     } else {
       ret.base = e;
     }
diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py 
b/tests/python/unittest/test_arith_detect_linear_equation.py
index cedb557829..829b101af3 100644
--- a/tests/python/unittest/test_arith_detect_linear_equation.py
+++ b/tests/python/unittest/test_arith_detect_linear_equation.py
@@ -43,6 +43,10 @@ def test_basic():
     assert len(m) == 1
     tvm.testing.assert_prim_expr_equal(m[0], b * 7)
 
+    c = te.var("c", "uint32")
+    m = tvm.arith.detect_linear_equation(128 - c, [c])
+    assert m[0].value == -1
+
 
 def test_multivariate():
     v = [te.var("v%d" % i) for i in range(4)]

Reply via email to