yzhliu commented on a change in pull request #5618: URL: https://github.com/apache/incubator-tvm/pull/5618#discussion_r448630687
########## File path: src/arith/solve_linear_inequality.cc ########## @@ -0,0 +1,648 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/solve_linear_inequality.cc + * \brief Solve linear inequalities. + */ +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_solver.h> +#include <tvm/arith/pattern.h> +#include <tvm/runtime/data_type.h> +#include <tvm/runtime/registry.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt_functor.h> + +#include "int_operator.h" + +namespace tvm { +namespace arith { + +using namespace tvm::runtime; +using namespace tvm::tir; + +#define PLUS_ONE(OP) \ + void VisitExpr_(const OP* op) final { num_symbols_++; } + +#define PLUS_ONE_BINARY(OP) \ + void VisitExpr_(const OP* op) final { \ + num_symbols_++; \ + VisitExpr(op->a); \ + VisitExpr(op->b); \ + } + +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + */ +class ExprComplexity : public ExprVisitor { + public: + size_t Eval(const PrimExpr& expr) { + VisitExpr(expr); + return num_symbols_; + } + + PLUS_ONE_BINARY(AddNode) + PLUS_ONE_BINARY(SubNode) + PLUS_ONE_BINARY(MulNode) + PLUS_ONE_BINARY(DivNode) + PLUS_ONE_BINARY(ModNode) + PLUS_ONE_BINARY(FloorDivNode) + PLUS_ONE_BINARY(FloorModNode) + PLUS_ONE_BINARY(MinNode) + PLUS_ONE_BINARY(MaxNode) + PLUS_ONE_BINARY(EQNode) + PLUS_ONE_BINARY(NENode) + PLUS_ONE_BINARY(LTNode) + PLUS_ONE_BINARY(LENode) + PLUS_ONE_BINARY(GTNode) + PLUS_ONE_BINARY(GENode) + PLUS_ONE_BINARY(AndNode) + PLUS_ONE_BINARY(OrNode) + PLUS_ONE(VarNode) + PLUS_ONE(FloatImmNode) + PLUS_ONE(IntImmNode) + void VisitExpr_(const NotNode* op) final { + num_symbols_++; + VisitExpr(op->a); + } + + private: + size_t num_symbols_{0}; +}; + +struct ExprLess { + bool operator()(const PrimExpr& l, const PrimExpr& r) const { + return ExprComplexity().Eval(l) < ExprComplexity().Eval(r); + } +}; + +/*! + * \brief Combine the information into an array of (in)equalities. + */ +Array<PrimExpr> as_conditions(const Array<Var>& variables, const Map<Var, IntGrpBounds>& bounds, + const Array<PrimExpr>& relations) { + Array<PrimExpr> res; + // use variables to keep the order of iteration + // so as to get rid of any non-determinism. + CHECK_EQ(variables.size(), bounds.size()); + for (const auto v : variables) { + CHECK(bounds.count(v)); + const auto& bnds = bounds[v]; + PrimExpr lhs = bnds->coef * v; + for (const PrimExpr& rhs : bnds->equal) { + res.push_back(tir::EQ(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->lower) { + res.push_back(tir::GE(lhs, rhs)); + } + for (const PrimExpr& rhs : bnds->upper) { + res.push_back(tir::LE(lhs, rhs)); + } + } + for (const PrimExpr& e : relations) { + res.push_back(e); + } + return res; +} + +void DebugPrint( + const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set, + const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set, + const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos, + const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) { + std::cout << "Current ineq set:\n["; + for (auto& ineq : current_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "Next ineq set:\n["; + for (auto& ineq : next_ineq_set) { + std::cout << ineq << ", "; + } + std::cout << "]\n"; + + std::cout << "coef_pos:\n["; + for (auto& coef : coef_pos) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; + + std::cout << "coef_neg:\n["; + for (auto& coef : coef_neg) { + std::cout << "(" << coef.first << ", " << coef.second << "), "; + } + std::cout << "]\n"; +} + +/*! + * \brief normalize to the form `expr <= 0` + */ +class NormalizeComparisons : public ExprMutator { + public: + PrimExpr VisitExpr_(const EQNode* op) override { return Make<EQ>(op->a, op->b); } + PrimExpr VisitExpr_(const NENode* op) override { return Make<NE>(op->a, op->b); } + PrimExpr VisitExpr_(const LTNode* op) override { return Make<LT>(op->a, op->b); } + PrimExpr VisitExpr_(const LENode* op) override { return Make<LE>(op->a, op->b); } + PrimExpr VisitExpr_(const GTNode* op) override { return Make<LT>(op->b, op->a); } + PrimExpr VisitExpr_(const GENode* op) override { return Make<LE>(op->b, op->a); } + + private: + template <class T> + PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { + // rewrite LT to LE for ints + if (std::is_same<T, LT>::value && (a.dtype().is_int() || a.dtype().is_uint())) { + return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); + } + return T(analyzer_.Simplify(a - b), make_zero(a.dtype())); + } + arith::Analyzer analyzer_; +}; + +void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set, + const PrimExpr& new_ineq, Analyzer* analyzer) { + if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { + // redundant: follows from the vranges + // or has already been added + return; + } + for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { + if (const LENode* new_le = new_ineq.as<LENode>()) { Review comment: nice catch. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
