This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 25a0d47d2b [Arith][TIR] Check for constant offsets of known literal 
constraints (#13023)
25a0d47d2b is described below

commit 25a0d47d2b55f3404ea711a3ff28bf22f7cc0e17
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Oct 29 12:55:55 2022 -0500

    [Arith][TIR] Check for constant offsets of known literal constraints 
(#13023)
    
    Previously, the checks for a literal constraint would find exact
    matches for an inequality, but any alterations to the conditional
    would break this exact matching.  This commit introduces checks for
    constant offsets relative to a known value.  These checks are not
    always expressible using the existing `ConstIntSetAnalyzer`, which
    represents allowed values using a single contiguous
    region.  (e.g. `i!=5` is not representable, because it requires a
    region for `i<5` and another for `i>5`.)
    
    This implementation reuses the internal representation for
    inequalities introduced in https://github.com/apache/tvm/pull/12863,
    along with much of its implementation.  However, the indirect
    comparisons (e.g. using `a < b` and `b < c` to prove that `a < c`)
    introduced in that PR still require an explicit flag to be used.
---
 include/tvm/arith/analyzer.h                       |  11 +-
 src/arith/rewrite_simplify.cc                      |   7 +-
 src/arith/transitive_comparison_analyzer.cc        | 168 ++++++++++++++++-----
 .../python/unittest/test_tir_transform_simplify.py |  14 ++
 4 files changed, 155 insertions(+), 45 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e2d60684da..885c23f491 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -409,10 +409,19 @@ class TransitiveComparisonAnalyzer {
    *
    * \param rhs The right-hand side of the comparison
    *
+   * \param propagate_inequalities If true, attempt to find a sequence
+   * of transitive inequalities that allow the lhs and rhs to be
+   * compared.  If false, only use the known comparison that have been
+   * directly provided.  Using `propagate_inequalities = false` is
+   * roughly equivalent to comparing against all known inequality
+   * expressions using `ExprDeepEqual`, but also allows for constant
+   * offsets on either side of the inequality.
+   *
    * \return The most specific result that can be proven about the
    * comparison.  If nothing can be proven, returns kUnknown.
    */
-  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);
+  TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
+                                   bool propagate_inequalities = true);
 
   /*! \brief Bind a variable as being equal to a known expression
    *
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 6cc2aa9e45..a42303e459 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -118,9 +118,7 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const 
PrimExpr& x, const PrimE
 
   if (is_finished()) return output;
 
-  if (enabled_extensions_ & kTransitivelyProveInequalities) {
-    output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
-  }
+  output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
 
   return output;
 }
@@ -132,7 +130,8 @@ CompareResult 
RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimE
 
 CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const 
PrimExpr& x,
                                                                         const 
PrimExpr& y) {
-  return analyzer_->transitive_comparisons.TryCompare(x, y);
+  bool propagate_inequalities = enabled_extensions_ & 
kTransitivelyProveInequalities;
+  return analyzer_->transitive_comparisons.TryCompare(x, y, 
propagate_inequalities);
 }
 
 // try to prove x equals val
diff --git a/src/arith/transitive_comparison_analyzer.cc 
b/src/arith/transitive_comparison_analyzer.cc
index 9a835f7fde..b71096a479 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -43,10 +43,19 @@ class TransitiveComparisonAnalyzer::Impl {
    *
    * \param rhs The right-hand side of the comparison
    *
+   * \param propagate_inequalities If true, attempt to find a sequence
+   * of transitive inequalities that allow the lhs and rhs to be
+   * compared.  If false, only use the known comparison that have been
+   * directly provided.  Using `propagate_inequalities = false` is
+   * roughly equivalent to comparing against all known values with
+   * `ExprDeepEqual`, but also allowing for constant offsets on either
+   * side of the inequality.
+   *
    * \return The most specific result that can be proven about the
    * comparison.  If nothing can be proven, returns kUnknown.
    */
-  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
+  CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
+                           bool propagate_inequalities = true) const;
 
   /*! \brief Bind a variable as being equal to a known expression
    *
@@ -192,7 +201,37 @@ class TransitiveComparisonAnalyzer::Impl {
    */
   void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
 
