This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fb2315a266 [Fix][Arith] Analyzer simplification starts with canonical 
(#13875)
fb2315a266 is described below

commit fb2315a266e01c2c1a0cf6fdcde326ea393387c2
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 3 09:34:38 2023 -0500

    [Fix][Arith] Analyzer simplification starts with canonical (#13875)
    
    This PR updates the order of arithmetic analyzer simplification, by
    adding a stage of canonical simplification at the very beginning so
    that every simplification always starts with a canonical round. This
    is because the rewrite simplification may destroy some PrimExpr property
    that the canonical simplification can make use of. Therefore, adding
    the canonical one in the front can maximize the use of canonical
    simplification.
---
 src/arith/analyzer.cc                              |  4 +++
 src/arith/canonical_simplify.cc                    | 12 ++++---
 .../unittest/test_arith_canonical_simplify.py      | 14 ++++++++
 tests/python/unittest/test_arith_intset.py         | 11 ++-----
 tests/python/unittest/test_arith_simplify.py       | 38 ++++++++++++++++++++++
 tests/python/unittest/test_tir_buffer.py           |  2 +-
 .../python/unittest/test_tir_schedule_analysis.py  |  6 ++--
 7 files changed, 70 insertions(+), 17 deletions(-)

diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 921f8ac709..4714cf1df5 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -129,6 +129,10 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
 PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
   PrimExpr res = expr;
 
+  // Always starts with a canonical simplification, as some structural property
+  // of an expression might be destroyed by rewrite simplification.
+  res = this->canonical_simplify(res);
+
   for (int i = 0; i < steps; ++i) {
     if (tir::is_const_int(res)) {
       return res;
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 39d626aaf2..11fb041511 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -335,6 +335,8 @@ class SumExprNode : public CanonicalExprNode {
    * \return whether the cast can be safely pushed to children
    */
   bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
+    bool is_min_value = dtype.bits() == 64 ? base == 
std::numeric_limits<int64_t>::lowest()
+                                           : base == -(1LL << (dtype.bits() - 
1));
     // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
     // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
     // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
@@ -351,7 +353,7 @@ class SumExprNode : public CanonicalExprNode {
         }
       }
     }
-    if (base > 0) {
+    if (base > 0 || is_min_value) {
       res = res + make_const(dtype, base);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
@@ -366,7 +368,7 @@ class SumExprNode : public CanonicalExprNode {
         }
       }
     }
-    if (base < 0) {
+    if (base < 0 && !is_min_value) {
       res = res - make_const(dtype, -base);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
@@ -497,6 +499,8 @@ class SumExprNode : public CanonicalExprNode {
     return args;
   }
   static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& 
args, int64_t base) {
+    bool is_min_value = dtype.bits() == 64 ? base == 
std::numeric_limits<int64_t>::lowest()
+                                           : base == -(1LL << (dtype.bits() - 
1));
     // Positive scales first
     PrimExpr res = make_const(dtype, 0);
     for (size_t i = 0; i < args.size(); ++i) {
@@ -504,7 +508,7 @@ class SumExprNode : public CanonicalExprNode {
         res = res + args[i]->Normalize();
       }
     }
-    if (base > 0) {
+    if (base > 0 || is_min_value) {
       res = res + make_const(dtype, base);
     }
     // negative scales follows using sub.
@@ -513,7 +517,7 @@ class SumExprNode : public CanonicalExprNode {
         res = res - args[i]->NormalizeWithScale(-1);
       }
     }
-    if (base < 0) {
+    if (base < 0 && !is_min_value) {
       res = res - make_const(dtype, -base);
     }
     return res;
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py 
b/tests/python/unittest/test_arith_canonical_simplify.py
index 9db3035fd9..5914305600 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -372,5 +372,19 @@ def test_simplify_cast():
     ck.verify(res, 2)
 
 
+def test_simplify_normalize_min_value_expr():
+    ck = CanonicalChecker()
+    x = te.var("x", "int32")
+
+    ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32"))
+    ck.verify(te.min_value("int32") + x == 0, False)
+    ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32"))
+    ck.verify(0 == te.min_value("int32") + x, False)
+    ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32"))
+    ck.verify(x + te.min_value("int32") == 0, False)
+    ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32"))
+    ck.verify(0 == x + te.min_value("int32"), False)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_arith_intset.py 
b/tests/python/unittest/test_arith_intset.py
index 24228fb527..da3fd94f81 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -182,7 +182,6 @@ def check_region_bound(expect_region, var_dom, mode, 
predicate=None):
                 expect_begin, expect_end = expect_desc[binding]
                 result_begin = analyzer.simplify(intset.min_value, 3)
                 result_end = analyzer.simplify(intset.max_value + 1, 3)
-                print(result_end)
                 assert analyzer.can_prove_equal(
                     result_begin - expect_begin, 0
                 ), f"{result_begin} vs {expect_begin}"
@@ -306,10 +305,7 @@ def test_region_lower_bound_for_non_perfect_tile():
             + h2: {
                 (): (
                     tvm.tir.max(h3 * 8, 1),
-                    tvm.tir.max(h3 * 8, 1)
-                    - tvm.tir.max(h3 * 8, 214)
-                    - tvm.tir.max(1 - h3 * 8, 0)
-                    + 224,
+                    tvm.tir.min(0, h3 * 8 - 214) + 224,
                 ),
                 ((h3, 0),): (1, 10),  # h3 == 0: region is [1, 10)
                 ((h3, 10),): (h3 * 8, h3 * 8 + 10),  # 0 < h3 <= 26: region is 
[h3 * 8, h3 * 8 + 10)
@@ -333,10 +329,7 @@ def test_region_lower_bound_for_non_perfect_tile():
             + h1: {
                 (): (
                     tvm.tir.max(h3 * 8, 1),
-                    tvm.tir.max(h3 * 8, 1)
-                    - tvm.tir.max(h3 * 8, 214)
-                    - tvm.tir.max(1 - h3 * 8, 0)
-                    + 224,
+                    tvm.tir.min(0, h3 * 8 - 214) + 224,
                 ),
                 ((h3, 0),): (1, 10),
                 ((h3, 10),): (h3 * 8, h3 * 8 + 10),
diff --git a/tests/python/unittest/test_arith_simplify.py 
b/tests/python/unittest/test_arith_simplify.py
new file mode 100644
index 0000000000..aa9d5179aa
--- /dev/null
+++ b/tests/python/unittest/test_arith_simplify.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import tir
+
+
+def test_simplify_reshape_flattened_index():
+    ana = tvm.arith.Analyzer()
+
+    i0 = tir.Var("i0", "int64")
+    i1 = tir.Var("i1", "int64")
+    ana.bind(i0, tvm.ir.Range(0, 8))
+    ana.bind(i1, tvm.ir.Range(0, 3))
+
+    i_flattened = i0 * 3 + i1
+    assert tvm.ir.structural_equal(
+        ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + 
(i_flattened) % 4),
+        i_flattened,
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_buffer.py 
b/tests/python/unittest/test_tir_buffer.py
index 55c8316739..95ad81db88 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -150,7 +150,7 @@ def test_buffer_index_merge_mult_mod():
     index_simplified = A.offset_of(
         (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + 
idxm(k0, k1))
     )
-    index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
+    index_direct = A.offset_of((0, idxm(k0, idxd(k1, s)) + idxm(k0, k1)))
     assert_simplified_equal(index_simplified, index_direct)
     # Test Case3
     index_simplified = A.offset_of(
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py 
b/tests/python/unittest/test_tir_schedule_analysis.py
index 38bd4bba14..349c4734c9 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -126,7 +126,7 @@ def test_suggest_index_map_winograd():
             floordiv(i0, 2),
             floordiv(i1, 2),
             floormod(i0, 2),
-            floormod(((i1 * 4) + floordiv(i2, 32)), 8),
+            floormod(i1, 2) * 4 + floordiv(i2, 32),
             floormod(i2, 32),
             floordiv(i3, 32),
             floormod(i3, 32),
@@ -137,8 +137,8 @@ def test_suggest_index_map_winograd():
     expected_inverse_index_map = IndexMap.from_func(
         lambda i0, i1, i2, i3, i4, i5, i6: (
             ((i0 * 2) + i2),
-            ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)),
-            floormod(((i3 * 32) + i4), 128),
+            i1 * 2 + floordiv(i3, 4),
+            floormod(i3, 4) * 32 + i4,
             ((i5 * 32) + i6),
         )
     )

Reply via email to