This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c3fe08fb20 [ARITH] support floordiv in deduce bound (#13880)
c3fe08fb20 is described below
commit c3fe08fb20abc69386f645a81a0abe49f64d61ba
Author: wrongtest <[email protected]>
AuthorDate: Wed Feb 1 15:15:16 2023 +0800
[ARITH] support floordiv in deduce bound (#13880)
* support floordiv in deduce bound
* add rule for (x // -positive)
* leave todo for x // a == b
---
src/arith/bound_deducer.cc | 61 +++++++++++++++++++++---
tests/python/unittest/test_arith_deduce_bound.py | 38 ++++++++++++++-
2 files changed, 91 insertions(+), 8 deletions(-)
diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index d4a3101378..7cfe8681be 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -94,6 +94,13 @@ class BoundDeducer : public ExprFunctor<void(const
PrimExpr&)> {
void VisitExprDefault_(const Object* op) final { success_ = false; }
+ SignType GetSignType(const PrimExpr& e) {
+ if (e.dtype().is_uint()) {
+ return kPositive;
+ }
+ return expr_map_[e].GetSignType();
+ }
+
void VisitExpr_(const VarNode* op) final {}
void VisitExpr_(const AddNode* op) final {
@@ -119,13 +126,7 @@ class BoundDeducer : public ExprFunctor<void(const
PrimExpr&)> {
PrimExpr operand = left ? op->b : op->a;
PrimExpr target_var = left ? op->a : op->b;
- SignType sign_operand;
- if (operand.dtype().is_uint()) {
- sign_operand = kPositive;
- } else {
- sign_operand = expr_map_[operand].GetSignType();
- }
-
+ SignType sign_operand = GetSignType(operand);
if (sign_operand == SignType::kNegative) {
comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) {
@@ -162,6 +163,52 @@ class BoundDeducer : public ExprFunctor<void(const
PrimExpr&)> {
this->VisitExpr(left ? op->a : op->b);
}
+ void VisitExpr_(const FloorDivNode* op) final {
+ if (op->b.get() == path_[iter_]) {
+ // Skip cases where the var is divisor.
+ success_ = false;
+ return;
+ }
+ PrimExpr divisor = op->b;
+ if (analyzer_.CanProveEqual(divisor, 0)) {
+ // Skip zero divisor
+ success_ = false;
+ return;
+ }
+
+ SignType sign_operand = GetSignType(divisor);
+ if (sign_operand == SignType::kNegative) {
+ comp_op = ReverseOp(comp_op);
+ divisor = -divisor;
+ result_ = -result_;
+ } else if (sign_operand == SignType::kUnknown) {
+ // unable to get the sign of operand
+ success_ = false;
+ return;
+ }
+
+ if (comp_op == kGreater) {
+ // (x // 6 >= 4 --> x >= 4 * 6)
+ result_ = result_ * divisor;
+ } else if (comp_op == kEqual) {
+ // The bound is not single directional
+ // (x // 6 == 4 --> 30 > x >= 24)
+ // TODO(@wrongtest): support bidirectional bound
+ success_ = false;
+ return;
+ } else {
+ // (x // 6 <= 4 --> x <= 4 * 6 + 5)
+ result_ = result_ * divisor + divisor - 1;
+ }
+ if (sign_operand == SignType::kNegative) {
+ // (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4
+ // --> (x + 6 - 1) // 6 <= -4
+ result_ = result_ - divisor + 1;
+ }
+
+ this->VisitExpr(op->a);
+ }
+
PrimExpr result_;
CompareOp comp_op{kGreater};
bool success_{true};
diff --git a/tests/python/unittest/test_arith_deduce_bound.py
b/tests/python/unittest/test_arith_deduce_bound.py
index d5e0303b05..45ecb62755 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -219,7 +219,6 @@ def test_deduce_non_support():
res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
assert res.is_nothing()
- test_non_support(tvm.tir.floordiv(a, 16))
test_non_support(tvm.tir.floormod(a, 16))
test_non_support(tvm.tir.Min(a, 16))
test_non_support(tvm.tir.Max(a, 16))
@@ -233,5 +232,42 @@ def test_deduce_non_support():
test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))
+def test_deduce_floordiv():
+ def do_test(gen_expr, dom_map, expect_min, expect_max):
+ a = te.var("a")
+ expr = gen_expr(a)
+ res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
+ if isinstance(expect_min, str):
+ assert str(res.min_value) == expect_min
+ else:
+ tvm.testing.assert_prim_expr_equal(res.min_value, expect_min)
+ if isinstance(expect_max, str):
+ assert str(res.max_value) == expect_max
+ else:
+ tvm.testing.assert_prim_expr_equal(res.max_value, expect_max)
+
+ # test basic cases
+ do_test(lambda a: a // 8 > 3, {}, 32, "pos_inf")
+ do_test(lambda a: a // 8 >= 3, {}, 24, "pos_inf")
+ do_test(lambda a: a // 8 < 3, {}, "neg_inf", 23)
+ do_test(lambda a: a // 8 <= 3, {}, "neg_inf", 31)
+ do_test(lambda a: a // 8 == 3, {}, "pos_inf", "neg_inf")
+ do_test(lambda a: a // 8 > -3, {}, -16, "pos_inf")
+ do_test(lambda a: a // 8 >= -3, {}, -24, "pos_inf")
+ do_test(lambda a: a // -8 > 3, {}, "neg_inf", -32)
+ do_test(lambda a: a // -8 >= 3, {}, "neg_inf", -24)
+ do_test(lambda a: a // -8 < 3, {}, -23, "pos_inf")
+ do_test(lambda a: a // -8 <= 3, {}, -31, "pos_inf")
+ do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")
+
+ # test nested cases
+ b = te.var("b")
+ bs = {b: tvm.arith.IntervalSet(2, 6)}
+ do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
+ do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
+ do_test(lambda a: b * 3 + a // 8 > 63, bs, 464, "pos_inf")
+ do_test(lambda a: b * 3 + a // 8 >= 63, bs, 456, "pos_inf")
+
+
if __name__ == "__main__":
tvm.testing.main()