This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b5dd136d7  [Arith] Updated BufferDomainTouched to use 
IRVisitorWithAnalyzer (#11970)
4b5dd136d7 is described below

commit 4b5dd136d764fbef5f552ffce0759232c138e4e2
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jul 13 11:17:53 2022 -0500

     [Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer (#11970)
    
    * [Arith] Allow binding of Var in IntSetAnalyzer
    
    The other four subanalyzers in `arith::Analyzer` can each be provided
    with variable bindings/constraints that are remembered internally.
    This adds the same capability to `IntSetAnalyzer`, rather than
    requiring users to independently track and maintain a `Map<Var,
    IntSet>` containing the domain of each variable, and applies
    bindings/constraints alongside the other subanalyzers.
    
    * [Arith] Updated IRVisitorWithAnalyzer to mimic IRMutatorWithAnalyzer
    
    Previously, `IRVisitorWithAnalyzer` did not allow subclassing, and
    could only be used to collect bounds of variables along an entire
    statement, and could not be used to perform scope-dependent analysis.
    This commit removes `final` from `IRVisitorWithAnalyzer` and provides
    the same scope-based constraints/bindings during iteration as are
    provided by `IRMutatorWithAnalyzer`.
    
    * [Arith] Moved IRVisitorWithAnalyzer to tvm::arith namespace
    
    Changing for consistency, since `IRVisitorWithAnalyzer` it is part of
    the `src/arith` directory and the analogous `IRMutatorWithAnalyzer` is
    already part of the `arith` namespace.
    
    * [Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer
    
    This used the earlier changes to allow subclasses of
    `IRVisitorWithAnalyzer`, and to expose binding/constraints to
    `IntSetAnalyzer`.
    
    * Avoid accidental Bind with dynamic Range
    
    * [Arith] Do not visit SelectNode in IRVisitorWithAnalyzer
    
    Because both sides of a `Select` node are visited regardless of the
    condition, the `SelectNode::condition` should not be treated as a
    known value.
    
    * [Arith][IntSet] Track global and scope-dependent bounds separately
    
    Resolves a bug that was found in CI, where an earlier scope-dependent
    constraint was treated as a conflict by a later global bound.
    
    * [Arith] Recovery function for each subanalyzer
    
    This way, if a subanalyzer throws an exception during
    `EnterConstraint`, the other subanalyzers are still appropriately
    backed out of the constraint.
    
    * [Arith][IntSet] Use CanProve instead of CanProveGreaterEqual
    
    The `min_value - max_value` in the `CanProveGreaterEqual` argument can
    result in an exception being thrown for unsigned integers where
    subtraction would wrap.
    
    * [Arith] Allow vector expressions in IntSet::operator(PrimExpr)
    
    Since these are tracked when lowering expressions, should allow
    post-vectorization expressions.
    
    To maintain previous behavior, this only applies when using the
    automatically tracked `Map<Var, IntSet> dom_map_`.  If an explicit
    domain map is passed, the previous behavior of raising an error for
    vectorized expressions still occurs.
    
    * Avoid comparisons between integer and handle datatypes
    
    * [Arith] IntSet, Combine() extension
    
    Previously, the Combine() method didn't handle values without a known
    lower bound, for boolean operators.
    
    * Added docstring
    
    * Naming consistency of `IntSetAnalyzer` methods.
    
    To be consistent with other subanalyzers, using "Update" when
    providing the analyzer with the same data structure as is used
    internally, and "Bind" used when providing it with something that must
    be converted to the internal data structure.
---
 include/tvm/arith/analyzer.h          |  46 +++++++-
 src/arith/analyzer.cc                 |  26 +++--
 src/arith/domain_touched.cc           |  43 ++-----
 src/arith/int_set.cc                  | 211 +++++++++++++++++++++++++++++++---
 src/arith/ir_visitor_with_analyzer.cc | 126 ++++++++++++++++++++
 src/arith/ir_visitor_with_analyzer.h  |  45 ++++----
 src/tir/transforms/storage_flatten.cc |   1 +
 src/tir/transforms/texture_flatten.cc |   1 +
 8 files changed, 409 insertions(+), 90 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 3704eff33e..ceb9f574f2 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer {
    *
    * \param var The variable of interest.
    * \param info The bound information.
-   * \param allow_override Whether do we allow override of existing 
information.
+   * \param allow_override whether we allow override of existing information.
    */
   TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool 
allow_override = false);
   /*!
@@ -224,7 +224,7 @@ class ModularSetAnalyzer {
    *
    * \param var The variable of interest.
    * \param info The bound information.
-   * \param allow_override Whether do we allow override of existing 
information.
+   * \param allow_override whether we allow override of existing information.
    */
   TVM_DLL void Update(const Var& var, const ModularSet& info, bool 
allow_override = false);
 
@@ -263,10 +263,16 @@ class RewriteSimplifier {
    *
    * \param var The variable of interest.
    * \param new_expr
-   * \param allow_override Whether do we allow override of existing 
information.
+   * \param allow_override Whether we allow override of existing information.
    */
   TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool 
allow_override = false);
 
+  /*!
+   * \brief Update the internal state to enter constraint.
+   * \param constraint A constraint expression.
+   *
+   * \return an exit function that must be called to cleanup the constraint 
can be nullptr.
+   */
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
  private:
@@ -297,7 +303,7 @@ class CanonicalSimplifier {
    *
    * \param var The variable of interest.
    * \param new_expr
-   * \param allow_override Whether do we allow override of existing 
information.
+   * \param allow_override whether we allow override of existing information.
    */
   TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool 
allow_override = false);
 
@@ -347,7 +353,7 @@ class ConstraintContext {
   /*! \brief The constraint */
   PrimExpr constraint_;
   /*! \brief function to be called in recovery */
-  std::function<void()> exit_;
+  std::vector<std::function<void()>> recovery_functions_;
 };
 
 /*!
@@ -365,6 +371,36 @@ class IntSetAnalyzer {
    */
   TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& 
dom_map);
 
+  /*!
+   * \brief Find a symbolic integer set that contains all possible
+   *        values of expr given the domain of each variables, using
+   *        the domain map defined by bound variables.
+   *
+   * \param expr The expression of interest.
+   * \return the result of the analysis.
+   */
+  TVM_DLL IntSet operator()(const PrimExpr& expr);
+
+  /*!
+   * \brief Update binding of var to a new expression.
+   *
+   * \param var The variable of interest.
+   * \param new_interval_set The set of allowed values for this var.
+   * \param allow_override whether we allow override of existing information.
+   */
+  TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool 
allow_override = false);
+
+  /*!
+   * \brief Update binding of var to a new expression.
+   *
+   * \param var The variable of interest.
+   * \param new_range The range of allowed values for this var.
+   * \param allow_override whether we allow override of existing information.
+   */
+  TVM_DLL void Bind(const Var& var, const Range& new_range, bool 
allow_override = false);
+
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
+
  private:
   friend class Analyzer;
   explicit IntSetAnalyzer(Analyzer* parent);
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index b922138057..f32c9b2ff4 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, 
bool allow_override) {
   this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
   this->rewrite_simplify.Update(var, new_expr, allow_override);
   this->canonical_simplify.Update(var, new_expr, allow_override);
+  this->int_set.Update(var, this->int_set(new_expr), allow_override);
 }
 
 void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
@@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool 
allow_override) {
     this->Bind(var, range->min, allow_override);
   } else {
     this->const_int_bound.Bind(var, range, allow_override);
+    this->int_set.Bind(var, range, allow_override);
   }
   // skip modular_set
   // skip rewrite simplify
