This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0545962 [Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (#8266)
0545962 is described below
commit 0545962002e953dd69045b826ec707cb9da261e5
Author: Huang, Guangtai <[email protected]>
AuthorDate: Sat Jul 17 14:56:18 2021 +0800
[Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (#8266)
* fix
* fix lint
* remove
* address comments
---
python/tvm/tir/expr.py | 2 +-
src/tir/analysis/deep_equal.cc | 3 +++
tests/python/unittest/test_arith_rewrite_simplify.py | 3 +++
3 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index 6e86157..4ba8c54 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -1214,7 +1214,7 @@ class Let(PrimExprWithOp):
@tvm._ffi.register_object("tir.Any")
-class Any(PrimExpr):
+class Any(PrimExprWithOp):
"""Any node.
span : Optional[Span]
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index 7eb8013..7f48cc4 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const
PrimExpr& rhs) const {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
+ if (lhs.as<AnyNode>()) {
+ return false;
+ }
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
}
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py
b/tests/python/unittest/test_arith_rewrite_simplify.py
index c3afa6c..231c376 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -275,6 +275,7 @@ def test_add_index_simplify():
def test_sub_index_simplify():
ck = RewriteChecker()
x, y, z = te.var("x"), te.var("y"), te.var("z")
+ a, b = tvm.tir.Any(), tvm.tir.Any()
ck.verify(x + y - y, x)
ck.verify(x + y - x, y)
@@ -293,6 +294,8 @@ def test_sub_index_simplify():
# mul co-efficient foldng
ck.verify(x - x, 0)
+ ck.verify(a - a, 0)
+ ck.verify(a - b, a - b)
ck.verify(x * y - x, x * (y + (-1)))
ck.verify(x * y - 10 * x, x * (y + (-10)))
ck.verify(y * x - x * z, x * (y - z))