This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c2ab1bb42 [Arith] Support eq in detect_clip_bound (#13746)
0c2ab1bb42 is described below
commit 0c2ab1bb42fc960ba23416f3ae4068bece8ca2e2
Author: wrongtest <[email protected]>
AuthorDate: Sat Jan 28 13:42:53 2023 +0800
[Arith] Support eq in detect_clip_bound (#13746)
* Support eq in detect_clip_bound
* follow review suggestion
---
src/arith/detect_linear_equation.cc | 38 +++++++++++++++++-----
.../unittest/test_arith_detect_clip_bound.py | 13 ++++++++
2 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/src/arith/detect_linear_equation.cc
b/src/arith/detect_linear_equation.cc
index 8ea8f168b6..576ac1716e 100644
--- a/src/arith/detect_linear_equation.cc
+++ b/src/arith/detect_linear_equation.cc
@@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond,
PostOrderVisit(cond, fvisit);
if (flag != 1) return false;
// canonical form: exp >= 0
+ bool is_eq = false;
PrimExpr canonical;
if (const LTNode* op = cond.as<LTNode>()) {
if (!op->a.dtype().is_int()) return false;
@@ -202,6 +203,10 @@ bool DetectClipBound(const PrimExpr& cond,
} else if (const GENode* op = cond.as<GENode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b;
+ } else if (const EQNode* op = cond.as<EQNode>()) {
+ if (!op->a.dtype().is_int()) return false;
+ canonical = op->a - op->b;
+ is_eq = true;
} else {
return false;
}
@@ -210,25 +215,40 @@ bool DetectClipBound(const PrimExpr& cond,
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = analyzer.Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
+
+ Optional<PrimExpr> min_value;
+ Optional<PrimExpr> max_value;
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
+ min_value = -ret.base;
+ if (is_eq) {
+ max_value = min_value;
+ }
+ } else if (is_const_int(ret.coeff, -1)) {
+ // -var + shift >=0 -> var <= shift
+ max_value = ret.base;
+ if (is_eq) {
+ min_value = max_value;
+ }
+ }
+ if (!min_value.defined() && !max_value.defined()) {
+ return false;
+ }
+ if (min_value.defined()) {
if (p.min_value.defined()) {
- p.min_value = max(p.min_value, -ret.base);
+ p.min_value = max(p.min_value, min_value.value());
} else {
- p.min_value = -ret.base;
+ p.min_value = min_value.value();
}
- return true;
}
- if (is_const_int(ret.coeff, -1)) {
- // -var + shift >=0 -> var <= shift
+ if (max_value.defined()) {
if (p.max_value.defined()) {
- p.max_value = min(p.max_value, ret.base);
+ p.max_value = min(p.max_value, max_value.value());
} else {
- p.max_value = ret.base;
+ p.max_value = max_value.value();
}
- return true;
}
- return false;
+ return true;
}
template <typename OP>
diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py
b/tests/python/unittest/test_arith_detect_clip_bound.py
index 0a9d75fcea..03fff11f77 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/tests/python/unittest/test_arith_detect_clip_bound.py
@@ -39,5 +39,18 @@ def test_basic():
tvm.testing.assert_prim_expr_equal(m[2], 4)
+def test_trivial_eq():
+ a = te.var("a")
+ b = te.var("b")
+ m = tvm.arith.detect_clip_bound(b == 3, [a, b])
+ tvm.testing.assert_prim_expr_equal(m[2], 3)
+ tvm.testing.assert_prim_expr_equal(m[3], 3)
+ m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b])
+ tvm.testing.assert_prim_expr_equal(m[0], 4)
+ tvm.testing.assert_prim_expr_equal(m[1], 4)
+ tvm.testing.assert_prim_expr_equal(m[2], 3)
+ tvm.testing.assert_prim_expr_equal(m[3], 3)
+
+
if __name__ == "__main__":
test_basic()