This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc53a6c29a [Arith] Simplify the result of non-divisible floordiv
(#15881)
dc53a6c29a is described below
commit dc53a6c29a2b035d8d44d8983b1d7b121f09dbef
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Oct 7 09:23:54 2023 -0700
[Arith] Simplify the result of non-divisible floordiv (#15881)
---
src/arith/iter_affine_map.cc | 9 ++++++
.../python/unittest/test_arith_iter_affine_map.py | 34 ++++++++++++++++++++++
2 files changed, 43 insertions(+)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 1c782f4546..366784c04f 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1904,6 +1904,15 @@ PrimExpr
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
/* lower_factor = */ padded->lower_factor * rhs,
/* extent = */
analyzer_->Simplify(floordiv(padded->extent, rhs)),
/* scale = */ padded->scale);
+ } else if (is_one(padded->lower_factor) &&
+ analyzer_->CanProveEqual(padded->extent, padded->source->extent))
{
+ // floordiv(floormod(floordiv(iter, lower_factor), ext), c)
+ // = floordiv(iter, c)
+ // when lower_factor = 1 and ext = iter.extent
+ new_split = IterSplitExpr(padded->source,
+ /* lower_factor = */ rhs,
+ /* extent = */
analyzer_->Simplify(ceildiv(padded->extent, rhs)),
+ /* scale = */ padded->scale);
} else {
new_split = IterSplitExpr(IterMark(padded, padded->extent),
/* lower_factor = */ rhs,
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 912edcbced..3a10ec05ef 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1227,11 +1227,29 @@ def test_iter_map_simplify_unit_loop_order():
def assert_normalize_to_iter_sum(index, input_iters, args, base):
+ """Assert the result of arith.normalize_to_iter_sum is correct
+
+ Parameters
+ ----------
+ index : tvm.tir.PrimExpr
+ The index to be normalized
+ input_iters : Mapping[Var, Range]
+ The input iterators
+ args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
+ The expected result. Ordered list of args of the expected IterSumExpr.
Each arg can be
+ either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the
first element is the
+ iterator normalized to PrimExpr and the second element is the scale.
+ base : tvm.tir.PrimExpr
+ The expected base
+ """
res = tvm.arith.normalize_to_iter_sum(index, input_iters)
assert isinstance(res, tvm.arith.IterSumExpr)
assert len(res.args) == len(args)
for split, item in zip(res.args, args):
+ if isinstance(item, tvm.arith.IterSplitExpr):
+ tvm.ir.assert_structural_equal(split, item)
+ continue
tvm.testing.assert_prim_expr_equal(split.scale, item[1])
tvm.testing.assert_prim_expr_equal(
tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
@@ -1245,6 +1263,7 @@ def test_normalize_to_iter_sum():
z = tvm.tir.Var("z", "int64")
a = tvm.tir.Var("a", "int64")
n = tvm.tir.Var("n", "int64")
+ flm = tvm.tir.floormod
assert_normalize_to_iter_sum(
z + ((y + x * 4 + 2) * n) + 3,
@@ -1285,6 +1304,21 @@ def test_normalize_to_iter_sum():
0,
)
+ # non-divisible
+ assert_normalize_to_iter_sum(
+ x // 5,
+ var_dom([(x, 4096)]),
+ [
+ tvm.arith.IterSplitExpr(
+ tvm.arith.IterMark(x, 4096),
+ lower_factor=tvm.tir.const(5, "int64"),
+ extent=tvm.tir.const(820, "int64"),
+ scale=tvm.tir.const(1, "int64"),
+ )
+ ],
+ 0,
+ )
+
# iter simplify
assert_normalize_to_iter_sum(
z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),