This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cac0445f5e [Arith][BoundDeducer] Forbid non-supported expr type in 
bound deducer (#11323)
cac0445f5e is described below

commit cac0445f5e65ac0357ab7db141006d1004750ac4
Author: wrongtest <[email protected]>
AuthorDate: Wed May 25 02:36:27 2022 +0800

    [Arith][BoundDeducer] Forbid non-supported expr type in bound deducer 
(#11323)
---
 src/arith/bound_deducer.cc                       | 12 ++++------
 tests/python/unittest/test_arith_deduce_bound.py | 28 ++++++++++++++++++++----
 2 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index 9275ec1bc3..ba6b11dbb7 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -71,7 +71,7 @@ std::vector<const Object*> GetPath(PrimExpr target, PrimExpr 
expr) {
 enum CompareOp { kGreater, kLess, kEqual };
 
 // a visitor to deduce the bound of a variable from a expression
-class BoundDeducer : public ExprVisitor {
+class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
  public:
   friend class BoundDeduceInputChecker;
   friend class Converter;
@@ -85,20 +85,16 @@ class BoundDeducer : public ExprVisitor {
   void VisitExpr(const PrimExpr& e) final {
     if (!success_) return;
     if (iter_ < path_.size() && e.get() == path_[iter_++]) {
-      ExprVisitor::VisitExpr(e);
+      ExprFunctor::VisitExpr(e);
     } else {
       success_ = false;
       return;
     }
   }
 
-  void VisitExpr_(const LTNode* op) final { success_ = false; }
+  void VisitExprDefault_(const Object* op) final { success_ = false; }
 
-  void VisitExpr_(const LENode* op) final { success_ = false; }
-
-  void VisitExpr_(const GTNode* op) final { success_ = false; }
-
-  void VisitExpr_(const GENode* op) final { success_ = false; }
+  void VisitExpr_(const VarNode* op) final {}
 
   void VisitExpr_(const AddNode* op) final {
     bool left = op->a.get() == path_[iter_];
diff --git a/tests/python/unittest/test_arith_deduce_bound.py 
b/tests/python/unittest/test_arith_deduce_bound.py
index 5c6976ab50..ef478b4c2f 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -14,9 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
 import tvm
 import tvm.testing
 from tvm import te
+from tvm.tir.buffer import decl_buffer
 
 
 def test_deduce():
@@ -210,8 +212,26 @@ def test_deduce_complex():
     test_complex(2, 6, -4)
 
 
+def test_deduce_non_support():
+    a = te.var("a")
+
+    def test_non_support(lhs):
+        res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
+        assert res.is_nothing()
+
+    test_non_support(tvm.tir.floordiv(a, 16))
+    test_non_support(tvm.tir.floormod(a, 16))
+    test_non_support(tvm.tir.Min(a, 16))
+    test_non_support(tvm.tir.Max(a, 16))
+    test_non_support(tvm.tir.LE(a, 16))
+    test_non_support(tvm.tir.LT(a, 16))
+    test_non_support(tvm.tir.GE(a, 16))
+    test_non_support(tvm.tir.GT(a, 16))
+    test_non_support(tvm.tir.EQ(a, 16))
+    test_non_support(tvm.tir.NE(a, 16))
+    test_non_support(tvm.tir.log(a))
+    test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))
+
+
 if __name__ == "__main__":
-    test_check()
-    test_deduce()
-    test_deduce_basic()
-    test_deduce_complex()
+    pytest.main([__file__])

Reply via email to