This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8f9f605dd5 [ARITH] Enhance buffer shape bound deduction to include
offset (#15228)
8f9f605dd5 is described below
commit 8f9f605dd599a623cf56159fe24bd8fd08489c2b
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jul 4 21:38:16 2023 -0400
[ARITH] Enhance buffer shape bound deduction to include offset (#15228)
This PR enhances buffer shape hint so shape expressions like n - 1
will deduce n >= 1
---
src/arith/analyzer.cc | 32 +++++++++++++++++-----
src/arith/product_normal_form.h | 18 ++++++++++++
.../python/unittest/test_tir_transform_simplify.py | 16 ++++++++---
3 files changed, 55 insertions(+), 11 deletions(-)
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 9e5b1414ed..3e5b8834eb 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -65,17 +65,34 @@ void Analyzer::Bind(const Var& var, const Range& range,
bool allow_override) {
}
void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
- // split out the symbolic and non-symbolic part
+ // decompose value as symbol * scale + offset
+ int64_t offset = 0;
+ PrimExpr symbol_scale = tir::make_const(value.dtype(), 0);
+
+ auto fcollect_sum = [&](PrimExpr val, int sign) {
+ if (const auto* intimm = val.as<IntImmNode>()) {
+ offset += intimm->value * sign;
+ } else {
+ if (sign > 0) {
+ symbol_scale = symbol_scale + val;
+ } else {
+ symbol_scale = symbol_scale - val;
+ }
+ }
+ };
+ UnpackSum(value, fcollect_sum);
+
+ // split out the symbol and non-symbolic part
int64_t cscale = 1;
- PrimExpr symbolic = tir::make_const(value.dtype(), 1);
- auto fcollect = [&](PrimExpr val) {
+ PrimExpr symbol = tir::make_const(value.dtype(), 1);
+ auto fcollect_prod = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
} else {
- symbolic = symbolic * val;
+ symbol = symbol * val;
}
};
- UnpackReduction<tir::MulNode>(value, fcollect);
+ UnpackReduction<tir::MulNode>(symbol_scale, fcollect_prod);
if (cscale <= 0) return;
// override the constant int bound by marking it as non-negative
// NOTE: there might be future opportunities of more bound hint
@@ -83,7 +100,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
//
// We may consider enhance the sub analyzer to directly take
// MarkPositiveVar so their bounds do not overlap
- if (const auto* var_ptr = symbolic.as<VarNode>()) {
+ if (const auto* var_ptr = symbol.as<VarNode>()) {
Var var = GetRef<Var>(var_ptr);
// skip non-index type, keep it to be compatible
// with any_dim that do not represent any value
@@ -92,7 +109,8 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
// mark the constant bound is sufficient
// we cannot mark interval set as that will cause relaxation of the var
// during bound proof which is not our intention
- this->const_int_bound.Update(var, ConstIntBound(0,
ConstIntBound::kPosInf), allow_override);
+ this->const_int_bound.Update(var, ConstIntBound(-offset,
ConstIntBound::kPosInf),
+ allow_override);
}
}
diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h
index 768a3a2b8b..d27ca76650 100644
--- a/src/arith/product_normal_form.h
+++ b/src/arith/product_normal_form.h
@@ -47,6 +47,24 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf
fleaf) {
}
}
+/**
+ * \brief Unpack chain of add sub by calling each leaf via fleaf
+ * \param value The expression value.
+ * \tparam FLeaf The callback function at leaf.
+ */
+template <typename FLeaf>
+inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) {
+ if (const tir::AddNode* node = value.as<tir::AddNode>()) {
+ UnpackSum(node->a, fleaf, sign);
+ UnpackSum(node->b, fleaf, sign);
+ } else if (const tir::SubNode* node = value.as<tir::SubNode>()) {
+ UnpackSum(node->a, fleaf, sign);
+ UnpackSum(node->b, fleaf, -sign);
+ } else {
+ fleaf(value, sign);
+ }
+}
+
/*!
* \brief Helper function to multiply extent and and re-normalize.
*
diff --git a/tests/python/unittest/test_tir_transform_simplify.py
b/tests/python/unittest/test_tir_transform_simplify.py
index 79fd5e1434..c779d92f9c 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1734,10 +1734,6 @@ class TestSimplifyTrivialLetStride(BaseBeforeAfter):
class TestBufferShapeConstraint(BaseBeforeAfter):
- """If enabled, rewrite boolean expressions into AND of OR"""
-
- convert_boolean_to_and_of_ors = True
-
def before(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32,), "float32")
@@ -1749,5 +1745,17 @@ class TestBufferShapeConstraint(BaseBeforeAfter):
A[T.int64(0)] = T.float32(0)
+class TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
+ def before(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
+ A[T.min(T.int64(1), n)] = T.float32(0)
+
+ def expected(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
+ A[T.int64(1)] = T.float32(0)
+
+
if __name__ == "__main__":
tvm.testing.main()