@@ -64,22 +66,22 @@ void Analyzer::Bind(const Map<Var, Range>& variables, bool 
allow_override) {
 }
 
 void ConstraintContext::EnterWithScope() {
-  ICHECK(exit_ == nullptr);
+  ICHECK(recovery_functions_.size() == 0);
   // entering the scope.
-  auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
-  auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
-  auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
-  // recovery function.
-  exit_ = [f0, f1, f2]() {
-    if (f2 != nullptr) f2();
-    if (f1 != nullptr) f1();
-    if (f0 != nullptr) f0();
-  };
+  
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
+  
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
+  
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
+  
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
 }
 
 void ConstraintContext::ExitWithScope() {
-  ICHECK(exit_ != nullptr);
-  exit_();
+  while (recovery_functions_.size()) {
+    auto& func = recovery_functions_.back();
+    if (func) {
+      func();
+    }
+    recovery_functions_.pop_back();
+  }
 }
 
 bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) 
{
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index 403ea47f4e..d2c5d79a09 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -30,6 +30,8 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "ir_visitor_with_analyzer.h"
+
 namespace tvm {
 namespace arith {
 
@@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple<LoadAccess, 
StoreAccess, CombinedAccess>;
 }  // namespace
 
 // Find Read region of the tensor in the stmt.
-class BufferTouchedDomain final : public StmtExprVisitor {
+class BufferTouchedDomain final : public IRVisitorWithAnalyzer {
  public:
   BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }
 
@@ -90,39 +92,17 @@ class BufferTouchedDomain final : public StmtExprVisitor {
     return ret;
   }
 
-  void VisitStmt_(const ForNode* op) final {
-    const VarNode* var = op->loop_var.get();
-    dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, 
op->extent));
-    StmtExprVisitor::VisitStmt_(op);
-    dom_map_.erase(var);
-  }
-
-  void VisitStmt_(const LetStmtNode* op) final {
-    dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_);
-    StmtExprVisitor::VisitStmt_(op);
-    dom_map_.erase(op->var.get());
-  }
-
-  /* TODO: Thread extent unitest not generated.*/
-  void VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == tir::attr::thread_extent) {
-      const IterVarNode* thread_axis = op->node.as<IterVarNode>();
-      ICHECK(thread_axis);
-      const VarNode* var = thread_axis->var.get();
-      dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), 
op->value));
-      StmtExprVisitor::VisitStmt_(op);
-      dom_map_.erase(var);
-    } else {
-      StmtExprVisitor::VisitStmt_(op);
-    }
-  }
+ private:
+  using Parent = IRVisitorWithAnalyzer;
+  using Parent::VisitExpr_;
+  using Parent::VisitStmt_;
 
   void VisitExpr_(const BufferLoadNode* op) final {
     // Record load-exclusive buffer access
     Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, 
op->indices);
     // Record load-store inclusive buffer access
     Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, 
