This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fb64be3f78 [ARITH] Allow Analyzer to MarkGlobalNonNegValue (#15193)
fb64be3f78 is described below

commit fb64be3f7807df18c2df6ebf5e68178e564ab0b4
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jul 2 23:02:11 2023 -0400

    [ARITH] Allow Analyzer to MarkGlobalNonNegValue (#15193)
    
    This PR introduces an utility function MarkGlobalNonNegValue.
    This function allows analyzer to mark buffer shapes in function arguments
    as positive globally and opens doors for more symbolic simplification.
---
 include/tvm/arith/analyzer.h                       | 16 +++++++++
 src/arith/analyzer.cc                              | 33 +++++++++++++++++++
 src/arith/const_int_bound.cc                       | 38 ++++++++++++++++++----
 src/arith/ir_mutator_with_analyzer.cc              |  9 +++++
 src/arith/ir_mutator_with_analyzer.h               |  8 +++++
 src/tir/transforms/flatten_buffer.cc               |  1 +
 src/tir/transforms/simplify.cc                     | 21 +++++-------
 .../python/unittest/test_arith_const_int_bound.py  | 17 ++++++++++
 .../python/unittest/test_tir_transform_simplify.py | 16 +++++++++
 9 files changed, 140 insertions(+), 19 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index efa26a31d0..fb837ca3d0 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -618,6 +618,22 @@ class TVM_DLL Analyzer {
   TransitiveComparisonAnalyzer transitive_comparisons;
   /*! \brief constructor */
   Analyzer();
+  /*!
+   * \brief Mark the value as non-negative value globally in analyzer.
+   *
+   * Only call this function if the non-neg condition is global and
+   * not context-dependent.
+   *
+   * This function does best-effort propagations to the sub-analyzers
+   *
+   * \note We expose this function because non-negative global values,
+   * such as symbolic buffer shapes in function arguments are really
+   * important to ensure the best simplification, and usually they
+   * can be handled in a simpler way than the generic constraints.
+   *
+   * This function may call into the Update function of the sub-analyzers.
+   */
+  void MarkGlobalNonNegValue(const PrimExpr& value);
   /*!
    * \brief Notify all the sub-analyzers that var
    *        is created and binded to expr.
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 722a2cd00e..9e5b1414ed 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -25,6 +25,7 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 
+#include "const_fold.h"
 #include "product_normal_form.h"
 
 namespace tvm {
@@ -63,6 +64,38 @@ void Analyzer::Bind(const Var& var, const Range& range, bool 
allow_override) {
   // skip rewrite simplify
 }
 
+void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
+  // split out the symbolic and non-symbolic part
+  int64_t cscale = 1;
+  PrimExpr symbolic = tir::make_const(value.dtype(), 1);
+  auto fcollect = [&](PrimExpr val) {
+    if (const auto* intimm = val.as<IntImmNode>()) {
+      cscale *= intimm->value;
+    } else {
+      symbolic = symbolic * val;
+    }
+  };
+  UnpackReduction<tir::MulNode>(value, fcollect);
+  if (cscale <= 0) return;
+  // override the constant int bound by marking it as non-negative
+  // NOTE: there might be future opportunities of more bound hint
+  // this is a simple step and covers all the current needs
+  //
+  // We may consider enhance the sub analyzer to directly take
+  // MarkPositiveVar so their bounds do not overlap
+  if (const auto* var_ptr = symbolic.as<VarNode>()) {
+    Var var = GetRef<Var>(var_ptr);
+    // skip non-index type, keep it to be compatible
+    // with any_dim that do not represent any value
+    if (!IsIndexType(var.dtype())) return;
+    bool allow_override = true;
+    // mark the constant bound is sufficient
+    // we cannot mark interval set as that will cause relaxation of the var
+    // during bound proof which is not our intention
+    this->const_int_bound.Update(var, ConstIntBound(0, 
ConstIntBound::kPosInf), allow_override);
+  }
+}
+
 void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
   for (const auto& iter : variables) {
     this->Bind(iter.first, iter.second, allow_override);
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 68ade3bb54..8ce5025231 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -195,6 +195,31 @@ class ConstIntBoundAnalyzer::Impl
     return Intersect(a, b);
   }
 
+  /*!
+   * \brief Process the divisor by making assumption that divide by zero
+   * won't happen in a valid program.
+   *
+   * This is important for us to get a lot of symbolic shape bound right
+   * now that the shape n >= 0, but in cases
+   * when mod or divide of n occur, the intention is actually n > 0
+   *
+   * \param divisor The input divsor entry
+   * \return The processed entry
+   */
+  Entry AssumeNoZeroDivisor(Entry divisor) {
+    ICHECK(!divisor.is_const(0)) << "Find divide by zero";
+    // NOTE: here we make the assumption that
+    // divide by zero won't happen in a valid program
+    // this is important for us to get a lot of symbolic shape bound right
+    // where most conditions know that the shape n >= 0, but in cases
+    // when mod or divide of n occur, the intention is actually n > 0
+    if (divisor.min_value == 0) {
+      divisor.min_value = 1;
+      ICHECK_GE(divisor.max_value, 1);
+    }
+    return divisor;
+  }
+
   Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, 
op->value); }
 
   Entry VisitExpr_(const AddNode* op) final {
@@ -223,14 +248,14 @@ class ConstIntBoundAnalyzer::Impl
 
   Entry VisitExpr_(const DivNode* op) final {
     Entry a = VisitExpr(op->a);
-    Entry b = VisitExpr(op->b);
-    ICHECK(!b.is_const(0)) << "divide by zero";
+    Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
     return HandleDivision(a, b, op->dtype, InfAwareDiv);
   }
 
   Entry VisitExpr_(const ModNode* op) final {
     Entry a = VisitExpr(op->a);
-    Entry b = VisitExpr(op->b);
+    Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
+
     if (b.min_value > 0) {
       int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
       if (a.min_value >= 0) {
@@ -252,8 +277,7 @@ class ConstIntBoundAnalyzer::Impl
 
   Entry VisitExpr_(const FloorDivNode* op) final {
     Entry a = VisitExpr(op->a);
-    Entry b = VisitExpr(op->b);
-    ICHECK(!b.is_const(0)) << "floordiv by zero";
+    Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
     return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
   }
 
@@ -276,7 +300,8 @@ class ConstIntBoundAnalyzer::Impl
      * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1)
      */
     Entry a = VisitExpr(op->a);
-    Entry b = VisitExpr(op->b);
+    Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
+
     if (b.min_value > 0) {
       int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
       if (a.min_value >= 0) {
@@ -457,7 +482,6 @@ class ConstIntBoundAnalyzer::Impl
     // at a negative value and ends at a positive one, narrow it down to
     // be closer to 0, because BinaryOpBoundary only checks end-points of
     // the domain ranges.
-
     // If the range of b contains 0, then some infinity will be involved
     if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) {
       Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : 
Everything(dt);
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index c201a245e1..1f087d9934 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -30,6 +30,15 @@ namespace arith {
 
 using namespace tir;
 
+void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
+  // Mark the all the symbolic buffer shape values in the buffer map as 
positive value.
+  for (auto kv : func->buffer_map) {
+    for (PrimExpr shape : kv.second->shape) {
+      analyzer_->MarkGlobalNonNegValue(shape);
+    }
+  }
+}
+
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
   // record the loop variable as iterators
   Range dom = Range::FromMinExtent(op->min, op->extent);
diff --git a/src/arith/ir_mutator_with_analyzer.h 
b/src/arith/ir_mutator_with_analyzer.h
index ed62c91df9..f04b40e7ae 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -62,6 +62,14 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
   PrimExpr VisitExpr_(const tir::ReduceNode* op) override;
 
  protected:
+  /*!
+   * \brief Mark the all the buffer shape values in the buffer map as positive 
value.
+   *
+   * \note call this function before Visit function's body to maximize
+   *       simplification efficiency
+   */
+  void MarkBufferMapShapes(const tir::PrimFunc& func);
+
   /*! \brief internal analyzer field. */
   Analyzer* analyzer_;
   // the following two fields are useful in case we want
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index ffdff45a7d..f37c21593f 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -42,6 +42,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
     arith::Analyzer ana;
     auto pass = BufferFlattener(&ana);
     auto writer = func.CopyOnWrite();
+    pass.MarkBufferMapShapes(func);
     writer->body = pass.VisitStmt(func->body);
     // The buffers in func->buffer_map are deliberately left
     // unflattened, as they are used for validation of user-provided
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 130cbe37c1..44d64df63d 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", 
SimplifyConfig);
 
 class StmtSimplifier : public IRMutatorWithAnalyzer {
  public:
-  static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> 
config_opt = NullOpt) {
+  static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
+                        Optional<SimplifyConfig> config_opt = NullOpt) {
     auto config = 
config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
     
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
 
     std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
     if (config->propagate_knowns_to_prove_conditional ||
         config->propagate_knowns_to_simplify_expressions) {
-      touch_pattern = ControlFlowGraph(stmt);
+      touch_pattern = ControlFlowGraph(func->body);
     }
 
-    std::unordered_set<const VarNode*> used_in_buffer_def = 
CollectVarsUsedInBufferDefinition(stmt);
+    std::unordered_set<const VarNode*> used_in_buffer_def =
+        CollectVarsUsedInBufferDefinition(func->body);
     StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
                               std::move(used_in_buffer_def));
-    return simplifier(std::move(stmt));
+    simplifier.MarkBufferMapShapes(func);
+    func.CopyOnWrite()->body = simplifier(func->body);
+    return func;
   }
 
  private:
@@ -335,11 +339,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
 }  // namespace arith
 
 namespace tir {
-
-Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
-  return arith::StmtSimplifier::Apply(stmt, analyzer);
-}
-
 namespace transform {
 
 Pass Simplify() {
@@ -347,9 +346,7 @@ Pass Simplify() {
     arith::Analyzer analyzer;
     auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");
 
-    auto* n = f.CopyOnWrite();
-    n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
-    return f;
+    return arith::StmtSimplifier::Apply(f, &analyzer, cfg);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
 }
diff --git a/tests/python/unittest/test_arith_const_int_bound.py 
b/tests/python/unittest/test_arith_const_int_bound.py
index d9ea36206b..5667c79aac 100644
--- a/tests/python/unittest/test_arith_const_int_bound.py
+++ b/tests/python/unittest/test_arith_const_int_bound.py
@@ -339,6 +339,23 @@ def test_floormod_negative_divisor():
     assert bd.max_value == 6
 
 
+def test_divmod_assume_no_zero_divsor():
+    # Divmod non negative expression makes assumption that divide by zero 
won't occur
+    # this assumption is important to get best result from symbolic shape 
programs
+    analyzer = tvm.arith.Analyzer()
+    flm, fld = tvm.te.floormod, tvm.te.floordiv
+    a, b = te.var("a"), te.var("b")
+    analyzer.update(a, tvm.arith.ConstIntBound(0, 6))
+    analyzer.update(b, tvm.arith.ConstIntBound(0, 
tvm.arith.ConstIntBound.POS_INF))
+    bd = analyzer.const_int_bound(fld(a, b))
+    assert bd.min_value == 0
+    assert bd.max_value == 6
+
+    bd = analyzer.const_int_bound(flm(a, b))
+    assert bd.min_value == 0
+    assert bd.max_value == 6
+
+
 def test_multiple_condition():
     analyzer = tvm.arith.Analyzer()
     flm, fld = tvm.te.floormod, tvm.te.floordiv
diff --git a/tests/python/unittest/test_tir_transform_simplify.py 
b/tests/python/unittest/test_tir_transform_simplify.py
index 1f25405ec9..79fd5e1434 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1733,5 +1733,21 @@ class TestSimplifyTrivialLetStride(BaseBeforeAfter):
     expected = before
 
 
+class TestBufferShapeConstraint(BaseBeforeAfter):
+    """If enabled, rewrite boolean expressions into AND of OR"""
+
+    convert_boolean_to_and_of_ors = True
+
+    def before(a: T.handle):
+        n = T.int64()
+        A = T.match_buffer(a, (n * 32,), "float32")
+        A[T.min(T.int64(0), n)] = T.float32(0)
+
+    def expected(a: T.handle):
+        n = T.int64()
+        A = T.match_buffer(a, (n * 32,), "float32")
+        A[T.int64(0)] = T.float32(0)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to