This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c6f9c9bcc [Arith] Added simplification rule for multiple equality 
compares (#15628)
3c6f9c9bcc is described below

commit 3c6f9c9bcc2b3fa2ca30ae1f6174b4f536f6d368
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 28 19:28:01 2023 -0400

    [Arith] Added simplification rule for multiple equality compares (#15628)
    
    The expression `(x==y) && (x==z)` requires that `y==z`.  When `y` and
    `z` are constants, this can allow better constant folding by
    rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`.
    
    This commit adds the above rewrite, and the corresponding rewrite of
    the negative expression.
---
 src/arith/rewrite_simplify.cc                        | 2 ++
 tests/python/unittest/test_arith_rewrite_simplify.py | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 40088fd963..63becf8eb7 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
AndNode* op) {
                      }),
                      cfalse, c2.Eval()->value > c1.Eval()->value);
 
+  TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
   TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == 
c1 && c1 != c2);
 
   TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && 
floormod(x, c2) == c3,
@@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
OrNode* op) {
   TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= 
c1.Eval()->value + 1);
   TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= 
c1.Eval()->value + 1);
 
+  TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2);
   TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
   TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
 
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py 
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 46ac0f9751..0b0a43a7d3 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -951,6 +951,7 @@ class TestLogical(BaseCompare):
         TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")),
         TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")),
         TestCase(tvm.tir.And(x == 1, x != 2), x == 1),
+        TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")),
         TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), 
tvm.tir.const(True, "bool")),
         TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), 
tvm.tir.const(True, "bool")),
         TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, 
"bool")),
@@ -965,6 +966,7 @@ class TestLogical(BaseCompare):
         TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")),
         TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")),
         TestCase(tvm.tir.Or(x != 1, x == 2), x != 1),
+        TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")),
         TestCase(
             tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)),
             tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),

Reply via email to