op->indices);
-    StmtExprVisitor::VisitExpr_(op);
+    Parent::VisitExpr_(op);
   }
 
   void VisitStmt_(const BufferStoreNode* op) final {
@@ -130,11 +110,11 @@ class BufferTouchedDomain final : public StmtExprVisitor {
     Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, 
op->indices);
     // Record load-store inclusive buffer access
     Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, 
op->indices);
-    StmtExprVisitor::VisitStmt_(op);
+    Parent::VisitStmt_(op);
   }
 
  private:
-  void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
+  void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) {
     if (args.size() > bounds->size()) {
       bounds->resize(args.size());
     }
@@ -142,13 +122,12 @@ class BufferTouchedDomain final : public StmtExprVisitor {
       if (args[i].as<RampNode>()) {
         (*bounds)[i].emplace_back(IntSet::Vector(args[i]));
       } else {
-        (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
+        (*bounds)[i].emplace_back(analyzer_.int_set(args[i]));
       }
     }
   }
 
   std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
-  std::unordered_map<const VarNode*, IntSet> dom_map_;
 };
 
 Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool 
consider_loads,
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 48fae479b0..6d48ad1ed1 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -31,6 +31,7 @@
 #include <unordered_map>
 #include <utility>
 
+#include "constraint_extract.h"
 #include "interval_set.h"
 #include "pattern_match.h"
 
