This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0545962  [Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (#8266)
0545962 is described below

commit 0545962002e953dd69045b826ec707cb9da261e5
Author: Huang, Guangtai <[email protected]>
AuthorDate: Sat Jul 17 14:56:18 2021 +0800

    [Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (#8266)
    
    * fix
    
    * fix lint
    
    * remove
    
    * address comments
---
 python/tvm/tir/expr.py                               | 2 +-
 src/tir/analysis/deep_equal.cc                       | 3 +++
 tests/python/unittest/test_arith_rewrite_simplify.py | 3 +++
 3 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index 6e86157..4ba8c54 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -1214,7 +1214,7 @@ class Let(PrimExprWithOp):
 
 
 @tvm._ffi.register_object("tir.Any")
-class Any(PrimExpr):
+class Any(PrimExprWithOp):
     """Any node.
 
     span : Optional[Span]
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index 7eb8013..7f48cc4 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const 
PrimExpr& rhs) const {
     auto* prhs = rhs.as<IntImmNode>();
     return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
   }
+  if (lhs.as<AnyNode>()) {
+    return false;
+  }
   return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
 }
 
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py 
b/tests/python/unittest/test_arith_rewrite_simplify.py
index c3afa6c..231c376 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -275,6 +275,7 @@ def test_add_index_simplify():
 def test_sub_index_simplify():
     ck = RewriteChecker()
     x, y, z = te.var("x"), te.var("y"), te.var("z")
+    a, b = tvm.tir.Any(), tvm.tir.Any()
 
     ck.verify(x + y - y, x)
     ck.verify(x + y - x, y)
@@ -293,6 +294,8 @@ def test_sub_index_simplify():
 
     # mul co-efficient foldng
     ck.verify(x - x, 0)
+    ck.verify(a - a, 0)
+    ck.verify(a - b, a - b)
     ck.verify(x * y - x, x * (y + (-1)))
     ck.verify(x * y - 10 * x, x * (y + (-10)))
     ck.verify(y * x - x * z, x * (y - z))

Reply via email to