Lunderberg commented on code in PR #12863: URL: https://github.com/apache/tvm/pull/12863#discussion_r989464193
########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { + ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); + }; + return frecover; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, + const PrimExpr& rhs_expr) const { + // Currently only supports integer checks + if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + return CompareResult::kUnknown; + } + + // Bail out early if possible. This int check should have been + // constant-folded earlier, so this check shouldn't occur. + auto* x_int = lhs_expr.as<IntImmNode>(); + auto* y_int = rhs_expr.as<IntImmNode>(); + if (x_int && y_int) { + if (x_int->value < y_int->value) { + return CompareResult::kLT; + } else if (x_int->value > y_int->value) { + return CompareResult::kGT; + } else { + return CompareResult::kEQ; + } + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + auto lhs_key = ExprToPreviousKey(lhs); + auto rhs_key = ExprToPreviousKey(rhs); + + if (!lhs_key.has_value() || !rhs_key.has_value()) { + return CompareResult::kUnknown; + } + + auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); + auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); + auto output = from_lhs & from_rhs; + + return output; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS( Review Comment: Hmm, good point. Renamed to `DFSFromLHS`, which hopefully works with the updated documentation to reduce the verbosity. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
