This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc53a6c29a [Arith] Simplify the result of non-divisible floordiv 
(#15881)
dc53a6c29a is described below

commit dc53a6c29a2b035d8d44d8983b1d7b121f09dbef
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Oct 7 09:23:54 2023 -0700

    [Arith] Simplify the result of non-divisible floordiv (#15881)
---
 src/arith/iter_affine_map.cc                       |  9 ++++++
 .../python/unittest/test_arith_iter_affine_map.py  | 34 ++++++++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 1c782f4546..366784c04f 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1904,6 +1904,15 @@ PrimExpr 
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
                               /* lower_factor = */ padded->lower_factor * rhs,
                               /* extent = */ 
analyzer_->Simplify(floordiv(padded->extent, rhs)),
                               /* scale = */ padded->scale);
+  } else if (is_one(padded->lower_factor) &&
+             analyzer_->CanProveEqual(padded->extent, padded->source->extent)) 
{
+    // floordiv(floormod(floordiv(iter, lower_factor), ext), c)
+    // = floordiv(iter, c)
+    // when lower_factor = 1 and ext = iter.extent
+    new_split = IterSplitExpr(padded->source,
+                              /* lower_factor = */ rhs,
+                              /* extent = */ 
analyzer_->Simplify(ceildiv(padded->extent, rhs)),
+                              /* scale = */ padded->scale);
   } else {
     new_split = IterSplitExpr(IterMark(padded, padded->extent),
                               /* lower_factor = */ rhs,
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 912edcbced..3a10ec05ef 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -1227,11 +1227,29 @@ def test_iter_map_simplify_unit_loop_order():
 
 
 def assert_normalize_to_iter_sum(index, input_iters, args, base):
+    """Assert the result of arith.normalize_to_iter_sum is correct
+
+    Parameters
+    ----------
+    index : tvm.tir.PrimExpr
+        The index to be normalized
+    input_iters : Mapping[Var, Range]
+        The input iterators
+    args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
+        The expected result. Ordered list of args of the expected IterSumExpr. 
Each arg can be
+        either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the 
first element is the
+        iterator normalized to PrimExpr and the second element is the scale.
+    base : tvm.tir.PrimExpr
+        The expected base
+    """
     res = tvm.arith.normalize_to_iter_sum(index, input_iters)
 
     assert isinstance(res, tvm.arith.IterSumExpr)
     assert len(res.args) == len(args)
     for split, item in zip(res.args, args):
+        if isinstance(item, tvm.arith.IterSplitExpr):
+            tvm.ir.assert_structural_equal(split, item)
+            continue
         tvm.testing.assert_prim_expr_equal(split.scale, item[1])
         tvm.testing.assert_prim_expr_equal(
             tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]
@@ -1245,6 +1263,7 @@ def test_normalize_to_iter_sum():
     z = tvm.tir.Var("z", "int64")
     a = tvm.tir.Var("a", "int64")
     n = tvm.tir.Var("n", "int64")
+    flm = tvm.tir.floormod
 
     assert_normalize_to_iter_sum(
         z + ((y + x * 4 + 2) * n) + 3,
@@ -1285,6 +1304,21 @@ def test_normalize_to_iter_sum():
         0,
     )
 
+    # non-divisible
+    assert_normalize_to_iter_sum(
+        x // 5,
+        var_dom([(x, 4096)]),
+        [
+            tvm.arith.IterSplitExpr(
+                tvm.arith.IterMark(x, 4096),
+                lower_factor=tvm.tir.const(5, "int64"),
+                extent=tvm.tir.const(820, "int64"),
+                scale=tvm.tir.const(1, "int64"),
+            )
+        ],
+        0,
+    )
+
     # iter simplify
     assert_normalize_to_iter_sum(
         z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),

Reply via email to