This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f9f605dd5 [ARITH] Enhance buffer shape bound deduction to include 
offset (#15228)
8f9f605dd5 is described below

commit 8f9f605dd599a623cf56159fe24bd8fd08489c2b
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jul 4 21:38:16 2023 -0400

    [ARITH] Enhance buffer shape bound deduction to include offset (#15228)
    
    This PR enhances buffer shape hint so shape expressions like n - 1
    will deduce n >= 1
---
 src/arith/analyzer.cc                              | 32 +++++++++++++++++-----
 src/arith/product_normal_form.h                    | 18 ++++++++++++
 .../python/unittest/test_tir_transform_simplify.py | 16 ++++++++---
 3 files changed, 55 insertions(+), 11 deletions(-)

diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 9e5b1414ed..3e5b8834eb 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -65,17 +65,34 @@ void Analyzer::Bind(const Var& var, const Range& range, 
bool allow_override) {
 }
 
 void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
-  // split out the symbolic and non-symbolic part
+  // decompose value as symbol * scale + offset
+  int64_t offset = 0;
+  PrimExpr symbol_scale = tir::make_const(value.dtype(), 0);
+
+  auto fcollect_sum = [&](PrimExpr val, int sign) {
+    if (const auto* intimm = val.as<IntImmNode>()) {
+      offset += intimm->value * sign;
+    } else {
+      if (sign > 0) {
+        symbol_scale = symbol_scale + val;
+      } else {
+        symbol_scale = symbol_scale - val;
+      }
+    }
+  };
+  UnpackSum(value, fcollect_sum);
+
+  // split out the symbol and non-symbolic part
   int64_t cscale = 1;
-  PrimExpr symbolic = tir::make_const(value.dtype(), 1);
-  auto fcollect = [&](PrimExpr val) {
+  PrimExpr symbol = tir::make_const(value.dtype(), 1);
+  auto fcollect_prod = [&](PrimExpr val) {
     if (const auto* intimm = val.as<IntImmNode>()) {
       cscale *= intimm->value;
     } else {
-      symbolic = symbolic * val;
+      symbol = symbol * val;
     }
   };
-  UnpackReduction<tir::MulNode>(value, fcollect);
+  UnpackReduction<tir::MulNode>(symbol_scale, fcollect_prod);
   if (cscale <= 0) return;
   // override the constant int bound by marking it as non-negative
   // NOTE: there might be future opportunities of more bound hint
@@ -83,7 +100,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
   //
   // We may consider enhance the sub analyzer to directly take
   // MarkPositiveVar so their bounds do not overlap
-  if (const auto* var_ptr = symbolic.as<VarNode>()) {
+  if (const auto* var_ptr = symbol.as<VarNode>()) {
     Var var = GetRef<Var>(var_ptr);
     // skip non-index type, keep it to be compatible
     // with any_dim that do not represent any value
@@ -92,7 +109,8 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
     // mark the constant bound is sufficient
     // we cannot mark interval set as that will cause relaxation of the var
     // during bound proof which is not our intention
-    this->const_int_bound.Update(var, ConstIntBound(0, 
ConstIntBound::kPosInf), allow_override);
+    this->const_int_bound.Update(var, ConstIntBound(-offset, 
ConstIntBound::kPosInf),
+                                 allow_override);
   }
 }
 
diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h
index 768a3a2b8b..d27ca76650 100644
--- a/src/arith/product_normal_form.h
+++ b/src/arith/product_normal_form.h
@@ -47,6 +47,24 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf 
fleaf) {
   }
 }
 
+/**
+ * \brief Unpack chain of add sub by calling each leaf via fleaf
+ * \param value The expression value.
+ * \tparam FLeaf The callback function at leaf.
+ */
+template <typename FLeaf>
+inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) {
+  if (const tir::AddNode* node = value.as<tir::AddNode>()) {
+    UnpackSum(node->a, fleaf, sign);
+    UnpackSum(node->b, fleaf, sign);
+  } else if (const tir::SubNode* node = value.as<tir::SubNode>()) {
+    UnpackSum(node->a, fleaf, sign);
+    UnpackSum(node->b, fleaf, -sign);
+  } else {
+    fleaf(value, sign);
+  }
+}
+
 /*!
  * \brief Helper function to multiply extent and and re-normalize.
  *
diff --git a/tests/python/unittest/test_tir_transform_simplify.py 
b/tests/python/unittest/test_tir_transform_simplify.py
index 79fd5e1434..c779d92f9c 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1734,10 +1734,6 @@ class TestSimplifyTrivialLetStride(BaseBeforeAfter):
 
 
 class TestBufferShapeConstraint(BaseBeforeAfter):
-    """If enabled, rewrite boolean expressions into AND of OR"""
-
-    convert_boolean_to_and_of_ors = True
-
     def before(a: T.handle):
         n = T.int64()
         A = T.match_buffer(a, (n * 32,), "float32")
@@ -1749,5 +1745,17 @@ class TestBufferShapeConstraint(BaseBeforeAfter):
         A[T.int64(0)] = T.float32(0)
 
 
+class TestBufferShapeConstraintWithOffset(BaseBeforeAfter):
+    def before(a: T.handle):
+        n = T.int64()
+        A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
+        A[T.min(T.int64(1), n)] = T.float32(0)
+
+    def expected(a: T.handle):
+        n = T.int64()
+        A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32")
+        A[T.int64(1)] = T.float32(0)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to