yzhliu commented on a change in pull request #5618:
URL: https://github.com/apache/incubator-tvm/pull/5618#discussion_r441203304



##########
File path: src/arith/solve_linear_inequality.cc
##########
@@ -0,0 +1,629 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/solve_linear_inequality.cc
+ * \brief Solve linear inequalities.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/arith/pattern.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "int_operator.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tvm::runtime;
+using namespace tvm::tir;
+
+#define PLUS_ONE(OP) \
+  void VisitExpr_(const OP* op) final { num_symbols_++; }
+
+#define PLUS_ONE_BINARY(OP)             \
+  void VisitExpr_(const OP* op) final { \
+    num_symbols_++;                     \
+    VisitExpr(op->a);                   \
+    VisitExpr(op->b);                   \
+  }
+
+/*!
+ * \brief Calculate the expresion complexity based on number of symbols it 
contains.
+ */
+class ExprComplexity : public ExprVisitor {
+ public:
+  size_t Eval(const PrimExpr& expr) {
+    VisitExpr(expr);
+    return num_symbols_;
+  }
+
+  PLUS_ONE_BINARY(AddNode)
+  PLUS_ONE_BINARY(SubNode)
+  PLUS_ONE_BINARY(MulNode)
+  PLUS_ONE_BINARY(DivNode)
+  PLUS_ONE_BINARY(ModNode)
+  PLUS_ONE_BINARY(FloorDivNode)
+  PLUS_ONE_BINARY(FloorModNode)
+  PLUS_ONE_BINARY(MinNode)
+  PLUS_ONE_BINARY(MaxNode)
+  PLUS_ONE_BINARY(EQNode)
+  PLUS_ONE_BINARY(NENode)
+  PLUS_ONE_BINARY(LTNode)
+  PLUS_ONE_BINARY(LENode)
+  PLUS_ONE_BINARY(GTNode)
+  PLUS_ONE_BINARY(GENode)
+  PLUS_ONE_BINARY(AndNode)
+  PLUS_ONE_BINARY(OrNode)
+  PLUS_ONE(VarNode)
+  PLUS_ONE(FloatImmNode)
+  PLUS_ONE(IntImmNode)
+  void VisitExpr_(const NotNode* op) final {
+    num_symbols_++;
+    VisitExpr(op->a);
+  }
+
+ private:
+  size_t num_symbols_{0};
+};
+
+struct ExprLess {
+  bool operator()(const PrimExpr& l, const PrimExpr& r) const {
+    return ExprComplexity().Eval(l) < ExprComplexity().Eval(r);

Review comment:
       I use unordered_set instead. The only place I use `ExprLess` is to get 
the simplest equation form (`std::sort(equal_list.begin(), equal_list.end(), 
ExprLess());`). I guess it is ok to keep it as simple.

##########
File path: src/arith/solve_linear_inequality.cc
##########
@@ -0,0 +1,629 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/solve_linear_inequality.cc
+ * \brief Solve linear inequalities.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/arith/pattern.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "int_operator.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tvm::runtime;
+using namespace tvm::tir;
+
+#define PLUS_ONE(OP) \
+  void VisitExpr_(const OP* op) final { num_symbols_++; }
+
+#define PLUS_ONE_BINARY(OP)             \
+  void VisitExpr_(const OP* op) final { \
+    num_symbols_++;                     \
+    VisitExpr(op->a);                   \
+    VisitExpr(op->b);                   \
+  }
+
+/*!
+ * \brief Calculate the expresion complexity based on number of symbols it 
contains.
+ */
+class ExprComplexity : public ExprVisitor {
+ public:
+  size_t Eval(const PrimExpr& expr) {
+    VisitExpr(expr);
+    return num_symbols_;
+  }
+
+  PLUS_ONE_BINARY(AddNode)
+  PLUS_ONE_BINARY(SubNode)
+  PLUS_ONE_BINARY(MulNode)
+  PLUS_ONE_BINARY(DivNode)
+  PLUS_ONE_BINARY(ModNode)
+  PLUS_ONE_BINARY(FloorDivNode)
+  PLUS_ONE_BINARY(FloorModNode)
+  PLUS_ONE_BINARY(MinNode)
+  PLUS_ONE_BINARY(MaxNode)
+  PLUS_ONE_BINARY(EQNode)
+  PLUS_ONE_BINARY(NENode)
+  PLUS_ONE_BINARY(LTNode)
+  PLUS_ONE_BINARY(LENode)
+  PLUS_ONE_BINARY(GTNode)
+  PLUS_ONE_BINARY(GENode)
+  PLUS_ONE_BINARY(AndNode)
+  PLUS_ONE_BINARY(OrNode)
+  PLUS_ONE(VarNode)
+  PLUS_ONE(FloatImmNode)
+  PLUS_ONE(IntImmNode)
+  void VisitExpr_(const NotNode* op) final {
+    num_symbols_++;
+    VisitExpr(op->a);
+  }
+
+ private:
+  size_t num_symbols_{0};
+};
+
+struct ExprLess {
+  bool operator()(const PrimExpr& l, const PrimExpr& r) const {
+    return ExprComplexity().Eval(l) < ExprComplexity().Eval(r);
+  }
+};
+
+/*!
+ * \brief Combine the information into an array of (in)equalities.
+ */
+Array<PrimExpr> as_conditions(const Map<Var, IntGrpBounds>& bounds,
+                              const Array<PrimExpr>& relations) {
+  Array<PrimExpr> res;
+  for (const auto iter : bounds) {
+    const Var& v = iter.first;
+    const auto& bnds = iter.second;
+    PrimExpr lhs = bnds->coef * v;
+    for (const PrimExpr& rhs : bnds->equal) {
+      res.push_back(tir::EQNode::make(lhs, rhs));
+    }
+    for (const PrimExpr& rhs : bnds->lower) {
+      res.push_back(tir::GENode::make(lhs, rhs));
+    }
+    for (const PrimExpr& rhs : bnds->upper) {
+      res.push_back(tir::LENode::make(lhs, rhs));
+    }
+  }
+  for (const PrimExpr& e : relations) {
+    res.push_back(e);
+  }
+  return res;
+}
+
+void DebugPrint(
+    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
current_ineq_set,
+    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
next_ineq_set,
+    const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, 
PrimExpr>>& coef_pos,
+    const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
+  std::cout << "Current ineq set:\n[";
+  for (auto& ineq : current_ineq_set) {
+    std::cout << ineq << ", ";
+  }
+  std::cout << "]\n";
+
+  std::cout << "Next ineq set:\n[";
+  for (auto& ineq : next_ineq_set) {
+    std::cout << ineq << ", ";
+  }
+  std::cout << "]\n";
+
+  std::cout << "coef_pos:\n[";
+  for (auto& coef : coef_pos) {
+    std::cout << "(" << coef.first << ", " << coef.second << "), ";
+  }
+  std::cout << "]\n";
+
+  std::cout << "coef_neg:\n[";
+  for (auto& coef : coef_neg) {
+    std::cout << "(" << coef.first << ", " << coef.second << "), ";
+  }
+  std::cout << "]\n";
+}
+
+/*!
+ * \brief normalize to the form `expr <= 0`
+ */
+class NormalizeComparisons : public ExprMutator {
+ public:
+  PrimExpr VisitExpr_(const EQNode* op) override { return Make<EQNode>(op->a, 
op->b); }
+  PrimExpr VisitExpr_(const NENode* op) override { return Make<NENode>(op->a, 
op->b); }
+  PrimExpr VisitExpr_(const LTNode* op) override { return Make<LTNode>(op->a, 
op->b); }
+  PrimExpr VisitExpr_(const LENode* op) override { return Make<LENode>(op->a, 
op->b); }
+  PrimExpr VisitExpr_(const GTNode* op) override { return Make<LTNode>(op->b, 
op->a); }
+  PrimExpr VisitExpr_(const GENode* op) override { return Make<LENode>(op->b, 
op->a); }
+
+ private:
+  template <class TNode>
+  PrimExpr Make(const PrimExpr& a, const PrimExpr& b) {
+    // rewrite LT to LE for ints
+    if (std::is_same<TNode, LTNode>::value && (a.dtype().is_int() || 
a.dtype().is_uint())) {
+      return LENode::make(analyzer_.Simplify(a - b + 1), make_zero(a.dtype()));
+    }
+    return TNode::make(analyzer_.Simplify(a - b), make_zero(a.dtype()));
+  }
+  arith::Analyzer analyzer_;
+};
+
+void AddInequality(std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* inequality_set,
+                   const PrimExpr& new_ineq, Analyzer* analyzer) {
+  if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != 
inequality_set->end()) {
+    // redundant: follows from the vranges
+    // or has already been added
+    return;
+  }
+  for (auto iter = inequality_set->begin(); iter != inequality_set->end();) {
+    if (const LENode* new_le = new_ineq.as<LENode>()) {
+      const LENode* le = iter->as<LENode>();
+      if (le && analyzer->CanProve(new_le->a - le->a <= 0)) {
+        return;
+      } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) {
+        iter = inequality_set->erase(iter);
+      } else {
+        ++iter;
+      }
+    } else {
+      ++iter;
+    }
+  }
+
+  inequality_set->insert(new_ineq);
+}
+
+void ClassifyByPolarity(
+    const Var& var,
+    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
current_ineq_set,
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* 
next_ineq_set,
+    std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* 
coef_pos,
+    std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
+  // Take formulas from current_ineq_set and classify them according to 
polarity wrt var
+  // and store to coef_pos and coef_neg respectively.
+  for (const PrimExpr& ineq : current_ineq_set) {
+    if (const LENode* le = ineq.as<LENode>()) {
+      Array<PrimExpr> coef = arith::DetectLinearEquation(le->a, {var});
+      if (!coef.empty() && is_const(coef[0])) {
+        int64_t coef0 = *as_const_int(coef[0]);
+        if (coef0 == 0) {
+          // zero polarity, straight to next_ineq_set
+          AddInequality(next_ineq_set, ineq, analyzer);
+        } else if (coef0 > 0) {
+          coef_pos->push_back({coef0, coef[1]});
+        } else if (coef0 < 0) {
+          coef_neg->push_back({coef0, coef[1]});
+        }
+        continue;
+      }
+    } else if (const EQNode* eq = ineq.as<EQNode>()) {
+      Array<PrimExpr> coef = arith::DetectLinearEquation(eq->a, {var});
+      if (!coef.empty() && is_const(coef[0])) {
+        int64_t coef0 = *as_const_int(coef[0]);
+        if (coef0 == 0) {
+          // zero polarity, straight to next_ineq_set
+          AddInequality(next_ineq_set, ineq, analyzer);
+        } else if (coef0 > 0) {
+          // Equalities may be considered as pairs of two inequalities
+          coef_pos->push_back({coef0, coef[1]});
+          coef_neg->push_back({-coef0, -coef[1]});
+        } else if (coef0 < 0) {
+          coef_pos->push_back({-coef0, -coef[1]});
+          coef_neg->push_back({coef0, coef[1]});
+        }
+        continue;
+      }
+    }
+
+    // if nothing worked, put it in rest
+    rest->push_back(ineq);
+  }
+}
+
+void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* upper_bounds,
+                  std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* lower_bounds,
+                  std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* equalities) {
+  // those exist in both upper & lower bounds will be moved to equalities
+  for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
+    auto lb = lower_bounds->find(*ub);
+    if (lb != lower_bounds->end()) {
+      equalities->insert(*lb);
+      lower_bounds->erase(lb);
+      ub = upper_bounds->erase(ub);
+    } else {
+      ++ub;
+    }
+  }
+}
+
+PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& 
system_to_solve) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(system_to_solve->ranges);
+
+  // The algorithm consists in doing the following things for each variable v
+  // - Take formulas from `current_ineq_set_to_solve` and
+  //   classify them according to polarity wrt v.
+  // - Combine each formula of positive polarity (wrt v)
+  //   with each formula of negative polarity.
+  // - Put the resulting combinations into `next_ineq_set_to_solve`
+  //   along with unclassifiable formulas.
+  // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve`
+  //   and move to the next variable.
+
+  // normalized inequality
+  std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> 
current_ineq_set_to_solve;
+  std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> 
next_ineq_set_to_solve;
+  // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + 
e <= 0
+  std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
+  // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + 
e <= 0
+  std::vector<std::pair<int64_t, PrimExpr>> coef_neg;
+
+  // formulas we don't know what to do with
+  std::vector<PrimExpr> rest;
+
+  // Simplify each inequality into the form `expr <= 0` and add to current 
formulas
+  for (const PrimExpr& ineq : system_to_solve->relations) {
+    AddInequality(&current_ineq_set_to_solve, 
NormalizeComparisons()(analyzer.Simplify(ineq, 3)),
+                  &analyzer);
+  }
+
+  Map<Var, IntGrpBounds> res_bounds;
+  for (const Var& v : system_to_solve->variables) {
+    CHECK(!res_bounds.count(v))
+        << "Variable " << v
+        << " appears more than one time in the `variables` which might be a 
bug";
+
+    next_ineq_set_to_solve.clear();
+    coef_pos.clear();
+    coef_neg.clear();
+
+    // Add bounds from vranges
+    if (system_to_solve->ranges.count(v)) {
+      const Range& range = system_to_solve->ranges[v];
+      PrimExpr range_lbound = analyzer.Simplify(range->min, 3);
+      PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 
1, 3);
+      coef_neg.push_back({-1, range_lbound});
+      coef_pos.push_back({1, -range_ubound});
+    }
+
+    ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, 
&rest, &coef_pos,
+                       &coef_neg, &analyzer);
+
+    // Combine each positive inequality with each negative one (by adding them 
together)
+    int64_t gcd_x, gcd_y;
+    for (const auto& pos : coef_pos) {
+      for (const auto& neg : coef_neg) {
+        auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, 
&gcd_y);
+        PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd);
+        PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd);
+        // eliminate the current variable
+        PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second;
+        PrimExpr new_ineq = LENode::make(new_lhs, 
make_zero(pos.second.dtype()));
+        // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
+        // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y 
- 5 <= 0
+        // with steps = 2 it's (y*2) - 10 <= 0
+        new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3));
+        AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
+      }
+    }
+
+    // Now we have to generate resulting (in)equalities for the variable v
+
+    // Find the common denominator in a sense
+    // We will generate formulas of the form coef_lcm*v <= bound
+    int64_t coef_lcm = 1;
+    for (const auto& pos : coef_pos) {
+      coef_lcm = LeastCommonMultiple(coef_lcm, pos.first);
+    }
+    for (const auto& neg : coef_neg) {
+      coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first);
+    }
+
+    // The resulting lower and upper bounds stored in sorted vectors
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
+    upper_bounds.reserve(coef_pos.size());
+    lower_bounds.reserve(coef_neg.size());
+
+    for (const auto& pos : coef_pos) {
+      PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * 
pos.second;
+      bound = analyzer.Simplify(bound, 3);
+      // Don't add if any of the existing bounds is better
+      if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
+                      [&bound, &analyzer](const PrimExpr& o) {
+                        return analyzer.CanProve(o - bound <= 0);
+                      })) {
+        continue;
+      }
+      // Erase all worse bounds
+      for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) {
+        if (analyzer.CanProve(*iter - bound >= 0)) {
+          iter = upper_bounds.erase(iter);
+        } else {
+          ++iter;
+        }
+      }
+      // Add the upper bound
+      upper_bounds.insert(bound);
+    }
+    for (const auto& neg : coef_neg) {
+      PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * 
neg.second;
+      bound = analyzer.Simplify(bound, 3);
+      // Don't add if any of the existing bounds is better
+      if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
+                      [&bound, &analyzer](const PrimExpr& o) {
+                        return analyzer.CanProve(o - bound >= 0);
+                      })) {
+        continue;
+      }
+      // Erase all worse bounds
+      for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) {
+        if (analyzer.CanProve(*iter - bound <= 0)) {
+          iter = lower_bounds.erase(iter);
+        } else {
+          ++iter;
+        }
+      }
+      // Add the lower bound
+      lower_bounds.insert(bound);
+    }
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
+    equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
+    MoveEquality(&upper_bounds, &lower_bounds, &equal);
+    std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
+    std::sort(equal_list.begin(), equal_list.end(), ExprLess());
+
+    // Write it to the result.
+    IntGrpBounds bnds(make_const(v.dtype(), coef_lcm),
+                      Array<PrimExpr>(lower_bounds.begin(), 
lower_bounds.end()),
+                      Array<PrimExpr>(equal_list.begin(), equal_list.end()),
+                      Array<PrimExpr>(upper_bounds.begin(), 
upper_bounds.end()));

Review comment:
       I think as long as I'm using unordered_set to do the set operations this 
will be fine. 
   And `FindBestBound` enumerates all the pairs thus order should not matter.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to