This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1392e64e0b [Arith] Allow constant values in InverseAffineIterMap
(#12026)
1392e64e0b is described below
commit 1392e64e0bd9f55238256f5feb95eb2af90b6b97
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jul 6 22:08:23 2022 -0700
[Arith] Allow constant values in InverseAffineIterMap (#12026)
---
src/arith/iter_affine_map.cc | 4 +++-
tests/python/unittest/test_arith_iter_affine_map.py | 13 +++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e1d6d316b4..d2aa16ded1 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2163,7 +2163,9 @@ class InverseAffineIterMapTransformer {
* descending order of lower_factor.
*/
void CheckFusePattern(const IterSumExpr sum_expr) {
- ICHECK(sum_expr->args.size());
+ if (sum_expr->args.empty()) {
+ return;
+ }
PrimExpr expected_scale = sum_expr->args.back()->scale;
for (size_t i = sum_expr->args.size(); i > 0; i--) {
ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale,
expected_scale));
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 472ecac44f..7bc5ead298 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -869,6 +869,19 @@ def test_inverse_affine_iter_map():
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
+def test_inverse_affine_map_trivial_iter():
+ analyzer = tvm.arith.Analyzer()
+ l0 = create_iter("l0", 64)
+ l1 = create_iter("l1", 64)
+ iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0,
l1])).indices
+ outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
+ res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+ # output_0 is expected to be constant and it is not included in the
inverse map
+ assert len(res) == 2
+ assert analyzer.can_prove_equal(res[l0[0]], outputs[1])
+ assert analyzer.can_prove_equal(res[l1[0]], outputs[2])
+
+
def test_free_variables():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")