This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6578834bc [Arith] Simplify the output of InverseAffineIterMap (#11167)
c6578834bc is described below

commit c6578834bc34a8721844197df03e0cef83440adf
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Apr 29 10:36:38 2022 -0700

    [Arith] Simplify the output of InverseAffineIterMap (#11167)
    
    This PR simplifies the result of `InverseAffineIterMap` by assuming the 
`output` param has the same range as the output range of the affine 
transformation. For example, for iter map `i, j => i * 16 + j, i \in [0, 8), j 
\in [0, 16)`, after this PR, the inverse will be `m => m // 16, m % 16, m \in 
[0, 128)` instead of `m => (m // 16) % 8, m % 16`
---
 include/tvm/arith/iter_affine_map.h                 | 2 ++
 src/arith/iter_affine_map.cc                        | 8 ++++++--
 tests/python/unittest/test_arith_iter_affine_map.py | 8 ++++----
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
index ed59be32b6..f8371b1a61 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -309,6 +309,8 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& 
indices, const Map<Var, R
  * the affine transformation specified by `iter_map` will be applied to 
`outputs` and the result
  * will be {l0: ((output_0*16) + output_1)}.
  *
+ * The range of `outputs` should be the same as the output range of the affine 
transmation.
+ *
  * \sa DetectIterMap
  *
  * \param iter_map The bijective affine iter map.
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index e7a73f4ea2..ec2680d8e6 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1680,8 +1680,12 @@ class InverseAffineIterMapTransformer {
     CheckFusePattern(iter_map_expr);
     for (size_t i = iter_map_expr->args.size(); i > 0; i--) {
       const IterSplitExpr& split = iter_map_expr->args[i - 1];
-      backprop_.Set(split,
-                    backprop_.at(split) + floormod(floordiv(input, 
split->scale), split->extent));
+      PrimExpr prop_value = floordiv(input, split->scale);
+      // the first part has the same extent as the split expression, floormod 
is not needed
+      if (i > 1) {
+        prop_value = floormod(prop_value, split->extent);
+      }
+      backprop_.Set(split, backprop_.at(split) + prop_value);
     }
   }
 
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 5beec1c08c..f77a250ede 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -848,7 +848,7 @@ def test_inverse_affine_iter_map():
     outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
     res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
     assert len(res) == 2
-    l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
+    l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16
     l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
     assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
     assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
@@ -867,7 +867,7 @@ def test_inverse_affine_iter_map():
     outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
     res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
     assert len(res) == 3
-    l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
+    l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16
     l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
     l2_inverse = (
         floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 
4 + outputs[2]
@@ -887,8 +887,8 @@ def test_inverse_affine_iter_map():
     outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
     res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
     assert len(res) == 1
-    l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 
8), 8)
-    l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 
4), 16)
+    l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8)
+    l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4)
 
     assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
 

Reply via email to