This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c2ab1bb42 [Arith] Support eq in detect_clip_bound (#13746)
0c2ab1bb42 is described below

commit 0c2ab1bb42fc960ba23416f3ae4068bece8ca2e2
Author: wrongtest <[email protected]>
AuthorDate: Sat Jan 28 13:42:53 2023 +0800

    [Arith] Support eq in detect_clip_bound (#13746)
    
    * Support eq in detect_clip_bound
    
    * follow review suggestion
---
 src/arith/detect_linear_equation.cc                | 38 +++++++++++++++++-----
 .../unittest/test_arith_detect_clip_bound.py       | 13 ++++++++
 2 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/src/arith/detect_linear_equation.cc 
b/src/arith/detect_linear_equation.cc
index 8ea8f168b6..576ac1716e 100644
--- a/src/arith/detect_linear_equation.cc
+++ b/src/arith/detect_linear_equation.cc
@@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond,
   PostOrderVisit(cond, fvisit);
   if (flag != 1) return false;
   // canonical form: exp >= 0
+  bool is_eq = false;
   PrimExpr canonical;
   if (const LTNode* op = cond.as<LTNode>()) {
     if (!op->a.dtype().is_int()) return false;
@@ -202,6 +203,10 @@ bool DetectClipBound(const PrimExpr& cond,
   } else if (const GENode* op = cond.as<GENode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->a - op->b;
+  } else if (const EQNode* op = cond.as<EQNode>()) {
+    if (!op->a.dtype().is_int()) return false;
+    canonical = op->a - op->b;
+    is_eq = true;
   } else {
     return false;
   }
@@ -210,25 +215,40 @@ bool DetectClipBound(const PrimExpr& cond,
   if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
   ret.coeff = analyzer.Simplify(ret.coeff);
   IntervalEntry& p = (*bmap)[var.get()];
+
+  Optional<PrimExpr> min_value;
+  Optional<PrimExpr> max_value;
   if (is_const_int(ret.coeff, 1)) {
     // var + shift >=0 -> var >= -shift
+    min_value = -ret.base;
+    if (is_eq) {
+      max_value = min_value;
+    }
+  } else if (is_const_int(ret.coeff, -1)) {
+    // -var + shift >=0 -> var <= shift
+    max_value = ret.base;
+    if (is_eq) {
+      min_value = max_value;
+    }
+  }
+  if (!min_value.defined() && !max_value.defined()) {
+    return false;
+  }
+  if (min_value.defined()) {
     if (p.min_value.defined()) {
-      p.min_value = max(p.min_value, -ret.base);
+      p.min_value = max(p.min_value, min_value.value());
     } else {
-      p.min_value = -ret.base;
+      p.min_value = min_value.value();
     }
-    return true;
   }
-  if (is_const_int(ret.coeff, -1)) {
-    // -var + shift >=0 -> var <= shift
+  if (max_value.defined()) {
     if (p.max_value.defined()) {
-      p.max_value = min(p.max_value, ret.base);
+      p.max_value = min(p.max_value, max_value.value());
     } else {
-      p.max_value = ret.base;
+      p.max_value = max_value.value();
     }
-    return true;
   }
-  return false;
+  return true;
 }
 
 template <typename OP>
diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py 
b/tests/python/unittest/test_arith_detect_clip_bound.py
index 0a9d75fcea..03fff11f77 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/tests/python/unittest/test_arith_detect_clip_bound.py
@@ -39,5 +39,18 @@ def test_basic():
     tvm.testing.assert_prim_expr_equal(m[2], 4)
 
 
+def test_trivial_eq():
+    a = te.var("a")
+    b = te.var("b")
+    m = tvm.arith.detect_clip_bound(b == 3, [a, b])
+    tvm.testing.assert_prim_expr_equal(m[2], 3)
+    tvm.testing.assert_prim_expr_equal(m[3], 3)
+    m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b])
+    tvm.testing.assert_prim_expr_equal(m[0], 4)
+    tvm.testing.assert_prim_expr_equal(m[1], 4)
+    tvm.testing.assert_prim_expr_equal(m[2], 3)
+    tvm.testing.assert_prim_expr_equal(m[3], 3)
+
+
 if __name__ == "__main__":
     test_basic()

Reply via email to