@@ -63,7 +64,7 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
   PrimExpr min_value = max(a->min_value, b->min_value);
   if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
       (min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
-      analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
+      analyzer->CanProve(max_value < min_value)) {
     return IntervalSet::Empty();
   } else {
     return IntervalSet(min_value, max_value);
@@ -105,14 +106,14 @@ TVM_DECLARE_LOGICAL_OP(Not);
  * \note this can possibly relax the set.
  */
 template <typename Op>
-inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, 
DataType dtype) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
     if (!res.defined()) res = Op(a->min_value, b->min_value);
     return IntervalSet::SinglePoint(res);
   }
   if (is_logical_op<Op>::value) {
-    return IntervalSet(make_const(a->min_value.dtype(), 0), 
make_const(a->min_value.dtype(), 1));
+    return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
   }
   if (a->IsEmpty()) return a;
   if (b->IsEmpty()) return b;
@@ -122,7 +123,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet 
a, IntervalSet b) {
 }
 
 template <>
-inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value + b->min_value);
   }
@@ -136,7 +138,8 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer, 
IntervalSet a, IntervalS
 }
 
 template <>
-inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value - b->min_value);
   }
@@ -150,7 +153,8 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, 
IntervalSet a, IntervalS
 }
 
 template <>
-inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value * b->min_value);
   }
@@ -183,7 +187,8 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, 
IntervalSet a, Interval
 }
 
 template <>
-inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value / b->min_value);
   }
@@ -216,7 +221,8 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, 
IntervalSet a, Interval
 }
 
 template <>
-inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
   }
@@ -244,7 +250,8 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, 
IntervalSet a, Interval
 }
 
 template <>
-inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
+                                          DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
   }
@@ -277,7 +284,8 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer* 
analyzer, IntervalSet a, Int
 }
 
 template <>
-inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
+                                          DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
   }
@@ -294,7 +302,10 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* 
analyzer, IntervalSet a, Int
         // a mod b = a - (a / b) * b if a_max / b == a_min / b
         auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : 
pos_inf();
         auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : 
neg_inf();
-        if (analyzer->CanProve(qmax == qmin)) {
+        // We can compare +/- inf against each other, but cannot use
+        // operator== between the symbolic limits and an integer.
+        bool compatible_dtypes = !(qmin.dtype().is_handle() ^ 
qmax.dtype().is_handle());
+        if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) {
           auto tmax = a->max_value - divisor * qmin;
           auto tmin = a->min_value - divisor * qmin;
           return IntervalSet(tmin, tmax);
@@ -311,7 +322,8 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* 
analyzer, IntervalSet a, Int
 }
 
 template <>
-inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
   }
@@ -321,7 +333,8 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, 
IntervalSet a, Interval
 }
 
 template <>
-inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b) {
+inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b,
+                                     DataType /* dtype */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
   }
@@ -423,10 +436,12 @@ class IntervalSetEvaluator : public 
ExprFunctor<IntervalSet(const PrimExpr&)> {
       int64_t vstride = stride.Eval()->value;
       if (vstride > 0) {
         return Combine<Add>(analyzer_, base,
-                            IntervalSet(make_zero(t), make_const(t, vstride * 
op->lanes - 1)));
+                            IntervalSet(make_zero(t), make_const(t, vstride * 
op->lanes - 1)),
+                            op->dtype);
       } else {
         return Combine<Add>(analyzer_, base,
-                            IntervalSet(make_const(t, vstride * op->lanes + 
1), make_zero(t)));
+                            IntervalSet(make_const(t, vstride * op->lanes + 
1), make_zero(t)),
+                            op->dtype);
       }
     }
     DLOG(WARNING) << "cannot evaluate set on expression " << 
GetRef<PrimExpr>(op);
@@ -490,7 +505,7 @@ class IntervalSetEvaluator : public 
ExprFunctor<IntervalSet(const PrimExpr&)> {
     if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
       return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
     }
