This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6578834bc [Arith] Simplify the output of InverseAffineIterMap (#11167)
c6578834bc is described below
commit c6578834bc34a8721844197df03e0cef83440adf
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Apr 29 10:36:38 2022 -0700
[Arith] Simplify the output of InverseAffineIterMap (#11167)
This PR simplifies the result of `InverseAffineIterMap` by assuming the
`output` param has the same range as the output range of the affine
transformation. For example, for iter map `i, j => i * 16 + j, i \in [0, 8), j
\in [0, 16)`, after this PR, the inverse will be `m => m // 16, m % 16, m \in
[0, 128)` instead of `m => (m // 16) % 8, m % 16`
---
include/tvm/arith/iter_affine_map.h | 2 ++
src/arith/iter_affine_map.cc | 8 ++++++--
tests/python/unittest/test_arith_iter_affine_map.py | 8 ++++----
3 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index ed59be32b6..f8371b1a61 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -309,6 +309,8 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>&
indices, const Map<Var, R
* the affine transformation specified by `iter_map` will be applied to
`outputs` and the result
* will be {l0: ((output_0*16) + output_1)}.
*
+ * The range of `outputs` should be the same as the output range of the affine
transmation.
+ *
* \sa DetectIterMap
*
* \param iter_map The bijective affine iter map.
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e7a73f4ea2..ec2680d8e6 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1680,8 +1680,12 @@ class InverseAffineIterMapTransformer {
CheckFusePattern(iter_map_expr);
for (size_t i = iter_map_expr->args.size(); i > 0; i--) {
const IterSplitExpr& split = iter_map_expr->args[i - 1];
- backprop_.Set(split,
- backprop_.at(split) + floormod(floordiv(input,
split->scale), split->extent));
+ PrimExpr prop_value = floordiv(input, split->scale);
+ // the first part has the same extent as the split expression, floormod
is not needed
+ if (i > 1) {
+ prop_value = floormod(prop_value, split->extent);
+ }
+ backprop_.Set(split, backprop_.at(split) + prop_value);
}
}
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 5beec1c08c..f77a250ede 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -848,7 +848,7 @@ def test_inverse_affine_iter_map():
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 2
- l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
+ l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16
l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
@@ -867,7 +867,7 @@ def test_inverse_affine_iter_map():
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 3
- l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
+ l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16
l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
l2_inverse = (
floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) *
4 + outputs[2]
@@ -887,8 +887,8 @@ def test_inverse_affine_iter_map():
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 1
- l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0],
8), 8)
- l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse,
4), 16)
+ l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8)
+ l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4)
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)