This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 44cc2b9de3 [Arith] Fix detect non-divisible iteration form like (x %
255) // 16 (#15665)
44cc2b9de3 is described below
commit 44cc2b9de38b76f576d4329c967782f4b6ce918c
Author: wrongtest <[email protected]>
AuthorDate: Mon Sep 18 16:17:48 2023 +0800
[Arith] Fix detect non-divisible iteration form like (x % 255) // 16
(#15665)
* fix detect non-divisible iteration form like (x % 255) // 16
* add required rule to prove divisibility of dynamic shape
---
src/arith/iter_affine_map.cc | 16 ++++++++++++----
src/arith/rewrite_simplify.cc | 10 ++++++++++
tests/python/unittest/test_arith_iter_affine_map.py | 2 ++
tests/python/unittest/test_arith_rewrite_simplify.py | 17 +++++++++++++++++
.../test_tir_transform_compact_buffer_region.py | 2 +-
5 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 607be0a83d..1c782f4546 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1898,10 +1898,18 @@ PrimExpr
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
// = floormod(sc2+t, c2)
// = floormod(floordiv(y, c1), c2)
// = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs,
c2=extent/rhs
- IterSplitExpr new_split(padded->source,
- /* lower_factor = */ padded->lower_factor * rhs,
- /* extent = */
analyzer_->Simplify(floordiv(padded->extent, rhs)),
- /* scale = */ padded->scale);
+ IterSplitExpr new_split;
+ if (CanProveDivisible(padded->extent, rhs)) {
+ new_split = IterSplitExpr(padded->source,
+ /* lower_factor = */ padded->lower_factor * rhs,
+ /* extent = */
analyzer_->Simplify(floordiv(padded->extent, rhs)),
+ /* scale = */ padded->scale);
+ } else {
+ new_split = IterSplitExpr(IterMark(padded, padded->extent),
+ /* lower_factor = */ rhs,
+ /* extent = */
analyzer_->Simplify(ceildiv(padded->extent, rhs)),
+ /* scale = */ make_const(rhs->dtype, 1));
+ }
auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6);
if (is_zero(new_base)) {
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 63becf8eb7..d5f946fca0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1053,6 +1053,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)),
ZeroWithTypeLike(y));
+ // x = ay + b, then (ay + b + (ny - ay - b) % y) % y -> (b + (-b) % y) % y
-> 0
+ TVM_TRY_REWRITE_IF(
+ matches_one_of(floormod(x + floormod(z, y), y), floormod(floormod(z,
y) + x, y)),
+ ZeroWithTypeLike(x), CanProveEqual(floormod(x.Eval() + z.Eval(),
y.Eval()), 0));
+ // x = ay + b, then (ay + b - (ay + b) % +-y) % y -> (b - b % +-y) % y -> 0
+ TVM_TRY_REWRITE_IF(
+ matches_one_of(floormod(x - floormod(x, z), y), floormod(floormod(x,
z) - x, y)),
+ ZeroWithTypeLike(x),
+ CanProveEqual(y.Eval() - z.Eval(), 0) || CanProveEqual(y.Eval() +
z.Eval(), 0));
+
if (floormod(x, c1).Match(ret)) {
int64_t c1val = c1.Eval()->value;
if (c1val > 0) {
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 1676855b31..912edcbced 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1051,6 +1051,8 @@ class TestPadding:
# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
({x: 3}, {flm(x, 16)}),
+ # (x % c1) // c2 is not proved as surjective if c1 % c2 != 0
+ ({x: 255}, {fld(flm(x, 255), 16)}),
)
def test_padding(self, positive_test_case):
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 0b0a43a7d3..5b06275422 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -605,6 +605,23 @@ class TestFloorModTwo(BaseCompare):
)
+class TestFloorModPadded(BaseCompare):
+ """Special-case simplifications for divisibility proof
+ such that (x - x % k) must be divisible by k
+ """
+
+ x, y = te.var("x"), te.var("y")
+ test_case = tvm.testing.parameter(
+ TestCase(flm(x - flm(x, 9), 9), 0),
+ TestCase(flm(x - flm(x, -9), 9), 0),
+ TestCase(flm(x + flm(-x, 9), 9), 0),
+ TestCase(flm(x + flm(8 * x, 9), 9), 0),
+ TestCase(flm(x - flm(x, y), y), 0),
+ TestCase(flm(x - flm(x, -y), y), 0),
+ TestCase(flm(x + flm(-x, y), y), 0),
+ )
+
+
class TestMinIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index d268403c1b..d5d5e0634e 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -999,7 +999,7 @@ class
TestDependentBufferIndicesOfPackedMatmul(BaseCompactTest):
) -> None:
for i0, i1 in T.grid(4, 1):
with T.block():
- C_local2 = T.alloc_buffer([1, 1, 15, 1000, 16],
dtype="float32", scope="local")
+ C_local2 = T.alloc_buffer([1, 1, 16, 1000, 16],
dtype="float32", scope="local")
C_local1 = T.alloc_buffer([255, 1000], dtype="float32",
scope="local")
for ax0, ax1, ax2 in T.grid(255, 1000, 64):
with T.block("matmul"):