This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 731f13326d [ARITH] Fix canonical simplify for LE with incorrect range 
assumptions (#18025)
731f13326d is described below

commit 731f13326d9ffbbd5b570f06b553890a26e16246
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Jun 1 13:33:58 2025 +0800

    [ARITH] Fix canonical simplify for LE with incorrect range assumptions 
(#18025)
    
    Fix a bug in canonical simplification of less-than expressions where
    the algorithm incorrectly assumed variables could have negative values
    when simplifying expressions of the form `ax + b < c`.
    
    The previous implementation checked if `-d < xn < d` before simplifying,
    but this was incorrect when variables are constrained to non-negative
    ranges. For example, with constraints `0 < x, y < 2` and expression
    `2x + y < 8`, the algorithm would incorrectly check if `-2 < y < 2`
    and then simplify to `x < 4`. However, when x=4 and y=-1, we get
    2*4 + (-1) = 7 < 8, which satisfies the original constraint but
    violates the intended variable bounds.
    
    The fix changes the range check to `0 <= xn < d`, ensuring that
    simplification only occurs when variables are properly bounded
    from below at zero.
    
    Co-authored-by: FeiyangChen 
<[email protected]>
---
 src/arith/canonical_simplify.cc                     | 6 +++---
 src/runtime/pack_args.h                             | 2 +-
 tests/python/arith/test_arith_canonical_simplify.py | 1 -
 tests/python/dlight/test_gpu_low_batch_gemv.py      | 8 ++++----
 tests/python/dlight/test_gpu_matmul.py              | 4 ++--
 5 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 06f181030e..1b82e93eac 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
   // First convert a < b into a - b < 0
   PrimExpr expr = this->CanonicalMutate(op->a - op->b);
   // Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., 
s{n-1}, c)
-  // 1. if can prove -d < xn < d, then we can simplify
+  // 1. if can prove 0 <= xn < d, then we can simplify
   //    the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * 
(s{n-1}/d) < c/d,
   //    e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 
2`
   // 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d 
% (m/d)
@@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
     ICHECK(extra->dtype == dtype);
     PrimExpr normal_extra = extra->Normalize();
     if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) &&
-        this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) {
-      // Case 1. -d < xn < d
+        this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) {
+      // Case 1. 0 <= xn < d
       divisible.CopyOnWrite()->DivideBy(gcd);
       return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
     } else if (extra->args.size() == 1 &&
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index b77adda4c9..0068db51d5 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -134,7 +134,7 @@ enum ArgConvertCode {
 };
 
 inline ArgConvertCode GetArgConvertCode(DLDataType t) {
-  ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic 
function for now";
+  ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device 
function for now";
   if (t.code == kDLInt) {
     if (t.bits == 64U) return INT64_TO_INT64;
     if (t.bits == 32U) return INT64_TO_INT32;
diff --git a/tests/python/arith/test_arith_canonical_simplify.py 
b/tests/python/arith/test_arith_canonical_simplify.py
index 42f5b0ccd0..733d1d13b3 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -448,7 +448,6 @@ def test_simplify_le():
     ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))
 
     ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
-    ck.verify(x * 8 + y - z < 16, x < 2)
 
     n = te.size_var("n")
     ck.verify(x * 8 + y < n, x * 8 + y < n)
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py 
b/tests/python/dlight/test_gpu_low_batch_gemv.py
index 6341b7b0ae..ae07a3b731 100644
--- a/tests/python/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -136,7 +136,7 @@ def test_batch_decode_gemv():
                             with T.block("NT_matmul_intermediate_pad"):
                                 v0 = T.axis.spatial(batch_size, ax0_0 * 
T.int64(4) + ax0)
                                 v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * 
T.int64(2) + ax1_fused_2)
-                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < 
batch_size)
+                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * 
T.int64(4) + ax0 < batch_size)
                                 T.reads(NT_matmul_intermediate_pad_local[v0, 
T.int64(0), v1])
                                 T.writes(NT_matmul_intermediate[v0, 
T.int64(0), v1])
                                 NT_matmul_intermediate[v0, T.int64(0), v1] = 
NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
@@ -240,7 +240,7 @@ def test_batch_gemv():
                             with T.block("NT_matmul_pad"):
                                 v0 = T.axis.spatial(batch_size, ax0_0 * 
T.int64(4) + ax0)
                                 v1 = T.axis.spatial(T.int64(4096), 
u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * 
T.int64(2) + ax1_fused_2)
-                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < 
batch_size)
+                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * 
T.int64(4) + ax0 < batch_size)
                                 T.reads(NT_matmul_pad_local[v0, T.int64(0), 
v1])
                                 T.writes(NT_matmul[v0, T.int64(0), v1])
                                 NT_matmul[v0, T.int64(0), v1] = 
NT_matmul_pad_local[v0, T.int64(0), v1]
@@ -369,7 +369,7 @@ def test_small_spatial_axis():
                             with T.block("C_pad"):
                                 v0 = T.axis.spatial(batch_size, ax0_0 * 
T.int64(4) + ax0)
                                 v1 = T.axis.spatial(T.int64(8), 
ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
-                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < 
batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused 
% T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8))
+                                T.where((ax0_0 - (batch_size + T.int64(3)) // 
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * 
T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + 
ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < 
T.int64(8))
                                 T.reads(C_pad_local[v0, v1])
                                 T.writes(C[v0, v1])
                                 C[v0, v1] = C_pad_local[v0, v1]
@@ -516,7 +516,7 @@ def test_outer_reduction():
                         with T.block("C_pad"):
                             v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
                             v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1)
-                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or 
ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
+                            T.where((ax0_0 - (batch_size + 3) // 4 < 0 or 
ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size)
                             T.reads(C_pad_local[v0, 0, v1])
                             T.writes(C[v0, 0, v1])
                             C[v0, 0, v1] = C_pad_local[v0, 0, v1]
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 2fa61faf40..f27d9d370f 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -695,7 +695,7 @@ class TestMatmulAndroid(AndroidBeforeAfter):
                                     v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                     v1 = T.axis.spatial(m, i0_i1_fused_0 * 
T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
                                     v2 = T.axis.spatial(T.int64(4096), i2_0 * 
T.int64(256) + i2_1 * T.int64(8) + ax1)
-                                    T.where((i0_i1_fused_0 * T.int64(4) + 
i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 
== T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + 
ax0 < m)
+                                    T.where((i0_i1_fused_0 * T.int64(4) + 
i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 
* T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and 
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m)
                                     T.reads(matmul_pad_local[v0, v1, v2])
                                     T.writes(matmul[v0, v1, v2])
                                     matmul[v0, v1, v2] = matmul_pad_local[v0, 
v1, v2]
@@ -835,7 +835,7 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
                                     v_ax0 = T.axis.spatial(T.int64(1), 
T.int64(0))
                                     v_ax1 = T.axis.spatial(seq_len, 
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
                                     v_ax2 = T.axis.spatial(T.int64(12288), 
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
-                                    T.where((i0_i1_fused_0 * T.int64(4) + 
i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or 
i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * 
T.int64(16) + ax0 < seq_len)
+                                    T.where((i0_i1_fused_0 * T.int64(4) + 
i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or 
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) 
and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len)
                                     
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], 
transformer_h_0_attn_c_attn_bias3[v_ax2])
                                     
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                                     T_add_intermediate_intermediate[v_ax0, 
v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + 
transformer_h_0_attn_c_attn_bias3[v_ax2]

Reply via email to