Lunderberg commented on code in PR #12863: URL: https://github.com/apache/tvm/pull/12863#discussion_r989445868
########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; Review Comment: Sounds good, and added a description in the constructor, along with an explanation of why the internal representation normalizes everything to GE/LE, instead of normalizing to GT/LT. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