-  /*! \brief Attempt to compare the expressions, starting at the lhs.
+  /*! Collect known comparisons between LHS and RHS, without propagation
+   *
+   * Allows the internal representation to handle any constant
+   * offsets, without searching for a sequence of inequalities.
+   *
+   * \param lhs_key The left-hand side of the comparison
+   *
+   * \param rhs_key The right-hand side of the comparison
+   *
+   * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
+   * only include comparisons between `lhs_key` and `rhs_key`,
+   * normalized such that `lhs_key` is on the left-hand side.
+   */
+  std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) 
const;
+
+  /*! Collect known comparisons between LHS and RHS, with propagation
+   *
+   * \param lhs_key The left-hand side of the comparison
+   *
+   * \param rhs_key The right-hand side of the comparison
+   *
+   * \returns All comparisons between `lhs_key` and `rhs_key`,
+   * including the explicitly-provided comparisons in `knowns_` and
+   * `scoped_knowns_`, and comparisons provable through a series of
+   * comparisons through other values.  All comparisons returned are
+   * between `lhs_key` and `rhs_key`, and are normalized such that
+   * `lhs_key` is on the left-hand side.
+   */
+  std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key) 
const;
+
+  /*! \brief Internal function used by CollectIndirectComparisons
    *
    * Perform a depth-first search through the space of known
    * expressions, starting at the LHS of a comparison.  In this
@@ -208,14 +247,29 @@ class TransitiveComparisonAnalyzer::Impl {
    * expression D, then combine the comparisons that compose the path
    * into the expression A<=D-4.
    *
-   * \param lhs The left-hand side of the comparison
+   * \param lhs_key The left-hand side of the comparison
    *
-   * \param rhs The right-hand side of the comparison
+   * \param rhs_key The right-hand side of the comparison
+   *
+   * \returns A vector of comparisons between the two expressions.
+   */
+  std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const;
+
+  /*! \brief Combine a set of comparisons that share a LHS and RHS
+   *
+   * \param lhs_to_rhs The comparisons to merge.  These should all
+   * have the same LHS and RHS.  This parameter will typically be the
+   * result from `CollectDirectComparisons` or
+   * `CollectIndirectComparisons`.
    *
-   * \return The result of the comparison
+   * \param offset The constant offset in the comparison being proven.
+   * This is extracted from any additive/subtractive constants in the
+   * `PrimExpr` arguments to `TryCompare`.
+   *
+   * \returns The possible comparisons between LHS and RHS provided
+   * inequalities.
    */
