This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a84a2cbe07 [ARITH] Enhance CanProve to handle symbolic bound (#14523)
a84a2cbe07 is described below

commit a84a2cbe07dfe608acd79453affc53330ebab1f0
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Apr 8 10:22:50 2023 -0400

    [ARITH] Enhance CanProve to handle symbolic bound (#14523)
    
    This PR enhances CanProve to handle symbolic bound.
    Such analysis is essential to eliminate predicates in
    dynamic shape workloads.
    
    We also the int set analysis singlepoint check to avoid recursion
    and improve the overall analysis speed.
    
    Added CanProveSinglePoint to serve previous stronger checks.
    
    The new CanProve comes with additinal strength argument
    that can only be used in top-level setting with stronger analysis.
    
    Added comment for future implementation efficiency.
    
    Testcases are added to cover the cases.
---
 include/tvm/arith/analyzer.h                      | 23 +++++++++++-
 include/tvm/arith/int_set.h                       | 16 +++++++++
 python/tvm/arith/__init__.py                      |  2 +-
 python/tvm/arith/analyzer.py                      | 27 ++++++++++++++
 src/arith/analyzer.cc                             | 43 +++++++++++++++++++++--
 src/arith/int_set.cc                              | 24 +++++++++++--
 src/arith/interval_set.h                          |  9 +++--
 src/arith/rewrite_simplify.cc                     |  8 +++++
 src/arith/rewrite_simplify.h                      |  1 -
 src/tir/analysis/block_access_region_detector.cc  |  4 ++-
 src/tir/schedule/primitive/compute_at.cc          |  7 ++--
 src/tir/schedule/primitive/loop_transformation.cc |  2 +-
 tests/python/unittest/test_arith_simplify.py      | 31 ++++++++++++++++
 13 files changed, 178 insertions(+), 19 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 885c23f491..e64426aca3 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -59,6 +59,22 @@ enum DivMode {
   kFloorDiv
 };
 
+/*!
+ * \brief The strength used in top-level condition proves
+ * \note The higher, the more time consuming it can be.
+ *
+ * Do not use level beyond kDefault in internal recursive rewriting in arith
+ * analysis and only use it at top-level simplification to avoid speed issues.
+ */
+enum class ProofStrength : int {
+  /*! \brief default strength, can be used in. */
+  kDefault = 0,
+  /*!
+   * \brief Prove using symbolic bound analysis
+   */
+  kSymbolicBound = 1
+};
+
 /*!
  * \brief Constant integer up and lower bound(inclusive).
  *  Useful for value bound analysis.
@@ -656,11 +672,16 @@ class TVM_DLL Analyzer {
    * \brief Whether can we prove condition.
    *
    * \param cond The expression to be proved.
+   * \param strength the strength of the prove.
+   *
    * \return The result.
    *
    * \note Analyzer will call into sub-analyzers to get the result.
+   * Do not use strength beyond default in sub-analyzers and
+   * only use it in top-level predicate analysis.
    */
-  bool CanProve(const PrimExpr& cond);
+  bool CanProve(const PrimExpr& cond, ProofStrength strength = 
ProofStrength::kDefault);
+
   /*!
    * \brief Simplify expr.
    *
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index 60d7c53d28..f09564d050 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -85,6 +85,22 @@ class IntSet : public ObjectRef {
   bool IsEverything() const;
   /*! \return Whether the set is a single point */
   bool IsSinglePoint() const;
+  /*!
+   * \brief Check if we can prove it is a single point.
+   *
+   * Unlike IsSinglePoint, which only checks ptr equality
+   * this function will invoke analyzer to do stonger proofs
+   * but also takes longer time.
+   *
+   * Use this function in some of the primitives but do not
+   * use it in the inner loop of simplification.
+   *
+   * \param ana Analyzer used in the proof.
+   * \return Whether we can prove it is a single point
+   */
+  bool CanProveSinglePoint(Analyzer* ana) const;
+  // TODO(tvm-team): update all CanProve to explicitly take
+  // analyzer to encourage more analyzer reuse
   /*! \return Whether the set is proved to be bigger than 0 */
   bool CanProvePositive() const;
   /*! \return Whether the set is proved to be smaller than 0 */
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index 423aafe5d6..401836aa19 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -23,7 +23,7 @@ from .int_set import (
     estimate_region_strict_bound,
     estimate_region_upper_bound,
 )
-from .analyzer import ModularSet, ConstIntBound, Analyzer
+from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
 from .bound import deduce_bound
 from .pattern import detect_linear_equation, detect_clip_bound, 
detect_common_subexpr
 from .int_solver import solve_linear_equations, solve_linear_inequalities
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 28adbe9d81..5ea2dfad9d 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -15,11 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 """Arithmetic data structure and utility"""
+from enum import IntEnum
 import tvm._ffi
 from tvm.runtime import Object
 from . import _ffi_api
 
 
+class ProofStrength(IntEnum):
+    """Proof strength of the analysis"""
+
+    DEFAULT = 0
+    SYMBOLIC_BOUND = 1
+
+
 @tvm._ffi.register_object("arith.ModularSet")
 class ModularSet(Object):
     """Represent range of (coeff * x + base) for x in Z"""
@@ -91,6 +99,7 @@ class Analyzer:
         self._int_set = _mod("int_set")
         self._enter_constraint_context = _mod("enter_constraint_context")
         self._can_prove_equal = _mod("can_prove_equal")
+        self._can_prove = _mod("can_prove")
 
     def const_int_bound(self, expr):
         """Find constant integer bound for expr.
@@ -190,6 +199,24 @@ class Analyzer:
         """
         return self._int_set(expr, dom_map)
 
+    def can_prove(self, expr, strength=ProofStrength.DEFAULT):
+        """Check whether we can prove expr to be true.
+
+        Parameters
+        ----------
+        expr : PrimExpr
+            The expression.
+
+        strength: ProofStrength
+            The proof strength
+
+        Returns
+        -------
+        result : Expr
+            The result.
+        """
+        return self._can_prove(expr, strength)
+
     def bind(self, var, expr):
         """Bind a variable to the expression.
 
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 4714cf1df5..89dcb8301a 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const 
PrimExpr& rhs) {
   return CanProve(lhs - rhs == 0);
 }
 
-bool Analyzer::CanProve(const PrimExpr& expr) {
+bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
   // Avoid potentially expensive simplification unless required.
   if (const auto* ptr = expr.as<IntImmNode>()) {
     return ptr->value != 0;
   }
-
   PrimExpr simplified = Simplify(expr);
   const int64_t* as_int = tir::as_const_int(simplified);
-  return as_int && *as_int;
+  if (as_int && *as_int) return true;
+  if (strength >= ProofStrength::kSymbolicBound) {
+    // NOTE: we intentionally only pattern match common bound predicate i < 
bound
+    // and put this implementation at the top-level.
+    // This is to avoid repeatitive calling of this function
+    // that causes speed issues.
+    // This strategy can only be called from top-level and not from 
sub-analyzers.
+    Optional<PrimExpr> pos_diff;
+    int lower_bound = 0;
+    if (const auto* ptr_lt = expr.as<tir::LTNode>()) {
+      pos_diff = ptr_lt->b - ptr_lt->a;
+      lower_bound = 1;
+    }
+    if (const auto* ptr_le = expr.as<tir::LENode>()) {
+      pos_diff = ptr_le->b - ptr_le->a;
+      lower_bound = 0;
+    }
+    if (const auto* ptr_gt = expr.as<tir::GTNode>()) {
+      pos_diff = ptr_gt->a - ptr_gt->b;
+      lower_bound = 1;
+    }
+    if (const auto* ptr_ge = expr.as<tir::GENode>()) {
+      pos_diff = ptr_ge->a - ptr_ge->b;
+      lower_bound = 0;
+    }
+    if (pos_diff) {
+      IntSet iset = this->int_set(this->Simplify(pos_diff.value()));
+      if (iset.HasLowerBound()) {
+        ConstIntBound relaxed_lower_bound = 
this->const_int_bound(this->Simplify(iset.min()));
+        if (relaxed_lower_bound->min_value >= lower_bound) return true;
+      }
+    }
+  }
+  return false;
 }
 
 PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
@@ -189,6 +221,11 @@ 
TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
           self->Bind(args[0], args[1].operator PrimExpr());
         }
       });
+    } else if (name == "can_prove") {
+      return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
+        int strength = args[1];
+        *ret = self->CanProve(args[0], static_cast<ProofStrength>(strength));
+      });
     } else if (name == "enter_constraint_context") {
       return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
         // can't use make_shared due to noexcept(false) decl in destructor,
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index a75d316a7e..1ad182aa83 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -492,6 +492,11 @@ class IntervalSetEvaluator : public 
ExprFunctor<IntervalSet(const PrimExpr&)> {
 
   IntervalSet VisitExpr_(const CastNode* op) final {
     IntervalSet value_set = this->Eval(op->value);
+    // short cut for the int set.
+    if (value_set->min_value.same_as(value_set->max_value)) {
+      if (value_set->IsEmpty()) return value_set;
+      return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value));
+    }
     PrimExpr min_value =
         value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : 
neg_inf();
     PrimExpr max_value =
@@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const {
   return (s_int && s_int->IsSinglePoint());
 }
 
+bool IntSet::CanProveSinglePoint(Analyzer* ana) const {
+  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+  if (!s_int) return false;
+  if (s_int->IsSinglePoint()) return true;
+  return ana->CanProveEqual(s_int->min_value, s_int->max_value);
+}
+
 bool IntSet::CanProvePositive() const {
   Analyzer analyzer;
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
@@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& 
dom_map) {
 }
 
 IntSet IntSet::Vector(PrimExpr x) {
-  Analyzer ana;
-  Map<Var, IntSet> dmap;
-  return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
+  // short cut: simply get single point
+  if (x.dtype().lanes() == 1) {
+    return IntSet::SinglePoint(x);
+  } else {
+    // vector case.
+    Analyzer ana;
+    Map<Var, IntSet> dmap;
+    return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
+  }
 }
 
 IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h
index 98fe5bdc2b..dc40fa9d4d 100644
--- a/src/arith/interval_set.h
+++ b/src/arith/interval_set.h
@@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode {
   bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
   /*! \return Whether the interval is a single point. */
   bool IsSinglePoint() const {
-    if (min_value.same_as(max_value)) {
-      return true;
-    }
-    Analyzer analyzer;
-    return analyzer.CanProveEqual(min_value, max_value);
+    // NOTE: we are only doing cheap check as this is a frequently called 
routine,
+    // do manual prove of min and max for stronger single point check.
+    return min_value.same_as(max_value);
   }
+
   /*! \return whether interval represent nothing */
   bool IsEmpty() const {
     // during computations, either extreme could occur.
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e44ef31da1..528a272ef4 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -168,6 +168,12 @@ CompareResult 
RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr
 
 // try to prove x equals val
 CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t 
val) {
+  // NOTE on implementation: this function can be called many times and can be 
a bottleneck,
+  // As a result, we keep comparison here lightweight.
+  // We only do constant int bound analysis here.
+  //
+  // For stronger comparison proof that is out of the recursive simplifcation
+  // consider look at analyzer::CanProveStrong
   PrimExpr diff = this->VisitExpr(x);
   if (const auto* ptr = diff.as<IntImmNode>()) {
     if (ptr->value == val) {
@@ -194,6 +200,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const 
PrimExpr& x, int64_t val
   if (dbound->max_value <= val) {
     return CompareResult::kLE;
   }
+
+  // modular analysis
   if (val == 0) {
     ModularSet dmod = analyzer_->modular_set(diff);
     if (dmod->base != 0) {
diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h
index b8e7fcdd94..22e7a0b74c 100644
--- a/src/arith/rewrite_simplify.h
+++ b/src/arith/rewrite_simplify.h
@@ -104,7 +104,6 @@ class RewriteSimplifier::Impl : public 
IRMutatorWithAnalyzer {
 
   // maximum number of recursion allowed during a single pass.
   static const constexpr int kMaxRecurDepth = 5;
-
   /*!
    * \brief try to compare x against val.
    * \param x The expression to be evaluated.
diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
index 409356c2b1..057cec475d 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
   Map<Var, Buffer> buffer_var_map_;
   /*! \brief The target buffer var mapping to its matching */
   std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
+  /*!\ brief Internal analyzer. */
+  arith::Analyzer ana_;
 
   /*!
    * \brief Update read/write buffers and regions with provided buffer and 
region
@@ -318,7 +320,7 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
     ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
     for (size_t j = 0; j < regions[i].size(); j++) {
       const tvm::arith::IntSet& range = regions[i][j];
-      if (range.IsSinglePoint()) {
+      if (range.CanProveSinglePoint(&ana_)) {
         PrimExpr min = range.min();
         region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 
1)));
       } else {
diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 988c73c3f0..75ea308de8 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise(
     arith::IntSet required = required_region[i];
     PrimExpr dim_max = max(buffer->shape[i] - 1, 0);
 
-    if (provided.IsSinglePoint() && is_const_int(provided.min())) {
-      ICHECK(required.IsSinglePoint() && 
analyzer->CanProveEqual(provided.min(), required.min()));
+    if (provided.CanProveSinglePoint(analyzer) && 
is_const_int(provided.min())) {
+      ICHECK(required.CanProveSinglePoint(analyzer) &&
+             analyzer->CanProveEqual(provided.min(), required.min()));
       continue;
     }
 
@@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, 
const Array<IterVar>&
                                 std::unordered_map<const VarNode*, 
BlockVarDomainInfo>* iter_doms) {
   // we only support single point provided region now, which could cover most 
cases
   for (const auto& intset : provided_region) {
-    if (!intset.IsSinglePoint()) return false;
+    if (!intset.CanProveSinglePoint(analyzer)) return false;
   }
   // calculate forward mapping (block vars -> provided region point)
   Map<Var, Range> dom_map;
diff --git a/src/tir/schedule/primitive/loop_transformation.cc 
b/src/tir/schedule/primitive/loop_transformation.cc
index d9c58a0381..a26843b7bd 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -430,7 +430,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& 
loop_sref, const Array
       &opaque_block_reuse)(std::move(new_stmt));
   // Step 3. Update predicate to guard the loop
   PrimExpr predicate = substitute_value < loop->extent;
-  if (!analyzer.CanProve(predicate)) {
+  if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) {
     new_stmt = 
BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt));
   }
   // Step 4. Generate nested loops to replace the original loop and simplify 
the binding
diff --git a/tests/python/unittest/test_arith_simplify.py 
b/tests/python/unittest/test_arith_simplify.py
index aa9d5179aa..754bf36d7a 100644
--- a/tests/python/unittest/test_arith_simplify.py
+++ b/tests/python/unittest/test_arith_simplify.py
@@ -34,5 +34,36 @@ def test_simplify_reshape_flattened_index():
     )
 
 
+def test_simplify_symbolic_comparison():
+    ana = tvm.arith.Analyzer()
+
+    i0 = tir.Var("i0", "int64")
+    i1 = tir.Var("i1", "int64")
+    n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
+    outer = (n + 31) // 32
+    ana.bind(i0, tvm.ir.Range(0, outer))
+    ana.bind(i1, tvm.ir.Range(0, 32))
+    PS = tvm.arith.ProofStrength
+
+    assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT)
+    assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
+    assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, 
PS.SYMBOLIC_BOUND)
+    assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, 
PS.SYMBOLIC_BOUND)
+    assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, 
PS.SYMBOLIC_BOUND)
+    assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, 
PS.SYMBOLIC_BOUND)
+
+
+def test_regression_simplify_inf_recursion():
+    ana = tvm.arith.Analyzer()
+    cond = tir.Var("cond", "int32")
+
+    res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 
0).astype("int8")).astype(
+        "int32"
+    ) == 0
+    # regression in a previous case
+    # try compare and int set recursive call can cause infinite loop
+    ana.rewrite_simplify(res)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to