This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 44cc2b9de3 [Arith] Fix detect non-divisible iteration form like (x % 
255) // 16 (#15665)
44cc2b9de3 is described below

commit 44cc2b9de38b76f576d4329c967782f4b6ce918c
Author: wrongtest <[email protected]>
AuthorDate: Mon Sep 18 16:17:48 2023 +0800

    [Arith] Fix detect non-divisible iteration form like (x % 255) // 16 
(#15665)
    
    * fix detect non-divisible iteration form like (x % 255) // 16
    
    * add required rule to prove divisibility of dynamic shape
---
 src/arith/iter_affine_map.cc                            | 16 ++++++++++++----
 src/arith/rewrite_simplify.cc                           | 10 ++++++++++
 tests/python/unittest/test_arith_iter_affine_map.py     |  2 ++
 tests/python/unittest/test_arith_rewrite_simplify.py    | 17 +++++++++++++++++
 .../test_tir_transform_compact_buffer_region.py         |  2 +-
 5 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 607be0a83d..1c782f4546 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1898,10 +1898,18 @@ PrimExpr 
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
   // = floormod(sc2+t, c2)
   // = floormod(floordiv(y, c1), c2)
   // = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs, 
c2=extent/rhs
-  IterSplitExpr new_split(padded->source,
-                          /* lower_factor = */ padded->lower_factor * rhs,
-                          /* extent = */ 
analyzer_->Simplify(floordiv(padded->extent, rhs)),
-                          /* scale = */ padded->scale);
+  IterSplitExpr new_split;
+  if (CanProveDivisible(padded->extent, rhs)) {
+    new_split = IterSplitExpr(padded->source,
+                              /* lower_factor = */ padded->lower_factor * rhs,
+                              /* extent = */ 
analyzer_->Simplify(floordiv(padded->extent, rhs)),
+                              /* scale = */ padded->scale);
+  } else {
+    new_split = IterSplitExpr(IterMark(padded, padded->extent),
+                              /* lower_factor = */ rhs,
+                              /* extent = */ 
analyzer_->Simplify(ceildiv(padded->extent, rhs)),
+                              /* scale = */ make_const(rhs->dtype, 1));
+  }
 
   auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6);
   if (is_zero(new_base)) {
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 63becf8eb7..d5f946fca0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1053,6 +1053,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
 
     TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), 
ZeroWithTypeLike(y));
 
+    // x = ay + b, then (ay + b + (ny - ay - b) % y) % y -> (b + (-b) % y) % y 
-> 0
+    TVM_TRY_REWRITE_IF(
+        matches_one_of(floormod(x + floormod(z, y), y), floormod(floormod(z, 
y) + x, y)),
+        ZeroWithTypeLike(x), CanProveEqual(floormod(x.Eval() + z.Eval(), 
y.Eval()), 0));
+    // x = ay + b, then (ay + b - (ay + b) % +-y) % y -> (b - b % +-y) % y -> 0
+    TVM_TRY_REWRITE_IF(
+        matches_one_of(floormod(x - floormod(x, z), y), floormod(floormod(x, 
z) - x, y)),
+        ZeroWithTypeLike(x),
+        CanProveEqual(y.Eval() - z.Eval(), 0) || CanProveEqual(y.Eval() + 
z.Eval(), 0));
+
     if (floormod(x, c1).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       if (c1val > 0) {
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 1676855b31..912edcbced 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1051,6 +1051,8 @@ class TestPadding:
         # original extent is smaller than the divident
         # it is not surjective wrt to the region [0, 16)
         ({x: 3}, {flm(x, 16)}),
+        # (x % c1) // c2 is not proved as surjective if c1 % c2 != 0
+        ({x: 255}, {fld(flm(x, 255), 16)}),
     )
 
     def test_padding(self, positive_test_case):
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py 
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 0b0a43a7d3..5b06275422 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -605,6 +605,23 @@ class TestFloorModTwo(BaseCompare):
     )
 
 
+class TestFloorModPadded(BaseCompare):
+    """Special-case simplifications for divisibility proof
+    such that (x - x % k) must be divisible by k
+    """
+
+    x, y = te.var("x"), te.var("y")
+    test_case = tvm.testing.parameter(
+        TestCase(flm(x - flm(x, 9), 9), 0),
+        TestCase(flm(x - flm(x, -9), 9), 0),
+        TestCase(flm(x + flm(-x, 9), 9), 0),
+        TestCase(flm(x + flm(8 * x, 9), 9), 0),
+        TestCase(flm(x - flm(x, y), y), 0),
+        TestCase(flm(x - flm(x, -y), y), 0),
+        TestCase(flm(x + flm(-x, y), y), 0),
+    )
+
+
 class TestMinIndex(BaseCompare):
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     test_case = tvm.testing.parameter(
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index d268403c1b..d5d5e0634e 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -999,7 +999,7 @@ class 
TestDependentBufferIndicesOfPackedMatmul(BaseCompactTest):
     ) -> None:
         for i0, i1 in T.grid(4, 1):
             with T.block():
-                C_local2 = T.alloc_buffer([1, 1, 15, 1000, 16], 
dtype="float32", scope="local")
+                C_local2 = T.alloc_buffer([1, 1, 16, 1000, 16], 
dtype="float32", scope="local")
                 C_local1 = T.alloc_buffer([255, 1000], dtype="float32", 
scope="local")
                 for ax0, ax1, ax2 in T.grid(255, 1000, 64):
                     with T.block("matmul"):

Reply via email to