This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cac0445f5e [Arith][BoundDeducer] Forbid non-supported expr type in
bound deducer (#11323)
cac0445f5e is described below
commit cac0445f5e65ac0357ab7db141006d1004750ac4
Author: wrongtest <[email protected]>
AuthorDate: Wed May 25 02:36:27 2022 +0800
[Arith][BoundDeducer] Forbid non-supported expr type in bound deducer
(#11323)
---
src/arith/bound_deducer.cc | 12 ++++------
tests/python/unittest/test_arith_deduce_bound.py | 28 ++++++++++++++++++++----
2 files changed, 28 insertions(+), 12 deletions(-)
diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index 9275ec1bc3..ba6b11dbb7 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -71,7 +71,7 @@ std::vector<const Object*> GetPath(PrimExpr target, PrimExpr
expr) {
enum CompareOp { kGreater, kLess, kEqual };
// a visitor to deduce the bound of a variable from a expression
-class BoundDeducer : public ExprVisitor {
+class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
@@ -85,20 +85,16 @@ class BoundDeducer : public ExprVisitor {
void VisitExpr(const PrimExpr& e) final {
if (!success_) return;
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
- ExprVisitor::VisitExpr(e);
+ ExprFunctor::VisitExpr(e);
} else {
success_ = false;
return;
}
}
- void VisitExpr_(const LTNode* op) final { success_ = false; }
+ void VisitExprDefault_(const Object* op) final { success_ = false; }
- void VisitExpr_(const LENode* op) final { success_ = false; }
-
- void VisitExpr_(const GTNode* op) final { success_ = false; }
-
- void VisitExpr_(const GENode* op) final { success_ = false; }
+ void VisitExpr_(const VarNode* op) final {}
void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_];
diff --git a/tests/python/unittest/test_arith_deduce_bound.py
b/tests/python/unittest/test_arith_deduce_bound.py
index 5c6976ab50..ef478b4c2f 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
import tvm
import tvm.testing
from tvm import te
+from tvm.tir.buffer import decl_buffer
def test_deduce():
@@ -210,8 +212,26 @@ def test_deduce_complex():
test_complex(2, 6, -4)
+def test_deduce_non_support():
+ a = te.var("a")
+
+ def test_non_support(lhs):
+ res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
+ assert res.is_nothing()
+
+ test_non_support(tvm.tir.floordiv(a, 16))
+ test_non_support(tvm.tir.floormod(a, 16))
+ test_non_support(tvm.tir.Min(a, 16))
+ test_non_support(tvm.tir.Max(a, 16))
+ test_non_support(tvm.tir.LE(a, 16))
+ test_non_support(tvm.tir.LT(a, 16))
+ test_non_support(tvm.tir.GE(a, 16))
+ test_non_support(tvm.tir.GT(a, 16))
+ test_non_support(tvm.tir.EQ(a, 16))
+ test_non_support(tvm.tir.NE(a, 16))
+ test_non_support(tvm.tir.log(a))
+ test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))
+
+
if __name__ == "__main__":
- test_check()
- test_deduce()
- test_deduce_basic()
- test_deduce_complex()
+ pytest.main([__file__])