This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f0f301c6f [TIR][Analyzer] Simplify `x==x` expressions for all dtypes 
(#17158)
9f0f301c6f is described below

commit 9f0f301c6f6de7548c6b2026bcb51590e0881ac5
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jul 24 08:24:15 2024 -0500

    [TIR][Analyzer] Simplify `x==x` expressions for all dtypes (#17158)
    
    * [TIR][Analyzer] Simplify `x==x` expressions for all dtypes
    
    Prior to this commit, there was no rule to simplify `x == x` into
    `True`.  In some cases, despite not having an explicit rewrite rule in
    `RewriteSimplifier`, the `RewriteSimplifier::CanProve` function would
    check if `x-x` simplifies to zero, relying on the rewrite rules used
    for `tir::Sub`.  However, the rule to rewrite `x-x` into zero was only
    enabled for `int32`, `int64`, and floating-point types, so relying on
    this behavior was inconsistent.
    
    This commit updates the rewrite rules for both `tir::EQ` and
    `tir::Sub` to check for simplification of `x-x` or `x==x`, regardless
    of the datatype.  This change preserves the fast-path for index
    data-types, in which `int32` and `int64` expressions may be simplified
    without checking for side effects.  For all other dtypes, the
    cancellation only applies when evaluating `x` has no side effects.
    
    * Add comment about simplifications of NaN/Inf
---
 src/arith/rewrite_simplify.cc                     | 21 ++++++++++++-
 tests/python/arith/test_arith_rewrite_simplify.py | 36 +++++++++++++++++++++++
 tests/python/arith/test_arith_simplify.py         | 29 ++++++++++++++++++
 3 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index f4d4a9048c..3682054e8e 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* 
op) {
   PVar<IntImm> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
   PVar<PrimExpr> lanes;
+
   // Vector rules
   if (op->dtype.is_scalable_or_fixed_length_vector()) {
     TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, 
s1 - s2, lanes));
@@ -697,9 +698,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
SubNode* op) {
     TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
     TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
     TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
-  } else if (op->dtype.is_float()) {
+  } else {
     // Cancellation rules.  Deliberately off of the integer path, to
     // avoid introducing checks on the side effects for the fast path.
+    //
+    // These simplifications do not preserve NaN/Inf that may occur in
+    // the inputs.  For IEEE floats, `NaN - NaN` is `NaN`, and does
+    // not cancel out.  However, since models should not encounter NaN
+    // in the first place, this allows better simplification for the
+    // supported path.
     TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
                        SideEffect(x.Eval()) <= CallEffectKind::kReadState);
     TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= 
CallEffectKind::kReadState);
@@ -1678,6 +1685,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ 
ret) {
   // Pattern var match IntImm
   PVar<IntImm> c1, c2;
   PVar<PrimExpr> lanes;
+  PConst<PrimExpr> ctrue(make_const(ret->dtype, true));
 
   // vector rule
   if (ret->dtype.is_scalable_or_fixed_length_vector()) {
@@ -1698,6 +1706,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ 
ret) {
     TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2);
     TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1);
     TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);
+    TVM_TRY_REWRITE(x == x, ctrue);
+  } else {
+    // Mimic the cancellation rules for SubNode.  For Index datatypes,
+    // we skip the check for side effects.
+    //
+    // These simplifications do not preserve NaN/Inf that may occur in
+    // the inputs.  For IEEE floats, `NaN - NaN` is `NaN`, and does
+    // not cancel out.  However, since models should not encounter NaN
+    // in the first place, this allows better simplification for the
+    // supported path.
+    TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= 
CallEffectKind::kReadState);
   }
   return std::move(ret);
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index 1ebaab53af..90f0aeef47 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -321,6 +321,42 @@ class TestSelect(BaseCompare):
     )
 
 
+class TestCancellation(BaseCompare):
+    var_int8 = tir.Var("var_int8", "int8")
+    var_int32 = tir.Var("var_int32", "int32")
+    var_int64 = tir.Var("var_int64", "int64")
+    var_uint8 = tir.Var("var_uint8", "uint8")
+    var_uint32 = tir.Var("var_uint32", "uint32")
+    var_uint64 = tir.Var("var_uint64", "uint64")
+
+    test_case = tvm.testing.parameter(
+        TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, 
"int64")),
+        TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, 
"uint8")),
+        TestCase(var_int8 - var_int8, tir.const(0, "int8")),
+        TestCase(var_int32 - var_int32, tir.const(0, "int32")),
+        TestCase(var_int64 - var_int64, tir.const(0, "int64")),
+        TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")),
+        TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")),
+        TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")),
+        TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), 
tir.const(True, "bool")),
+        TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), 
tir.const(True, "bool")),
+        TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")),
+        TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")),
+        TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")),
+        TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")),
+        TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")),
+        TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")),
+        TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), 
tir.const(False, "bool")),
+        TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), 
tir.const(False, "bool")),
+        TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")),
+        TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")),
+        TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")),
+        TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")),
+        TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")),
+        TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")),
+    )
+
+
 class TestAddIndex(BaseCompare):
     x, y, z = te.var("x"), te.var("y"), te.var("z")
 
diff --git a/tests/python/arith/test_arith_simplify.py 
b/tests/python/arith/test_arith_simplify.py
index 9a0245d274..3b02377400 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index():
     )
 
 
+dtype = tvm.testing.parameter(
+    "uint8",
+    "uint16",
+    "uint32",
+    "uint64",
+    "int8",
+    "int16",
+    "int32",
+    "int64",
+    "float16",
+    "float32",
+    "float64",
+)
+
+
+def test_can_prove_self_identity(dtype):
+    ana = tvm.arith.Analyzer()
+
+    n = tir.Var("n", dtype)
+    assert ana.can_prove(n == n)
+
+
+def test_can_prove_self_equal_to_self(dtype):
+    ana = tvm.arith.Analyzer()
+
+    n = tir.Var("n", dtype)
+    assert ana.can_prove_equal(n, n)
+
+
 def test_simplify_symbolic_comparison():
     ana = tvm.arith.Analyzer()
 

Reply via email to