This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 731f13326d [ARITH] Fix canonical simplify for LE with incorrect range
assumptions (#18025)
731f13326d is described below
commit 731f13326d9ffbbd5b570f06b553890a26e16246
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Jun 1 13:33:58 2025 +0800
[ARITH] Fix canonical simplify for LE with incorrect range assumptions
(#18025)
Fix a bug in canonical simplification of less-than expressions where
the algorithm incorrectly assumed variables could have negative values
when simplifying expressions of the form `ax + b < c`.
The previous implementation checked if `-d < xn < d` before simplifying,
but this was incorrect when variables are constrained to non-negative
ranges. For example, with constraints `0 < x, y < 2` and expression
`2x + y < 8`, the algorithm would incorrectly check if `-2 < y < 2`
and then simplify to `x < 4`. However, when x=4 and y=-1, we get
2*4 + (-1) = 7 < 8, which satisfies the original constraint but
violates the intended variable bounds.
The fix changes the range check to `0 <= xn < d`, ensuring that
simplification only occurs when variables are properly bounded
from below at zero.
Co-authored-by: FeiyangChen
<[email protected]>
---
src/arith/canonical_simplify.cc | 6 +++---
src/runtime/pack_args.h | 2 +-
tests/python/arith/test_arith_canonical_simplify.py | 1 -
tests/python/dlight/test_gpu_low_batch_gemv.py | 8 ++++----
tests/python/dlight/test_gpu_matmul.py | 4 ++--
5 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 06f181030e..1b82e93eac 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
LTNode* op) {
// First convert a < b into a - b < 0
PrimExpr expr = this->CanonicalMutate(op->a - op->b);
// Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ...,
s{n-1}, c)
- // 1. if can prove -d < xn < d, then we can simplify
+ // 1. if can prove 0 <= xn < d, then we can simplify
// the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} *
(s{n-1}/d) < c/d,
// e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x <
2`
// 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d
% (m/d)
@@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
LTNode* op) {
ICHECK(extra->dtype == dtype);
PrimExpr normal_extra = extra->Normalize();
if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) &&
- this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) {
- // Case 1. -d < xn < d
+ this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) {
+ // Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
} else if (extra->args.size() == 1 &&
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index b77adda4c9..0068db51d5 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -134,7 +134,7 @@ enum ArgConvertCode {
};
inline ArgConvertCode GetArgConvertCode(DLDataType t) {
- ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic
function for now";
+ ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device
function for now";
if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
diff --git a/tests/python/arith/test_arith_canonical_simplify.py
b/tests/python/arith/test_arith_canonical_simplify.py
index 42f5b0ccd0..733d1d13b3 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -448,7 +448,6 @@ def test_simplify_le():
ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))
ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
- ck.verify(x * 8 + y - z < 16, x < 2)
n = te.size_var("n")
ck.verify(x * 8 + y < n, x * 8 + y < n)
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py
b/tests/python/dlight/test_gpu_low_batch_gemv.py
index 6341b7b0ae..ae07a3b731 100644
--- a/tests/python/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -136,7 +136,7 @@ def test_batch_decode_gemv():
with T.block("NT_matmul_intermediate_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 *
T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused *
T.int64(2) + ax1_fused_2)
- T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 <
batch_size)
+ T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 *
T.int64(4) + ax0 < batch_size)
T.reads(NT_matmul_intermediate_pad_local[v0,
T.int64(0), v1])
T.writes(NT_matmul_intermediate[v0,
T.int64(0), v1])
NT_matmul_intermediate[v0, T.int64(0), v1] =
NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
@@ -240,7 +240,7 @@ def test_batch_gemv():
with T.block("NT_matmul_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 *
T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused *
T.int64(2) + ax1_fused_2)
- T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 <
batch_size)
+ T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 *
T.int64(4) + ax0 < batch_size)
T.reads(NT_matmul_pad_local[v0, T.int64(0),
v1])
T.writes(NT_matmul[v0, T.int64(0), v1])
NT_matmul[v0, T.int64(0), v1] =
NT_matmul_pad_local[v0, T.int64(0), v1]
@@ -369,7 +369,7 @@ def test_small_spatial_axis():
with T.block("C_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 *
T.int64(4) + ax0)
v1 = T.axis.spatial(T.int64(8),
ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
- T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 <
batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused
% T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8))
+ T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 *
T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) +
ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 <
T.int64(8))
T.reads(C_pad_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_pad_local[v0, v1]
@@ -516,7 +516,7 @@ def test_outer_reduction():
with T.block("C_pad"):
v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0)
v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1)
- T.where((ax0_0 - (batch_size + 3) // 4 < 0 or
ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size)
+ T.where((ax0_0 - (batch_size + 3) // 4 < 0 or
ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size)
T.reads(C_pad_local[v0, 0, v1])
T.writes(C[v0, 0, v1])
C[v0, 0, v1] = C_pad_local[v0, 0, v1]
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 2fa61faf40..f27d9d370f 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -695,7 +695,7 @@ class TestMatmulAndroid(AndroidBeforeAfter):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(m, i0_i1_fused_0 *
T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v2 = T.axis.spatial(T.int64(4096), i2_0 *
T.int64(256) + i2_1 * T.int64(8) + ax1)
- T.where((i0_i1_fused_0 * T.int64(4) +
i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0
== T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) +
ax0 < m)
+ T.where((i0_i1_fused_0 * T.int64(4) +
i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0
* T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m)
T.reads(matmul_pad_local[v0, v1, v2])
T.writes(matmul[v0, v1, v2])
matmul[v0, v1, v2] = matmul_pad_local[v0,
v1, v2]
@@ -835,7 +835,7 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
v_ax0 = T.axis.spatial(T.int64(1),
T.int64(0))
v_ax1 = T.axis.spatial(seq_len,
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v_ax2 = T.axis.spatial(T.int64(12288),
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
- T.where((i0_i1_fused_0 * T.int64(4) +
i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or
i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 *
T.int64(16) + ax0 < seq_len)
+ T.where((i0_i1_fused_0 * T.int64(4) +
i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or
i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0))
and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len)
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2],
transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0,
v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] +
transformer_h_0_attn_c_attn_bias3[v_ax2]