This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1392e64e0b [Arith] Allow constant values in InverseAffineIterMap 
(#12026)
1392e64e0b is described below

commit 1392e64e0bd9f55238256f5feb95eb2af90b6b97
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jul 6 22:08:23 2022 -0700

    [Arith] Allow constant values in InverseAffineIterMap (#12026)
---
 src/arith/iter_affine_map.cc                        |  4 +++-
 tests/python/unittest/test_arith_iter_affine_map.py | 13 +++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e1d6d316b4..d2aa16ded1 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2163,7 +2163,9 @@ class InverseAffineIterMapTransformer {
    *        descending order of lower_factor.
    */
   void CheckFusePattern(const IterSumExpr sum_expr) {
-    ICHECK(sum_expr->args.size());
+    if (sum_expr->args.empty()) {
+      return;
+    }
     PrimExpr expected_scale = sum_expr->args.back()->scale;
     for (size_t i = sum_expr->args.size(); i > 0; i--) {
       ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, 
expected_scale));
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 472ecac44f..7bc5ead298 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -869,6 +869,19 @@ def test_inverse_affine_iter_map():
     assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
 
 
+def test_inverse_affine_map_trivial_iter():
+    analyzer = tvm.arith.Analyzer()
+    l0 = create_iter("l0", 64)
+    l1 = create_iter("l1", 64)
+    iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, 
l1])).indices
+    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
+    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+    # output_0 is expected to be constant and it is not included in the 
inverse map
+    assert len(res) == 2
+    assert analyzer.can_prove_equal(res[l0[0]], outputs[1])
+    assert analyzer.can_prove_equal(res[l1[0]], outputs[2])
+
+
 def test_free_variables():
     x = tvm.tir.Var("x", "int32")
     y = tvm.tir.Var("y", "int32")

Reply via email to