This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c3fe08fb20 [ARITH] support floordiv in deduce bound (#13880)
c3fe08fb20 is described below

commit c3fe08fb20abc69386f645a81a0abe49f64d61ba
Author: wrongtest <[email protected]>
AuthorDate: Wed Feb 1 15:15:16 2023 +0800

    [ARITH] support floordiv in deduce bound (#13880)
    
    * support floordiv in deduce bound
    
    * add rule for (x // -positive)
    
    * leave todo for x // a == b
---
 src/arith/bound_deducer.cc                       | 61 +++++++++++++++++++++---
 tests/python/unittest/test_arith_deduce_bound.py | 38 ++++++++++++++-
 2 files changed, 91 insertions(+), 8 deletions(-)

diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index d4a3101378..7cfe8681be 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -94,6 +94,13 @@ class BoundDeducer : public ExprFunctor<void(const 
PrimExpr&)> {
 
   void VisitExprDefault_(const Object* op) final { success_ = false; }
 
+  SignType GetSignType(const PrimExpr& e) {
+    if (e.dtype().is_uint()) {
+      return kPositive;
+    }
+    return expr_map_[e].GetSignType();
+  }
+
   void VisitExpr_(const VarNode* op) final {}
 
   void VisitExpr_(const AddNode* op) final {
@@ -119,13 +126,7 @@ class BoundDeducer : public ExprFunctor<void(const 
PrimExpr&)> {
     PrimExpr operand = left ? op->b : op->a;
     PrimExpr target_var = left ? op->a : op->b;
 
-    SignType sign_operand;
-    if (operand.dtype().is_uint()) {
-      sign_operand = kPositive;
-    } else {
-      sign_operand = expr_map_[operand].GetSignType();
-    }
-
+    SignType sign_operand = GetSignType(operand);
     if (sign_operand == SignType::kNegative) {
       comp_op = ReverseOp(comp_op);
     } else if (sign_operand == SignType::kUnknown) {
@@ -162,6 +163,52 @@ class BoundDeducer : public ExprFunctor<void(const 
PrimExpr&)> {
     this->VisitExpr(left ? op->a : op->b);
   }
 
+  void VisitExpr_(const FloorDivNode* op) final {
+    if (op->b.get() == path_[iter_]) {
+      // Skip cases where the var is divisor.
+      success_ = false;
+      return;
+    }
+    PrimExpr divisor = op->b;
+    if (analyzer_.CanProveEqual(divisor, 0)) {
+      // Skip zero divisor
+      success_ = false;
+      return;
+    }
+
+    SignType sign_operand = GetSignType(divisor);
+    if (sign_operand == SignType::kNegative) {
+      comp_op = ReverseOp(comp_op);
+      divisor = -divisor;
+      result_ = -result_;
+    } else if (sign_operand == SignType::kUnknown) {
+      // unable to get the sign of operand
+      success_ = false;
+      return;
+    }
+
+    if (comp_op == kGreater) {
+      // (x // 6 >= 4 --> x >= 4 * 6)
+      result_ = result_ * divisor;
+    } else if (comp_op == kEqual) {
+      // The bound is not single directional
+      // (x // 6 == 4 --> 30 > x >= 24)
+      // TODO(@wrongtest): support bidirectional bound
+      success_ = false;
+      return;
+    } else {
+      // (x // 6 <= 4 --> x <= 4 * 6 + 5)
+      result_ = result_ * divisor + divisor - 1;
+    }
+    if (sign_operand == SignType::kNegative) {
+      // (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4
+      //               --> (x + 6 - 1) // 6 <= -4
+      result_ = result_ - divisor + 1;
+    }
+
+    this->VisitExpr(op->a);
+  }
+
   PrimExpr result_;
   CompareOp comp_op{kGreater};
   bool success_{true};
diff --git a/tests/python/unittest/test_arith_deduce_bound.py 
b/tests/python/unittest/test_arith_deduce_bound.py
index d5e0303b05..45ecb62755 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -219,7 +219,6 @@ def test_deduce_non_support():
         res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
         assert res.is_nothing()
 
-    test_non_support(tvm.tir.floordiv(a, 16))
     test_non_support(tvm.tir.floormod(a, 16))
     test_non_support(tvm.tir.Min(a, 16))
     test_non_support(tvm.tir.Max(a, 16))
@@ -233,5 +232,42 @@ def test_deduce_non_support():
     test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))
 
 
+def test_deduce_floordiv():
+    def do_test(gen_expr, dom_map, expect_min, expect_max):
+        a = te.var("a")
+        expr = gen_expr(a)
+        res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
+        if isinstance(expect_min, str):
+            assert str(res.min_value) == expect_min
+        else:
+            tvm.testing.assert_prim_expr_equal(res.min_value, expect_min)
+        if isinstance(expect_max, str):
+            assert str(res.max_value) == expect_max
+        else:
+            tvm.testing.assert_prim_expr_equal(res.max_value, expect_max)
+
+    # test basic cases
+    do_test(lambda a: a // 8 > 3, {}, 32, "pos_inf")
+    do_test(lambda a: a // 8 >= 3, {}, 24, "pos_inf")
+    do_test(lambda a: a // 8 < 3, {}, "neg_inf", 23)
+    do_test(lambda a: a // 8 <= 3, {}, "neg_inf", 31)
+    do_test(lambda a: a // 8 == 3, {}, "pos_inf", "neg_inf")
+    do_test(lambda a: a // 8 > -3, {}, -16, "pos_inf")
+    do_test(lambda a: a // 8 >= -3, {}, -24, "pos_inf")
+    do_test(lambda a: a // -8 > 3, {}, "neg_inf", -32)
+    do_test(lambda a: a // -8 >= 3, {}, "neg_inf", -24)
+    do_test(lambda a: a // -8 < 3, {}, -23, "pos_inf")
+    do_test(lambda a: a // -8 <= 3, {}, -31, "pos_inf")
+    do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")
+
+    # test nested cases
+    b = te.var("b")
+    bs = {b: tvm.arith.IntervalSet(2, 6)}
+    do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
+    do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
+    do_test(lambda a: b * 3 + a // 8 > 63, bs, 464, "pos_inf")
+    do_test(lambda a: b * 3 + a // 8 >= 63, bs, 456, "pos_inf")
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to