This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9f0f301c6f [TIR][Analyzer] Simplify `x==x` expressions for all dtypes
(#17158)
9f0f301c6f is described below
commit 9f0f301c6f6de7548c6b2026bcb51590e0881ac5
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jul 24 08:24:15 2024 -0500
[TIR][Analyzer] Simplify `x==x` expressions for all dtypes (#17158)
* [TIR][Analyzer] Simplify `x==x` expressions for all dtypes
Prior to this commit, there was no rule to simplify `x == x` into
`True`. In some cases, despite not having an explicit rewrite rule in
`RewriteSimplifier`, the `RewriteSimplifier::CanProve` function would
check if `x-x` simplifies to zero, relying on the rewrite rules used
for `tir::Sub`. However, the rule to rewrite `x-x` into zero was only
enabled for `int32`, `int64`, and floating-point types, so relying on
this behavior was inconsistent.
This commit updates the rewrite rules for both `tir::EQ` and
`tir::Sub` to check for simplification of `x-x` or `x==x`, regardless
of the datatype. This change preserves the fast-path for index
data-types, in which `int32` and `int64` expressions may be simplified
without checking for side effects. For all other dtypes, the
cancellation only applies when evaluating `x` has no side effects.
* Add comment about simplifications of NaN/Inf
---
src/arith/rewrite_simplify.cc | 21 ++++++++++++-
tests/python/arith/test_arith_rewrite_simplify.py | 36 +++++++++++++++++++++++
tests/python/arith/test_arith_simplify.py | 29 ++++++++++++++++++
3 files changed, 85 insertions(+), 1 deletion(-)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index f4d4a9048c..3682054e8e 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode*
op) {
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<PrimExpr> lanes;
+
// Vector rules
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2,
s1 - s2, lanes));
@@ -697,9 +698,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
SubNode* op) {
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
- } else if (op->dtype.is_float()) {
+ } else {
// Cancellation rules. Deliberately off of the integer path, to
// avoid introducing checks on the side effects for the fast path.
+ //
+ // These simplifications do not preserve NaN/Inf that may occur in
+ // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does
+ // not cancel out. However, since models should not encounter NaN
+ // in the first place, this allows better simplification for the
+ // supported path.
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <=
CallEffectKind::kReadState);
@@ -1678,6 +1685,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ
ret) {
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<PrimExpr> lanes;
+ PConst<PrimExpr> ctrue(make_const(ret->dtype, true));
// vector rule
if (ret->dtype.is_scalable_or_fixed_length_vector()) {
@@ -1698,6 +1706,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ
ret) {
TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2);
TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);
+ TVM_TRY_REWRITE(x == x, ctrue);
+ } else {
+ // Mimic the cancellation rules for SubNode. For Index datatypes,
+ // we skip the check for side effects.
+ //
+ // These simplifications do not preserve NaN/Inf that may occur in
+ // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does
+ // not cancel out. However, since models should not encounter NaN
+ // in the first place, this allows better simplification for the
+ // supported path.
+ TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <=
CallEffectKind::kReadState);
}
return std::move(ret);
}
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 1ebaab53af..90f0aeef47 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -321,6 +321,42 @@ class TestSelect(BaseCompare):
)
+class TestCancellation(BaseCompare):
+ var_int8 = tir.Var("var_int8", "int8")
+ var_int32 = tir.Var("var_int32", "int32")
+ var_int64 = tir.Var("var_int64", "int64")
+ var_uint8 = tir.Var("var_uint8", "uint8")
+ var_uint32 = tir.Var("var_uint32", "uint32")
+ var_uint64 = tir.Var("var_uint64", "uint64")
+
+ test_case = tvm.testing.parameter(
+ TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0,
"int64")),
+ TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0,
"uint8")),
+ TestCase(var_int8 - var_int8, tir.const(0, "int8")),
+ TestCase(var_int32 - var_int32, tir.const(0, "int32")),
+ TestCase(var_int64 - var_int64, tir.const(0, "int64")),
+ TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")),
+ TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")),
+ TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")),
+ TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")),
tir.const(True, "bool")),
+ TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")),
tir.const(True, "bool")),
+ TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")),
+ TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")),
+ TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")),
+ TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")),
+ TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")),
+ TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")),
+ TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")),
tir.const(False, "bool")),
+ TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")),
tir.const(False, "bool")),
+ TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")),
+ TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")),
+ TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")),
+ TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")),
+ TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")),
+ TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")),
+ )
+
+
class TestAddIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
diff --git a/tests/python/arith/test_arith_simplify.py
b/tests/python/arith/test_arith_simplify.py
index 9a0245d274..3b02377400 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index():
)
+dtype = tvm.testing.parameter(
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "float16",
+ "float32",
+ "float64",
+)
+
+
+def test_can_prove_self_identity(dtype):
+ ana = tvm.arith.Analyzer()
+
+ n = tir.Var("n", dtype)
+ assert ana.can_prove(n == n)
+
+
+def test_can_prove_self_equal_to_self(dtype):
+ ana = tvm.arith.Analyzer()
+
+ n = tir.Var("n", dtype)
+ assert ana.can_prove_equal(n, n)
+
+
def test_simplify_symbolic_comparison():
ana = tvm.arith.Analyzer()