This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e0105e488d [FIX] fix bug when normalize iter with different lower
bounds (#17360)
e0105e488d is described below
commit e0105e488dd99d5e153428bc1d8c3dec0c324086
Author: Jiaqiang Liu <[email protected]>
AuthorDate: Sat Sep 14 21:16:07 2024 +0800
[FIX] fix bug when normalize iter with different lower bounds (#17360)
If an iter has been normalized with a lower bound, and then try to
normalize with
a new lower bound, the iter_min need to be updated only when the new lower
bound
is smaller than the original one.
Co-authored-by: liujiaqiang <[email protected]>
---
src/arith/iter_affine_map.cc | 2 +-
tests/python/arith/test_arith_iter_affine_map.py | 21 +++++++++++++++++++++
2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 77b20fcdf2..d24c278f10 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -696,7 +696,7 @@ class IterMapRewriter : public ExprMutator {
// the delta of iter_min when it is updated when the lower bound
predicate is present
PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0);
if (predicate_induced_min.defined()) {
- iter_min_delta = predicate_induced_min.value() - iter_min;
+ iter_min_delta = max(predicate_induced_min.value(), iter_min) -
iter_min;
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
diff --git a/tests/python/arith/test_arith_iter_affine_map.py
b/tests/python/arith/test_arith_iter_affine_map.py
index f0e6f05adf..f34dce5c86 100644
--- a/tests/python/arith/test_arith_iter_affine_map.py
+++ b/tests/python/arith/test_arith_iter_affine_map.py
@@ -346,6 +346,27 @@ def test_predicate():
predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j),
)
+ # constraint with differnent lower bound
+ assert_iter_sum_pattern(
+ {
+ (i * 16 + j) // 23 * 8
+ + (i * 16 + j) % 23
+ - 15: (
+ 64,
+ 0,
+ 1,
+ (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 +
tvm.tir.IntImm("int32", -15)),
+ )
+ },
+ var_dom([(i, 12), (j, 16)]),
+ predicate=tvm.tir.And(
+ tvm.tir.And(
+ i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i *
16 + j) % 23)
+ ),
+ tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23),
+ ),
+ )
+
# constraint on many disjoint fused iters, case 1
# i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
# i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)