This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e0105e488d [FIX] fix bug when normalize iter with different lower 
bounds (#17360)
e0105e488d is described below

commit e0105e488dd99d5e153428bc1d8c3dec0c324086
Author: Jiaqiang Liu <[email protected]>
AuthorDate: Sat Sep 14 21:16:07 2024 +0800

    [FIX] fix bug when normalize iter with different lower bounds (#17360)
    
    If an iter has been normalized with a lower bound, and then try to 
normalize with
    a new lower bound, the iter_min need to be updated only when the new lower 
bound
    is smaller than the original one.
    
    Co-authored-by: liujiaqiang <[email protected]>
---
 src/arith/iter_affine_map.cc                     |  2 +-
 tests/python/arith/test_arith_iter_affine_map.py | 21 +++++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 77b20fcdf2..d24c278f10 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -696,7 +696,7 @@ class IterMapRewriter : public ExprMutator {
       // the delta of iter_min when it is updated when the lower bound 
predicate is present
       PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0);
       if (predicate_induced_min.defined()) {
-        iter_min_delta = predicate_induced_min.value() - iter_min;
+        iter_min_delta = max(predicate_induced_min.value(), iter_min) - 
iter_min;
         iter_min = max(predicate_induced_min.value(), iter_min);
       }
       if (predicate_induced_max.defined()) {
diff --git a/tests/python/arith/test_arith_iter_affine_map.py 
b/tests/python/arith/test_arith_iter_affine_map.py
index f0e6f05adf..f34dce5c86 100644
--- a/tests/python/arith/test_arith_iter_affine_map.py
+++ b/tests/python/arith/test_arith_iter_affine_map.py
@@ -346,6 +346,27 @@ def test_predicate():
         predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j),
     )
 
+    # constraint with differnent lower bound
+    assert_iter_sum_pattern(
+        {
+            (i * 16 + j) // 23 * 8
+            + (i * 16 + j) % 23
+            - 15: (
+                64,
+                0,
+                1,
+                (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + 
tvm.tir.IntImm("int32", -15)),
+            )
+        },
+        var_dom([(i, 12), (j, 16)]),
+        predicate=tvm.tir.And(
+            tvm.tir.And(
+                i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 
16 + j) % 23)
+            ),
+            tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23),
+        ),
+    )
+
     # constraint on many disjoint fused iters, case 1
     # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
     # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)

Reply via email to