-    return Combine<TOp>(analyzer_, a, b);
+    return Combine<TOp>(analyzer_, a, b, op->dtype);
   }
 
   // recursive depth
@@ -509,8 +524,37 @@ class IntSetAnalyzer::Impl {
     return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
   }
 
+  IntSet Eval(const PrimExpr& expr) const {
+    return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), 
true).Eval(expr);
+  }
+
+  void Bind(const Var& var, const Range& range, bool allow_override) {
+    Update(var, IntSet::FromRange(range), allow_override);
+  }
+
+  void Update(const Var& var, const IntSet& info, bool override_info);
+  void Bind(const Var& var, const PrimExpr& expr, bool override_info);
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
+
  private:
+  // Get the current variable bounds, including both global bounds and
+  // scope-dependent bounds.
+  Map<Var, IntSet> GetCurrentBounds() const;
+
+  // Utility function to split a boolean condition into the domain
+  // bounds implied by that condition.
+  static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& 
cond);
+
+  // The parent arith::Analyzer
   Analyzer* analyzer_;
+
+  // Map of variables to global variable bounds (e.g. loop iterator
+  // ranges)
+  Map<Var, IntSet> dom_map_;
+
+  // Map of variables to implicit scope-dependent bounds (e.g. inside
+  // the body of an if-statement)
+  Map<Var, IntSet> constraints_;
 };
 
 IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
@@ -521,6 +565,141 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, 
const Map<Var, IntSet>&
   return impl_->Eval(expr, dom_map);
 }
 
+IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return 
impl_->Eval(expr); }
+
+void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool 
allow_override) {
+  impl_->Update(var, info, allow_override);
+}
+
+void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool 
allow_override) {
+  impl_->Bind(var, range, allow_override);
+}
+
+void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool 
can_override) {
+  if (!can_override) {
+    auto it = dom_map_.find(var);
+    if (it != dom_map_.end()) {
+      const IntSet& old_info = (*it).second;
+
+      ICHECK(ExprDeepEqual()(old_info.min(), info.min()))
+          << "Trying to update var \'" << var << "\'"
+          << " with a different minimum value: "
+          << "original=" << old_info.min() << ", new=" << info.min();
+
+      ICHECK(ExprDeepEqual()(old_info.max(), info.max()))
+          << "Trying to update var \'" << var << "\'"
+          << " with a different maximum value: "
+          << "original=" << old_info.max() << ", new=" << info.max();
+    }
+  }
+  dom_map_.Set(var, info);
+}
+
+void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool 
can_override) {
+  Update(var, Eval(expr), can_override);
+}
+
+Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
+  // If either constraints_ or dom_map_ is empty, return the other to
+  // avoid constructing a new map.
+  if (constraints_.empty()) {
+    return dom_map_;
+  } else if (dom_map_.empty()) {
+    return constraints_;
+  }
+
+  // If neither is empty, construct a merged domain map with
+  // information from both sources.
+  Map<Var, IntSet> merged = dom_map_;
+  for (const auto& pair : constraints_) {
+    auto it = merged.find(pair.first);
+    if (it == merged.end()) {
+      merged.Set(pair.first, pair.second);
+    } else {
+      merged.Set(pair.first, Intersect({pair.second, (*it).second}));
+    }
+  }
+  return merged;
+}
+
+std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
+    const PrimExpr& constraint) {
+  PVar<Var> x;
+  PVar<PrimExpr> limit;
+
+  std::vector<std::pair<Var, IntSet>> bounds;
+  for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) {
+    if ((x <= limit).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, 
limit.Eval())});
+    } else if ((x < limit).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, 
limit.Eval() - 1)});
+    } else if ((x >= limit).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), 
SymbolicLimits::pos_inf_)});
+    } else if ((x > limit).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, 
SymbolicLimits::pos_inf_)});
+    } else if ((x == limit).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
+    }
+
+    if ((limit >= x).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, 
limit.Eval())});
+    } else if ((limit > x).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, 
limit.Eval() - 1)});
+    } else if ((limit <= x).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), 
SymbolicLimits::pos_inf_)});
+    } else if ((limit < x).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, 
SymbolicLimits::pos_inf_)});
+    } else if ((limit == x).Match(subconstraint)) {
+      bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
+    }
+  }
+  return bounds;
+}
+
+std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& 
constraint) {
+  return impl_->EnterConstraint(constraint);
+}
+
+std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& 
constraint) {
+  Map<Var, IntSet> cached_values;
+
+  auto bounds = DetectBoundInfo(constraint);
+
+  if (bounds.size() == 0) return nullptr;
+
+  // Collect the current values of each var that is changes by this
+  // constraint.
+  for (const auto& pair : bounds) {
+    auto it = constraints_.find(pair.first);
+    if (it == constraints_.end()) {
+      cached_values.Set(pair.first, IntSet());
+    } else {
+      cached_values.Set(pair.first, (*it).second);
+    }
+  }
+
+  // Update all constraints
+  for (const auto& pair : bounds) {
+    auto it = constraints_.find(pair.first);
+    if (it == constraints_.end()) {
+      constraints_.Set(pair.first, pair.second);
+    } else {
+      constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
+    }
+  }
+
+  auto frecover = [cached_values, this]() {
+    for (const auto& it : cached_values) {
+      if (it.second.defined()) {
+        constraints_.Set(it.first, it.second);
+      } else {
+        constraints_.erase(it.first);
+      }
+    }
+  };
+  return frecover;
+}
+
 // Quickly adapt to IntSet interface
 // TODO(tqchen): revisit IntSet interface as well.
 Range IntSet::CoverRange(Range max_range) const {
diff --git a/src/arith/ir_visitor_with_analyzer.cc 
b/src/arith/ir_visitor_with_analyzer.cc
new file mode 100644
index 0000000000..75ae22ef99
--- /dev/null
+++ b/src/arith/ir_visitor_with_analyzer.cc
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/ir_visitor_with_analyzer.cc
+ */
+#include "ir_visitor_with_analyzer.h"
+
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) {
+  analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+  StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) {
+  for (const auto& iter_var : op->iter_vars) {
+    analyzer_.Bind(iter_var->var, iter_var->dom);
+  }
+  StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
+  this->VisitExpr(op->value);
+  analyzer_.Bind(op->var, op->value);
+  this->VisitStmt(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
+  this->VisitExpr(op->condition);
+
+  PrimExpr real_condition = ExtractRealCondition(op->condition);
+
+  {
+    With<ConstraintContext> constraint(&analyzer_, real_condition);
+    this->VisitStmt(op->then_case);
+  }
+  if (op->else_case.defined()) {
+    With<ConstraintContext> constraint(&analyzer_, 
analyzer_.rewrite_simplify(Not(real_condition)));
+    this->VisitStmt(op->else_case);
+  }
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
+  if (op->attr_key == tir::attr::thread_extent || op->attr_key == 
tir::attr::virtual_thread) {
+    IterVar iv = Downcast<IterVar>(op->node);
+    ICHECK_NE(iv->thread_tag.length(), 0U);
+    analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
+  }
+  StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
+  this->VisitExpr(op->condition);
+  this->VisitExpr(op->message);
+  With<ConstraintContext> constraint(&analyzer_, op->condition);
+  this->VisitStmt(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) {
+  // add condition context to if_then_else
+  static auto op_if_then_else = Op::Get("tir.if_then_else");
+  if (op->op.same_as(op_if_then_else)) {
+    PrimExpr cond = op->args[0];
+    this->VisitExpr(op->args[0]);
+    {
+      With<ConstraintContext> constraint(&analyzer_, cond);
+      this->VisitExpr(op->args[1]);
+    }
+    {
+      With<ConstraintContext> constraint(&analyzer_, 
analyzer_.rewrite_simplify(Not(cond)));
+      this->VisitExpr(op->args[2]);
+    }
+  } else {
+    StmtExprVisitor::VisitExpr_(op);
+  }
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) {
+  this->VisitExpr(op->value);
+  analyzer_.Bind(op->var, op->value);
+  this->VisitExpr(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) {
+  for (const IterVar& iv : op->axis) {
+    analyzer_.Bind(iv->var, iv->dom);
+  }
+  StmtExprVisitor::VisitExpr_(op);
+}
+
+PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const 
{
+  if (auto call = condition.as<CallNode>()) {
+    if (call->op.same_as(builtin::likely())) {
+      return call->args[0];
+    }
+  }
+
+  return condition;
+}
+
+}  // namespace arith
+}  // namespace tvm
diff --git a/src/arith/ir_visitor_with_analyzer.h 
b/src/arith/ir_visitor_with_analyzer.h
index 058abc8c7d..f41a628f3c 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -30,42 +30,37 @@
 #include <tvm/tir/stmt_functor.h>
 
 namespace tvm {
-namespace tir {
+namespace arith {
 
-class IRVisitorWithAnalyzer final : public StmtExprVisitor {
+class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
  public:
   PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); }
 
-  void VisitStmt_(const ForNode* op) {
-    analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
-    return StmtExprVisitor::VisitStmt_(op);
-  }
+  using StmtExprVisitor::VisitExpr_;
+  using StmtExprVisitor::VisitStmt_;
 
-  void VisitStmt_(const AttrStmtNode* op) {
-    if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::virtual_thread) {
-      IterVar iv = Downcast<IterVar>(op->node);
-      ICHECK_NE(iv->thread_tag.length(), 0U);
-      analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
-      StmtExprVisitor::VisitStmt_(op);
-    } else {
-      StmtExprVisitor::VisitStmt_(op);
-    }
-  }
+  void VisitStmt_(const tir::ForNode* op);
+  void VisitStmt_(const tir::BlockNode* op);
+  void VisitStmt_(const tir::LetStmtNode* op);
+  void VisitStmt_(const tir::IfThenElseNode* op);
+  void VisitStmt_(const tir::AttrStmtNode* op);
+  void VisitStmt_(const tir::AssertStmtNode* op);
+  void VisitExpr_(const tir::CallNode* op);
+  void VisitExpr_(const tir::LetNode* op);
+  void VisitExpr_(const tir::ReduceNode* op);
 
-  void VisitExpr_(const ReduceNode* op) {
-    // Setup the domain information before simplification.
-    for (const IterVar& iv : op->axis) {
-      analyzer_.Bind(iv->var, iv->dom);
-    }
-    // Recursively call simplification when necessary.
-    StmtExprVisitor::VisitExpr_(op);
-  }
+  // IRVisitorWithAnalyzer deliberately does not handle Select nodes,
+  // because both sides of a Select node are visited regardless of the
+  // condition.
 
  protected:
   /*! \brief internal analyzer field. */
   arith::Analyzer analyzer_;
+
+ private:
+  PrimExpr ExtractRealCondition(PrimExpr condition) const;
 };
 
-}  // namespace tir
+}  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
diff --git a/src/tir/transforms/storage_flatten.cc 
b/src/tir/transforms/storage_flatten.cc
index f2d9aba4fb..dd236537e9 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -47,6 +47,7 @@
 namespace tvm {
 namespace tir {
 
+using arith::IRVisitorWithAnalyzer;
 using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
diff --git a/src/tir/transforms/texture_flatten.cc 
b/src/tir/transforms/texture_flatten.cc
index a607e5914b..3c35b73bc8 100644
--- a/src/tir/transforms/texture_flatten.cc
+++ b/src/tir/transforms/texture_flatten.cc
@@ -38,6 +38,7 @@
 
 namespace tvm {
 namespace tir {
+using arith::IRVisitorWithAnalyzer;
 using runtime::ApplyTexture2DFlattening;
 using runtime::DefaultTextureLayoutSeparator;
 using runtime::IsTextureStorage;

Reply via email to