-  CompareResult DFSFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const 
PrimExpr& lhs,
-                           const PrimExpr& rhs) const;
+  CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs, 
int64_t offset) const;
 
   /*! \brief Previous Range bindings
    *
@@ -475,8 +529,9 @@ bool 
TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
 TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : 
impl_(std::make_unique<Impl>()) {}
 TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
 
-CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, 
const PrimExpr& rhs) {
-  return impl_->TryCompare(lhs, rhs);
+CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, 
const PrimExpr& rhs,
+                                                       bool 
propagate_inequalities) {
+  return impl_->TryCompare(lhs, rhs, propagate_inequalities);
 }
 
 void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, 
bool allow_override) {
@@ -547,7 +602,8 @@ std::function<void()> 
TransitiveComparisonAnalyzer::Impl::EnterConstraint(const
 }
 
 CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& 
lhs_expr,
-                                                             const PrimExpr& 
rhs_expr) const {
+                                                             const PrimExpr& 
rhs_expr,
+                                                             bool 
propagate_inequalities) const {
   // Currently only supports integer checks
   if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
     return CompareResult::kUnknown;
@@ -575,29 +631,59 @@ CompareResult 
TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs
     return CompareResult::kUnknown;
   }
 
-  auto from_lhs = DFSFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, 
rhs);
-  auto from_rhs = Reverse(DFSFromLHS(rhs_key.value(), lhs_key.value(), 
-offset, rhs, lhs));
-  auto output = from_lhs & from_rhs;
+  auto lhs_to_rhs = [&]() {
+    if (propagate_inequalities) {
+      return CollectIndirectComparisons(lhs_key.value(), rhs_key.value());
+    } else {
+      return CollectDirectComparisons(lhs_key.value(), rhs_key.value());
+    }
+  }();
+  return MergeComparisons(lhs_to_rhs, offset);
+}
+
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key 
rhs_key) const {
+  std::vector<Comparison> output;
+
+  auto append_known = [&](Comparison cmp) {
+    if (auto normalized = cmp.WithLHS(lhs_key)) {
+      if (normalized.value().rhs_ == rhs_key) {
+        output.push_back(normalized.value());
+      }
+    }
+  };
+
+  for (const auto& known : knowns_) {
+    append_known(known);
+  }
+  for (const auto& known : scoped_knowns_) {
+    append_known(known);
+  }
 
   return output;
 }
 
-CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key 
lhs_key_input, Key rhs_key_input,
-                                                             int64_t 
offset_input,
-                                                             const PrimExpr& 
lhs_input,
-                                                             const PrimExpr& 
rhs_input) const {
-  Key lhs_key = lhs_key_input;
-  Key rhs_key = rhs_key_input;
-  int64_t offset = offset_input;
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, 
Key rhs_key) const {
+  auto output = DFSFromLHS(lhs_key, rhs_key);
+  for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) {
+    auto opt_normalized = cmp.WithLHS(lhs_key);
+    ICHECK(opt_normalized.has_value());
+    output.push_back(opt_normalized.value());
+  }
+  return output;
+}
 
+std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
+TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const 
{
   // Everything in `to_visit` has lhs as its lhs.
   std::unordered_set<Key> seen;
   std::unordered_set<Key> to_visit;
-  std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
+  std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs;
 
   // Utility function to add a new known statement
   auto declare_known = [&](Comparison cmp) {
-    std::vector<Comparison>& knowns = compared_to_x[cmp.rhs_];
+    std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_];
 
     // The comparison adds no new information, no modification
     // required.
@@ -646,8 +732,8 @@ CompareResult 
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
     Key middle_key = *to_visit.begin();
     to_visit.erase(to_visit.begin());
 
-    std::vector<Comparison>& prev_knowns_using_middle = 
compared_to_x.at(middle_key);
-    ICHECK(compared_to_x.count(middle_key));
+    std::vector<Comparison>& prev_knowns_using_middle = 
compared_to_lhs.at(middle_key);
+    ICHECK(compared_to_lhs.count(middle_key));
 
     std::vector<Comparison> new_knowns_using_lhs;
 
@@ -721,27 +807,29 @@ CompareResult 
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
     }
   }
 
-  // It's possible that we don't have any transitive comparisons that
-  // can prove something about LHS and RHS.
-  auto it = compared_to_x.find(rhs_key);
-  if (it == compared_to_x.end()) {
-    return CompareResult::kUnknown;
+  if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) {
+    return it->second;
+  } else {
+    // There are known comparisons involving the LHS and the RHS, but
+    // no path that connects the two expressions.
+    return {};
   }
+}
 
-  const std::vector<Comparison>& known_between_lhs_and_rhs = it->second;
-
+CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons(
+    const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const {
   // Just because we found a comparison involving LHS and RHS doesn't
   // mean that it's useful.  e.g. Knowing that `x < y` doesn't let us
   // prove whether `x + 5 < y`.
   CompareResult result = CompareResult::kUnknown;
-  for (const auto& known : known_between_lhs_and_rhs) {
-    switch (known.result_) {
+  for (const auto& cmp : lhs_to_rhs) {
+    switch (cmp.result_) {
       case CompareResult::kInconsistent:
         result = CompareResult::kInconsistent;
         break;
 
       case CompareResult::kEQ:
-        if (offset == known.offset_) {
+        if (offset == cmp.offset_) {
           result = result & CompareResult::kEQ;
         } else {
           result = result & CompareResult::kNE;
@@ -749,23 +837,23 @@ CompareResult 
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
         break;
 
       case CompareResult::kLE:
-        if (known.offset_ < offset) {
+        if (cmp.offset_ < offset) {
           result = result & CompareResult::kLT;
-        } else if (known.offset_ <= offset) {
+        } else if (cmp.offset_ <= offset) {
           result = result & CompareResult::kLE;
         }
         break;
 
       case CompareResult::kGE:
-        if (known.offset_ > offset) {
+        if (cmp.offset_ > offset) {
           result = result & CompareResult::kGT;
-        } else if (known.offset_ >= offset) {
+        } else if (cmp.offset_ >= offset) {
           result = result & CompareResult::kGE;
         }
         break;
 
       case CompareResult::kNE:
-        if (offset == known.offset_) {
+        if (offset == cmp.offset_) {
           result = result & CompareResult::kNE;
         }
         break;
@@ -779,7 +867,7 @@ CompareResult 
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
         return CompareResult::kInconsistent;
 
       default:
-        LOG(FATAL) << "Invalid CompareResult: " << 
static_cast<int>(known.result_);
+        LOG(FATAL) << "Invalid CompareResult: " << 
static_cast<int>(cmp.result_);
         return CompareResult::kInconsistent;
     }
   }
diff --git a/tests/python/unittest/test_tir_transform_simplify.py 
b/tests/python/unittest/test_tir_transform_simplify.py
index 91ef60f9d3..4c5499edcf 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -989,5 +989,19 @@ class 
TestSimplifyLHSOfBooleanOrUsingRHSWithoutConst(BaseBeforeAfter):
         A[0] = n < m + 10
 
 
+class TestProvableConditionWithOffset(BaseBeforeAfter):
+    """Use scoped-constraint to prove inequalities"""
+
+    transitively_prove_inequalities = False
+
+    def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+        if i < j:
+            A[0] = i < j + 1
+
+    def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+        if i < j:
+            A[0] = True
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to