This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 25a0d47d2b [Arith][TIR] Check for constant offsets of known literal
constraints (#13023)
25a0d47d2b is described below
commit 25a0d47d2b55f3404ea711a3ff28bf22f7cc0e17
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Oct 29 12:55:55 2022 -0500
[Arith][TIR] Check for constant offsets of known literal constraints
(#13023)
Previously, the checks for a literal constraint would find exact
matches for an inequality, but any alterations to the conditional
would break this exact matching. This commit introduces checks for
constant offsets relative to a known value. These checks are not
always expressible using the existing `ConstIntSetAnalyzer`, which
represents allowed values using a single contiguous
region. (e.g. `i!=5` is not representable, because it requires a
region for `i<5` and another for `i>5`.)
This implementation reuses the internal representation for
inequalities introduced in https://github.com/apache/tvm/pull/12863,
along with much of its implementation. However, the indirect
comparisons (e.g. using `a < b` and `b < c` to prove that `a < c`)
introduced in that PR still require an explicit flag to be used.
---
include/tvm/arith/analyzer.h | 11 +-
src/arith/rewrite_simplify.cc | 7 +-
src/arith/transitive_comparison_analyzer.cc | 168 ++++++++++++++++-----
.../python/unittest/test_tir_transform_simplify.py | 14 ++
4 files changed, 155 insertions(+), 45 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e2d60684da..885c23f491 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -409,10 +409,19 @@ class TransitiveComparisonAnalyzer {
*
* \param rhs The right-hand side of the comparison
*
+ * \param propagate_inequalities If true, attempt to find a sequence
+ * of transitive inequalities that allow the lhs and rhs to be
+ * compared. If false, only use the known comparison that have been
+ * directly provided. Using `propagate_inequalities = false` is
+ * roughly equivalent to comparing against all known inequality
+ * expressions using `ExprDeepEqual`, but also allows for constant
+ * offsets on either side of the inequality.
+ *
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
- TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);
+ TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
+ bool propagate_inequalities = true);
/*! \brief Bind a variable as being equal to a known expression
*
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 6cc2aa9e45..a42303e459 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -118,9 +118,7 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const
PrimExpr& x, const PrimE
if (is_finished()) return output;
- if (enabled_extensions_ & kTransitivelyProveInequalities) {
- output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
- }
+ output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
return output;
}
@@ -132,7 +130,8 @@ CompareResult
RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimE
CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const
PrimExpr& x,
const
PrimExpr& y) {
- return analyzer_->transitive_comparisons.TryCompare(x, y);
+ bool propagate_inequalities = enabled_extensions_ &
kTransitivelyProveInequalities;
+ return analyzer_->transitive_comparisons.TryCompare(x, y,
propagate_inequalities);
}
// try to prove x equals val
diff --git a/src/arith/transitive_comparison_analyzer.cc
b/src/arith/transitive_comparison_analyzer.cc
index 9a835f7fde..b71096a479 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -43,10 +43,19 @@ class TransitiveComparisonAnalyzer::Impl {
*
* \param rhs The right-hand side of the comparison
*
+ * \param propagate_inequalities If true, attempt to find a sequence
+ * of transitive inequalities that allow the lhs and rhs to be
+ * compared. If false, only use the known comparison that have been
+ * directly provided. Using `propagate_inequalities = false` is
+ * roughly equivalent to comparing against all known values with
+ * `ExprDeepEqual`, but also allowing for constant offsets on either
+ * side of the inequality.
+ *
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
- CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+ CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
+ bool propagate_inequalities = true) const;
/*! \brief Bind a variable as being equal to a known expression
*
@@ -192,7 +201,37 @@ class TransitiveComparisonAnalyzer::Impl {
*/
void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
- /*! \brief Attempt to compare the expressions, starting at the lhs.
+ /*! Collect known comparisons between LHS and RHS, without propagation
+ *
+ * Allows the internal representation to handle any constant
+ * offsets, without searching for a sequence of inequalities.
+ *
+ * \param lhs_key The left-hand side of the comparison
+ *
+ * \param rhs_key The right-hand side of the comparison
+ *
+ * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
+ * only include comparisons between `lhs_key` and `rhs_key`,
+ * normalized such that `lhs_key` is on the left-hand side.
+ */
+ std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key)
const;
+
+ /*! Collect known comparisons between LHS and RHS, with propagation
+ *
+ * \param lhs_key The left-hand side of the comparison
+ *
+ * \param rhs_key The right-hand side of the comparison
+ *
+ * \returns All comparisons between `lhs_key` and `rhs_key`,
+ * including the explicitly-provided comparisons in `knowns_` and
+ * `scoped_knowns_`, and comparisons provable through a series of
+ * comparisons through other values. All comparisons returned are
+ * between `lhs_key` and `rhs_key`, and are normalized such that
+ * `lhs_key` is on the left-hand side.
+ */
+ std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key)
const;
+
+ /*! \brief Internal function used by CollectIndirectComparisons
*
* Perform a depth-first search through the space of known
* expressions, starting at the LHS of a comparison. In this
@@ -208,14 +247,29 @@ class TransitiveComparisonAnalyzer::Impl {
* expression D, then combine the comparisons that compose the path
* into the expression A<=D-4.
*
- * \param lhs The left-hand side of the comparison
+ * \param lhs_key The left-hand side of the comparison
*
- * \param rhs The right-hand side of the comparison
+ * \param rhs_key The right-hand side of the comparison
+ *
+ * \returns A vector of comparisons between the two expressions.
+ */
+ std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const;
+
+ /*! \brief Combine a set of comparisons that share a LHS and RHS
+ *
+ * \param lhs_to_rhs The comparisons to merge. These should all
+ * have the same LHS and RHS. This parameter will typically be the
+ * result from `CollectDirectComparisons` or
+ * `CollectIndirectComparisons`.
*
- * \return The result of the comparison
+ * \param offset The constant offset in the comparison being proven.
+ * This is extracted from any additive/subtractive constants in the
+ * `PrimExpr` arguments to `TryCompare`.
+ *
+ * \returns The possible comparisons between LHS and RHS provided
+ * inequalities.
*/
- CompareResult DFSFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const
PrimExpr& lhs,
- const PrimExpr& rhs) const;
+ CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs,
int64_t offset) const;
/*! \brief Previous Range bindings
*
@@ -475,8 +529,9 @@ bool
TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() :
impl_(std::make_unique<Impl>()) {}
TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
-CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs,
const PrimExpr& rhs) {
- return impl_->TryCompare(lhs, rhs);
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs,
const PrimExpr& rhs,
+ bool
propagate_inequalities) {
+ return impl_->TryCompare(lhs, rhs, propagate_inequalities);
}
void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr,
bool allow_override) {
@@ -547,7 +602,8 @@ std::function<void()>
TransitiveComparisonAnalyzer::Impl::EnterConstraint(const
}
CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr&
lhs_expr,
- const PrimExpr&
rhs_expr) const {
+ const PrimExpr&
rhs_expr,
+ bool
propagate_inequalities) const {
// Currently only supports integer checks
if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
return CompareResult::kUnknown;
@@ -575,29 +631,59 @@ CompareResult
TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs
return CompareResult::kUnknown;
}
- auto from_lhs = DFSFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs,
rhs);
- auto from_rhs = Reverse(DFSFromLHS(rhs_key.value(), lhs_key.value(),
-offset, rhs, lhs));
- auto output = from_lhs & from_rhs;
+ auto lhs_to_rhs = [&]() {
+ if (propagate_inequalities) {
+ return CollectIndirectComparisons(lhs_key.value(), rhs_key.value());
+ } else {
+ return CollectDirectComparisons(lhs_key.value(), rhs_key.value());
+ }
+ }();
+ return MergeComparisons(lhs_to_rhs, offset);
+}
+
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key
rhs_key) const {
+ std::vector<Comparison> output;
+
+ auto append_known = [&](Comparison cmp) {
+ if (auto normalized = cmp.WithLHS(lhs_key)) {
+ if (normalized.value().rhs_ == rhs_key) {
+ output.push_back(normalized.value());
+ }
+ }
+ };
+
+ for (const auto& known : knowns_) {
+ append_known(known);
+ }
+ for (const auto& known : scoped_knowns_) {
+ append_known(known);
+ }
return output;
}
-CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key
lhs_key_input, Key rhs_key_input,
- int64_t
offset_input,
- const PrimExpr&
lhs_input,
- const PrimExpr&
rhs_input) const {
- Key lhs_key = lhs_key_input;
- Key rhs_key = rhs_key_input;
- int64_t offset = offset_input;
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key,
Key rhs_key) const {
+ auto output = DFSFromLHS(lhs_key, rhs_key);
+ for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) {
+ auto opt_normalized = cmp.WithLHS(lhs_key);
+ ICHECK(opt_normalized.has_value());
+ output.push_back(opt_normalized.value());
+ }
+ return output;
+}
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const
{
// Everything in `to_visit` has lhs as its lhs.
std::unordered_set<Key> seen;
std::unordered_set<Key> to_visit;
- std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+ std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs;
// Utility function to add a new known statement
auto declare_known = [&](Comparison cmp) {
- std::vector<Comparison>& knowns = compared_to_x[cmp.rhs_];
+ std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_];
// The comparison adds no new information, no modification
// required.
@@ -646,8 +732,8 @@ CompareResult
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
Key middle_key = *to_visit.begin();
to_visit.erase(to_visit.begin());
- std::vector<Comparison>& prev_knowns_using_middle =
compared_to_x.at(middle_key);
- ICHECK(compared_to_x.count(middle_key));
+ std::vector<Comparison>& prev_knowns_using_middle =
compared_to_lhs.at(middle_key);
+ ICHECK(compared_to_lhs.count(middle_key));
std::vector<Comparison> new_knowns_using_lhs;
@@ -721,27 +807,29 @@ CompareResult
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
}
}
- // It's possible that we don't have any transitive comparisons that
- // can prove something about LHS and RHS.
- auto it = compared_to_x.find(rhs_key);
- if (it == compared_to_x.end()) {
- return CompareResult::kUnknown;
+ if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) {
+ return it->second;
+ } else {
+ // There are known comparisons involving the LHS and the RHS, but
+ // no path that connects the two expressions.
+ return {};
}
+}
- const std::vector<Comparison>& known_between_lhs_and_rhs = it->second;
-
+CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons(
+ const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const {
// Just because we found a comparison involving LHS and RHS doesn't
// mean that it's useful. e.g. Knowing that `x < y` doesn't let us
// prove whether `x + 5 < y`.
CompareResult result = CompareResult::kUnknown;
- for (const auto& known : known_between_lhs_and_rhs) {
- switch (known.result_) {
+ for (const auto& cmp : lhs_to_rhs) {
+ switch (cmp.result_) {
case CompareResult::kInconsistent:
result = CompareResult::kInconsistent;
break;
case CompareResult::kEQ:
- if (offset == known.offset_) {
+ if (offset == cmp.offset_) {
result = result & CompareResult::kEQ;
} else {
result = result & CompareResult::kNE;
@@ -749,23 +837,23 @@ CompareResult
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
break;
case CompareResult::kLE:
- if (known.offset_ < offset) {
+ if (cmp.offset_ < offset) {
result = result & CompareResult::kLT;
- } else if (known.offset_ <= offset) {
+ } else if (cmp.offset_ <= offset) {
result = result & CompareResult::kLE;
}
break;
case CompareResult::kGE:
- if (known.offset_ > offset) {
+ if (cmp.offset_ > offset) {
result = result & CompareResult::kGT;
- } else if (known.offset_ >= offset) {
+ } else if (cmp.offset_ >= offset) {
result = result & CompareResult::kGE;
}
break;
case CompareResult::kNE:
- if (offset == known.offset_) {
+ if (offset == cmp.offset_) {
result = result & CompareResult::kNE;
}
break;
@@ -779,7 +867,7 @@ CompareResult
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
return CompareResult::kInconsistent;
default:
- LOG(FATAL) << "Invalid CompareResult: " <<
static_cast<int>(known.result_);
+ LOG(FATAL) << "Invalid CompareResult: " <<
static_cast<int>(cmp.result_);
return CompareResult::kInconsistent;
}
}
diff --git a/tests/python/unittest/test_tir_transform_simplify.py
b/tests/python/unittest/test_tir_transform_simplify.py
index 91ef60f9d3..4c5499edcf 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -989,5 +989,19 @@ class
TestSimplifyLHSOfBooleanOrUsingRHSWithoutConst(BaseBeforeAfter):
A[0] = n < m + 10
+class TestProvableConditionWithOffset(BaseBeforeAfter):
+ """Use scoped-constraint to prove inequalities"""
+
+ transitively_prove_inequalities = False
+
+ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+ if i < j:
+ A[0] = i < j + 1
+
+ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+ if i < j:
+ A[0] = True
+
+
if __name__ == "__main__":
tvm.testing.main()