This is an automated email from the ASF dual-hosted git repository.
tkonolige pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b5dd136d7 [Arith] Updated BufferDomainTouched to use
IRVisitorWithAnalyzer (#11970)
4b5dd136d7 is described below
commit 4b5dd136d764fbef5f552ffce0759232c138e4e2
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jul 13 11:17:53 2022 -0500
[Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer (#11970)
* [Arith] Allow binding of Var in IntSetAnalyzer
The other four subanalyzers in `arith::Analyzer` can each be provided
with variable bindings/constraints that are remembered internally.
This adds the same capability to `IntSetAnalyzer`, rather than
requiring users to independently track and maintain a `Map<Var,
IntSet>` containing the domain of each variable, and applies
bindings/constraints alongside the other subanalyzers.
* [Arith] Updated IRVisitorWithAnalyzer to mimic IRMutatorWithAnalyzer
Previously, `IRVisitorWithAnalyzer` did not allow subclassing, and
could only be used to collect bounds of variables along an entire
statement, and could not be used to perform scope-dependent analysis.
This commit removes `final` from `IRVisitorWithAnalyzer` and provides
the same scope-based constraints/bindings during iteration as are
provided by `IRMutatorWithAnalyzer`.
* [Arith] Moved IRVisitorWithAnalyzer to tvm::arith namespace
Changing for consistency, since `IRVisitorWithAnalyzer` it is part of
the `src/arith` directory and the analogous `IRMutatorWithAnalyzer` is
already part of the `arith` namespace.
* [Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer
This used the earlier changes to allow subclasses of
`IRVisitorWithAnalyzer`, and to expose binding/constraints to
`IntSetAnalyzer`.
* Avoid accidental Bind with dynamic Range
* [Arith] Do not visit SelectNode in IRVisitorWithAnalyzer
Because both sides of a `Select` node are visited regardless of the
condition, the `SelectNode::condition` should not be treated as a
known value.
* [Arith][IntSet] Track global and scope-dependent bounds separately
Resolves a bug that was found in CI, where an earlier scope-dependent
constraint was treated as a conflict by a later global bound.
* [Arith] Recovery function for each subanalyzer
This way, if a subanalyzer throws an exception during
`EnterConstraint`, the other subanalyzers are still appropriately
backed out of the constraint.
* [Arith][IntSet] Use CanProve instead of CanProveGreaterEqual
The `min_value - max_value` in the `CanProveGreaterEqual` argument can
result in an exception being thrown for unsigned integers where
subtraction would wrap.
* [Arith] Allow vector expressions in IntSet::operator(PrimExpr)
Since these are tracked when lowering expressions, should allow
post-vectorization expressions.
To maintain previous behavior, this only applies when using the
automatically tracked `Map<Var, IntSet> dom_map_`. If an explicit
domain map is passed, the previous behavior of raising an error for
vectorized expressions still occurs.
* Avoid comparisons between integer and handle datatypes
* [Arith] IntSet, Combine() extension
Previously, the Combine() method didn't handle values without a known
lower bound, for boolean operators.
* Added docstring
* Naming consistency of `IntSetAnalyzer` methods.
To be consistent with other subanalyzers, using "Update" when
providing the analyzer with the same data structure as is used
internally, and "Bind" used when providing it with something that must
be converted to the internal data structure.
---
include/tvm/arith/analyzer.h | 46 +++++++-
src/arith/analyzer.cc | 26 +++--
src/arith/domain_touched.cc | 43 ++-----
src/arith/int_set.cc | 211 +++++++++++++++++++++++++++++++---
src/arith/ir_visitor_with_analyzer.cc | 126 ++++++++++++++++++++
src/arith/ir_visitor_with_analyzer.h | 45 ++++----
src/tir/transforms/storage_flatten.cc | 1 +
src/tir/transforms/texture_flatten.cc | 1 +
8 files changed, 409 insertions(+), 90 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 3704eff33e..ceb9f574f2 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
- * \param allow_override Whether do we allow override of existing
information.
+ * \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool
allow_override = false);
/*!
@@ -224,7 +224,7 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
- * \param allow_override Whether do we allow override of existing
information.
+ * \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ModularSet& info, bool
allow_override = false);
@@ -263,10 +263,16 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
- * \param allow_override Whether do we allow override of existing
information.
+ * \param allow_override Whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool
allow_override = false);
+ /*!
+ * \brief Update the internal state to enter constraint.
+ * \param constraint A constraint expression.
+ *
+ * \return an exit function that must be called to cleanup the constraint
can be nullptr.
+ */
std::function<void()> EnterConstraint(const PrimExpr& constraint);
private:
@@ -297,7 +303,7 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
- * \param allow_override Whether do we allow override of existing
information.
+ * \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool
allow_override = false);
@@ -347,7 +353,7 @@ class ConstraintContext {
/*! \brief The constraint */
PrimExpr constraint_;
/*! \brief function to be called in recovery */
- std::function<void()> exit_;
+ std::vector<std::function<void()>> recovery_functions_;
};
/*!
@@ -365,6 +371,36 @@ class IntSetAnalyzer {
*/
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>&
dom_map);
+ /*!
+ * \brief Find a symbolic integer set that contains all possible
+ * values of expr given the domain of each variables, using
+ * the domain map defined by bound variables.
+ *
+ * \param expr The expression of interest.
+ * \return the result of the analysis.
+ */
+ TVM_DLL IntSet operator()(const PrimExpr& expr);
+
+ /*!
+ * \brief Update binding of var to a new expression.
+ *
+ * \param var The variable of interest.
+ * \param new_interval_set The set of allowed values for this var.
+ * \param allow_override whether we allow override of existing information.
+ */
+ TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool
allow_override = false);
+
+ /*!
+ * \brief Update binding of var to a new expression.
+ *
+ * \param var The variable of interest.
+ * \param new_range The range of allowed values for this var.
+ * \param allow_override whether we allow override of existing information.
+ */
+ TVM_DLL void Bind(const Var& var, const Range& new_range, bool
allow_override = false);
+
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
+
private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index b922138057..f32c9b2ff4 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr,
bool allow_override) {
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
+ this->int_set.Update(var, this->int_set(new_expr), allow_override);
}
void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
@@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool
allow_override) {
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
+ this->int_set.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
@@ -64,22 +66,22 @@ void Analyzer::Bind(const Map<Var, Range>& variables, bool
allow_override) {
}
void ConstraintContext::EnterWithScope() {
- ICHECK(exit_ == nullptr);
+ ICHECK(recovery_functions_.size() == 0);
// entering the scope.
- auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
- auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
- auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
- // recovery function.
- exit_ = [f0, f1, f2]() {
- if (f2 != nullptr) f2();
- if (f1 != nullptr) f1();
- if (f0 != nullptr) f0();
- };
+
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
+
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
+
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
+
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
}
void ConstraintContext::ExitWithScope() {
- ICHECK(exit_ != nullptr);
- exit_();
+ while (recovery_functions_.size()) {
+ auto& func = recovery_functions_.back();
+ if (func) {
+ func();
+ }
+ recovery_functions_.pop_back();
+ }
}
bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound)
{
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index 403ea47f4e..d2c5d79a09 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -30,6 +30,8 @@
#include <unordered_map>
#include <unordered_set>
+#include "ir_visitor_with_analyzer.h"
+
namespace tvm {
namespace arith {
@@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple<LoadAccess,
StoreAccess, CombinedAccess>;
} // namespace
// Find Read region of the tensor in the stmt.
-class BufferTouchedDomain final : public StmtExprVisitor {
+class BufferTouchedDomain final : public IRVisitorWithAnalyzer {
public:
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }
@@ -90,39 +92,17 @@ class BufferTouchedDomain final : public StmtExprVisitor {
return ret;
}
- void VisitStmt_(const ForNode* op) final {
- const VarNode* var = op->loop_var.get();
- dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min,
op->extent));
- StmtExprVisitor::VisitStmt_(op);
- dom_map_.erase(var);
- }
-
- void VisitStmt_(const LetStmtNode* op) final {
- dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_);
- StmtExprVisitor::VisitStmt_(op);
- dom_map_.erase(op->var.get());
- }
-
- /* TODO: Thread extent unitest not generated.*/
- void VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tir::attr::thread_extent) {
- const IterVarNode* thread_axis = op->node.as<IterVarNode>();
- ICHECK(thread_axis);
- const VarNode* var = thread_axis->var.get();
- dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()),
op->value));
- StmtExprVisitor::VisitStmt_(op);
- dom_map_.erase(var);
- } else {
- StmtExprVisitor::VisitStmt_(op);
- }
- }
+ private:
+ using Parent = IRVisitorWithAnalyzer;
+ using Parent::VisitExpr_;
+ using Parent::VisitStmt_;
void VisitExpr_(const BufferLoadNode* op) final {
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set,
op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set,
op->indices);
- StmtExprVisitor::VisitExpr_(op);
+ Parent::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
@@ -130,11 +110,11 @@ class BufferTouchedDomain final : public StmtExprVisitor {
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set,
op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set,
op->indices);
- StmtExprVisitor::VisitStmt_(op);
+ Parent::VisitStmt_(op);
}
private:
- void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
+ void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
@@ -142,13 +122,12 @@ class BufferTouchedDomain final : public StmtExprVisitor {
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
- (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
+ (*bounds)[i].emplace_back(analyzer_.int_set(args[i]));
}
}
}
std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
- std::unordered_map<const VarNode*, IntSet> dom_map_;
};
Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool
consider_loads,
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 48fae479b0..6d48ad1ed1 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -31,6 +31,7 @@
#include <unordered_map>
#include <utility>
+#include "constraint_extract.h"
#include "interval_set.h"
#include "pattern_match.h"
@@ -63,7 +64,7 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
PrimExpr min_value = max(a->min_value, b->min_value);
if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
(min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
- analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
+ analyzer->CanProve(max_value < min_value)) {
return IntervalSet::Empty();
} else {
return IntervalSet(min_value, max_value);
@@ -105,14 +106,14 @@ TVM_DECLARE_LOGICAL_OP(Not);
* \note this can possibly relax the set.
*/
template <typename Op>
-inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType dtype) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
if (!res.defined()) res = Op(a->min_value, b->min_value);
return IntervalSet::SinglePoint(res);
}
if (is_logical_op<Op>::value) {
- return IntervalSet(make_const(a->min_value.dtype(), 0),
make_const(a->min_value.dtype(), 1));
+ return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
@@ -122,7 +123,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet
a, IntervalSet b) {
}
template <>
-inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value + b->min_value);
}
@@ -136,7 +138,8 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer,
IntervalSet a, IntervalS
}
template <>
-inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
@@ -150,7 +153,8 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer,
IntervalSet a, IntervalS
}
template <>
-inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value * b->min_value);
}
@@ -183,7 +187,8 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer,
IntervalSet a, Interval
}
template <>
-inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value / b->min_value);
}
@@ -216,7 +221,8 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer,
IntervalSet a, Interval
}
template <>
-inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
}
@@ -244,7 +250,8 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer,
IntervalSet a, Interval
}
template <>
-inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
@@ -277,7 +284,8 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer*
analyzer, IntervalSet a, Int
}
template <>
-inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
@@ -294,7 +302,10 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer*
analyzer, IntervalSet a, Int
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) :
pos_inf();
auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) :
neg_inf();
- if (analyzer->CanProve(qmax == qmin)) {
+ // We can compare +/- inf against each other, but cannot use
+ // operator== between the symbolic limits and an integer.
+ bool compatible_dtypes = !(qmin.dtype().is_handle() ^
qmax.dtype().is_handle());
+ if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) {
auto tmax = a->max_value - divisor * qmin;
auto tmin = a->min_value - divisor * qmin;
return IntervalSet(tmin, tmax);
@@ -311,7 +322,8 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer*
analyzer, IntervalSet a, Int
}
template <>
-inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
@@ -321,7 +333,8 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer,
IntervalSet a, Interval
}
template <>
-inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a,
IntervalSet b) {
+inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a,
IntervalSet b,
+ DataType /* dtype */) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
@@ -423,10 +436,12 @@ class IntervalSetEvaluator : public
ExprFunctor<IntervalSet(const PrimExpr&)> {
int64_t vstride = stride.Eval()->value;
if (vstride > 0) {
return Combine<Add>(analyzer_, base,
- IntervalSet(make_zero(t), make_const(t, vstride *
op->lanes - 1)));
+ IntervalSet(make_zero(t), make_const(t, vstride *
op->lanes - 1)),
+ op->dtype);
} else {
return Combine<Add>(analyzer_, base,
- IntervalSet(make_const(t, vstride * op->lanes +
1), make_zero(t)));
+ IntervalSet(make_const(t, vstride * op->lanes +
1), make_zero(t)),
+ op->dtype);
}
}
DLOG(WARNING) << "cannot evaluate set on expression " <<
GetRef<PrimExpr>(op);
@@ -490,7 +505,7 @@ class IntervalSetEvaluator : public
ExprFunctor<IntervalSet(const PrimExpr&)> {
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}
- return Combine<TOp>(analyzer_, a, b);
+ return Combine<TOp>(analyzer_, a, b, op->dtype);
}
// recursive depth
@@ -509,8 +524,37 @@ class IntSetAnalyzer::Impl {
return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
}
+ IntSet Eval(const PrimExpr& expr) const {
+ return IntervalSetEvaluator(analyzer_, GetCurrentBounds(),
true).Eval(expr);
+ }
+
+ void Bind(const Var& var, const Range& range, bool allow_override) {
+ Update(var, IntSet::FromRange(range), allow_override);
+ }
+
+ void Update(const Var& var, const IntSet& info, bool override_info);
+ void Bind(const Var& var, const PrimExpr& expr, bool override_info);
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
+
private:
+ // Get the current variable bounds, including both global bounds and
+ // scope-dependent bounds.
+ Map<Var, IntSet> GetCurrentBounds() const;
+
+ // Utility function to split a boolean condition into the domain
+ // bounds implied by that condition.
+ static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr&
cond);
+
+ // The parent arith::Analyzer
Analyzer* analyzer_;
+
+ // Map of variables to global variable bounds (e.g. loop iterator
+ // ranges)
+ Map<Var, IntSet> dom_map_;
+
+ // Map of variables to implicit scope-dependent bounds (e.g. inside
+ // the body of an if-statement)
+ Map<Var, IntSet> constraints_;
};
IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
@@ -521,6 +565,141 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
const Map<Var, IntSet>&
return impl_->Eval(expr, dom_map);
}
+IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return
impl_->Eval(expr); }
+
+void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool
allow_override) {
+ impl_->Update(var, info, allow_override);
+}
+
+void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool
allow_override) {
+ impl_->Bind(var, range, allow_override);
+}
+
+void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool
can_override) {
+ if (!can_override) {
+ auto it = dom_map_.find(var);
+ if (it != dom_map_.end()) {
+ const IntSet& old_info = (*it).second;
+
+ ICHECK(ExprDeepEqual()(old_info.min(), info.min()))
+ << "Trying to update var \'" << var << "\'"
+ << " with a different minimum value: "
+ << "original=" << old_info.min() << ", new=" << info.min();
+
+ ICHECK(ExprDeepEqual()(old_info.max(), info.max()))
+ << "Trying to update var \'" << var << "\'"
+ << " with a different maximum value: "
+ << "original=" << old_info.max() << ", new=" << info.max();
+ }
+ }
+ dom_map_.Set(var, info);
+}
+
+void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool
can_override) {
+ Update(var, Eval(expr), can_override);
+}
+
+Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
+ // If either constraints_ or dom_map_ is empty, return the other to
+ // avoid constructing a new map.
+ if (constraints_.empty()) {
+ return dom_map_;
+ } else if (dom_map_.empty()) {
+ return constraints_;
+ }
+
+ // If neither is empty, construct a merged domain map with
+ // information from both sources.
+ Map<Var, IntSet> merged = dom_map_;
+ for (const auto& pair : constraints_) {
+ auto it = merged.find(pair.first);
+ if (it == merged.end()) {
+ merged.Set(pair.first, pair.second);
+ } else {
+ merged.Set(pair.first, Intersect({pair.second, (*it).second}));
+ }
+ }
+ return merged;
+}
+
+std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
+ const PrimExpr& constraint) {
+ PVar<Var> x;
+ PVar<PrimExpr> limit;
+
+ std::vector<std::pair<Var, IntSet>> bounds;
+ for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) {
+ if ((x <= limit).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_,
limit.Eval())});
+ } else if ((x < limit).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_,
limit.Eval() - 1)});
+ } else if ((x >= limit).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(),
SymbolicLimits::pos_inf_)});
+ } else if ((x > limit).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1,
SymbolicLimits::pos_inf_)});
+ } else if ((x == limit).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
+ }
+
+ if ((limit >= x).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_,
limit.Eval())});
+ } else if ((limit > x).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_,
limit.Eval() - 1)});
+ } else if ((limit <= x).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(),
SymbolicLimits::pos_inf_)});
+ } else if ((limit < x).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1,
SymbolicLimits::pos_inf_)});
+ } else if ((limit == x).Match(subconstraint)) {
+ bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
+ }
+ }
+ return bounds;
+}
+
+std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr&
constraint) {
+ return impl_->EnterConstraint(constraint);
+}
+
+std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr&
constraint) {
+ Map<Var, IntSet> cached_values;
+
+ auto bounds = DetectBoundInfo(constraint);
+
+ if (bounds.size() == 0) return nullptr;
+
+ // Collect the current values of each var that is changes by this
+ // constraint.
+ for (const auto& pair : bounds) {
+ auto it = constraints_.find(pair.first);
+ if (it == constraints_.end()) {
+ cached_values.Set(pair.first, IntSet());
+ } else {
+ cached_values.Set(pair.first, (*it).second);
+ }
+ }
+
+ // Update all constraints
+ for (const auto& pair : bounds) {
+ auto it = constraints_.find(pair.first);
+ if (it == constraints_.end()) {
+ constraints_.Set(pair.first, pair.second);
+ } else {
+ constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
+ }
+ }
+
+ auto frecover = [cached_values, this]() {
+ for (const auto& it : cached_values) {
+ if (it.second.defined()) {
+ constraints_.Set(it.first, it.second);
+ } else {
+ constraints_.erase(it.first);
+ }
+ }
+ };
+ return frecover;
+}
+
// Quickly adapt to IntSet interface
// TODO(tqchen): revisit IntSet interface as well.
Range IntSet::CoverRange(Range max_range) const {
diff --git a/src/arith/ir_visitor_with_analyzer.cc
b/src/arith/ir_visitor_with_analyzer.cc
new file mode 100644
index 0000000000..75ae22ef99
--- /dev/null
+++ b/src/arith/ir_visitor_with_analyzer.cc
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/ir_visitor_with_analyzer.cc
+ */
+#include "ir_visitor_with_analyzer.h"
+
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) {
+ analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+ StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) {
+ for (const auto& iter_var : op->iter_vars) {
+ analyzer_.Bind(iter_var->var, iter_var->dom);
+ }
+ StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
+ this->VisitExpr(op->value);
+ analyzer_.Bind(op->var, op->value);
+ this->VisitStmt(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
+ this->VisitExpr(op->condition);
+
+ PrimExpr real_condition = ExtractRealCondition(op->condition);
+
+ {
+ With<ConstraintContext> constraint(&analyzer_, real_condition);
+ this->VisitStmt(op->then_case);
+ }
+ if (op->else_case.defined()) {
+ With<ConstraintContext> constraint(&analyzer_,
analyzer_.rewrite_simplify(Not(real_condition)));
+ this->VisitStmt(op->else_case);
+ }
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
+ if (op->attr_key == tir::attr::thread_extent || op->attr_key ==
tir::attr::virtual_thread) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ ICHECK_NE(iv->thread_tag.length(), 0U);
+ analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
+ }
+ StmtExprVisitor::VisitStmt_(op);
+}
+
+void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
+ this->VisitExpr(op->condition);
+ this->VisitExpr(op->message);
+ With<ConstraintContext> constraint(&analyzer_, op->condition);
+ this->VisitStmt(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) {
+ // add condition context to if_then_else
+ static auto op_if_then_else = Op::Get("tir.if_then_else");
+ if (op->op.same_as(op_if_then_else)) {
+ PrimExpr cond = op->args[0];
+ this->VisitExpr(op->args[0]);
+ {
+ With<ConstraintContext> constraint(&analyzer_, cond);
+ this->VisitExpr(op->args[1]);
+ }
+ {
+ With<ConstraintContext> constraint(&analyzer_,
analyzer_.rewrite_simplify(Not(cond)));
+ this->VisitExpr(op->args[2]);
+ }
+ } else {
+ StmtExprVisitor::VisitExpr_(op);
+ }
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) {
+ this->VisitExpr(op->value);
+ analyzer_.Bind(op->var, op->value);
+ this->VisitExpr(op->body);
+}
+
+void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) {
+ for (const IterVar& iv : op->axis) {
+ analyzer_.Bind(iv->var, iv->dom);
+ }
+ StmtExprVisitor::VisitExpr_(op);
+}
+
+PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const
{
+ if (auto call = condition.as<CallNode>()) {
+ if (call->op.same_as(builtin::likely())) {
+ return call->args[0];
+ }
+ }
+
+ return condition;
+}
+
+} // namespace arith
+} // namespace tvm
diff --git a/src/arith/ir_visitor_with_analyzer.h
b/src/arith/ir_visitor_with_analyzer.h
index 058abc8c7d..f41a628f3c 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -30,42 +30,37 @@
#include <tvm/tir/stmt_functor.h>
namespace tvm {
-namespace tir {
+namespace arith {
-class IRVisitorWithAnalyzer final : public StmtExprVisitor {
+class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
public:
PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); }
- void VisitStmt_(const ForNode* op) {
- analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
- return StmtExprVisitor::VisitStmt_(op);
- }
+ using StmtExprVisitor::VisitExpr_;
+ using StmtExprVisitor::VisitStmt_;
- void VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::thread_extent || op->attr_key ==
attr::virtual_thread) {
- IterVar iv = Downcast<IterVar>(op->node);
- ICHECK_NE(iv->thread_tag.length(), 0U);
- analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
- StmtExprVisitor::VisitStmt_(op);
- } else {
- StmtExprVisitor::VisitStmt_(op);
- }
- }
+ void VisitStmt_(const tir::ForNode* op);
+ void VisitStmt_(const tir::BlockNode* op);
+ void VisitStmt_(const tir::LetStmtNode* op);
+ void VisitStmt_(const tir::IfThenElseNode* op);
+ void VisitStmt_(const tir::AttrStmtNode* op);
+ void VisitStmt_(const tir::AssertStmtNode* op);
+ void VisitExpr_(const tir::CallNode* op);
+ void VisitExpr_(const tir::LetNode* op);
+ void VisitExpr_(const tir::ReduceNode* op);
- void VisitExpr_(const ReduceNode* op) {
- // Setup the domain information before simplification.
- for (const IterVar& iv : op->axis) {
- analyzer_.Bind(iv->var, iv->dom);
- }
- // Recursively call simplification when necessary.
- StmtExprVisitor::VisitExpr_(op);
- }
+ // IRVisitorWithAnalyzer deliberately does not handle Select nodes,
+ // because both sides of a Select node are visited regardless of the
+ // condition.
protected:
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;
+
+ private:
+ PrimExpr ExtractRealCondition(PrimExpr condition) const;
};
-} // namespace tir
+} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index f2d9aba4fb..dd236537e9 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -47,6 +47,7 @@
namespace tvm {
namespace tir {
+using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
diff --git a/src/tir/transforms/texture_flatten.cc
b/src/tir/transforms/texture_flatten.cc
index a607e5914b..3c35b73bc8 100644
--- a/src/tir/transforms/texture_flatten.cc
+++ b/src/tir/transforms/texture_flatten.cc
@@ -38,6 +38,7 @@
namespace tvm {
namespace tir {
+using arith::IRVisitorWithAnalyzer;
using runtime::ApplyTexture2DFlattening;
using runtime::DefaultTextureLayoutSeparator;
using runtime::IsTextureStorage;