This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00257f3 [Autodiff] Deterministic gradient compute (#7321)
00257f3 is described below
commit 00257f347faad0b3ec2e9624413015bef34d451f
Author: Haozheng Fan <[email protected]>
AuthorDate: Thu Jan 28 08:32:04 2021 +0800
[Autodiff] Deterministic gradient compute (#7321)
* fix unstable compute
* fix
* fix
* lint
* sort linear equation
* sort inequalities
* fix
* fix find
* lint
* fix find
* lint
---
src/arith/solve_linear_equation.cc | 9 +++---
src/arith/solve_linear_inequality.cc | 54 ++++++++++++++++++------------------
src/te/autodiff/ad_simplify.cc | 26 +++++++++--------
3 files changed, 46 insertions(+), 43 deletions(-)
diff --git a/src/arith/solve_linear_equation.cc
b/src/arith/solve_linear_equation.cc
index 22bf736..d66e75d 100644
--- a/src/arith/solve_linear_equation.cc
+++ b/src/arith/solve_linear_equation.cc
@@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const
IntConstraints& system_to_sol
// We have to transform ranges of the old variables into relations over new
variables because
// new ranges are not enough usually.
- for (const auto& p : system_to_solve->ranges) {
- const Var& old_var = p.first;
- const Range& old_range = p.second;
- if (old_to_new_map.count(old_var)) {
- PrimExpr express_by_new_vars = old_to_new_map[old_var];
+ for (const auto& old_var : system_to_solve->variables) {
+ if (system_to_solve->ranges.find(old_var) !=
system_to_solve->ranges.end()) {
+ const Range& old_range = system_to_solve->ranges.at(old_var);
+ PrimExpr express_by_new_vars = old_to_new_map.at(old_var);
PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <=
express_by_new_vars);
PrimExpr upper_cond =
analyzer_solution.Simplify(express_by_new_vars < old_range->min +
old_range->extent);
diff --git a/src/arith/solve_linear_inequality.cc
b/src/arith/solve_linear_inequality.cc
index f4de9ff..dd90448 100644
--- a/src/arith/solve_linear_inequality.cc
+++ b/src/arith/solve_linear_inequality.cc
@@ -94,11 +94,10 @@ struct ExprLess {
}
};
-void DebugPrint(
- const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>&
current_ineq_set,
- const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>&
next_ineq_set,
- const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t,
PrimExpr>>& coef_pos,
- const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
+void DebugPrint(const std::vector<PrimExpr>& current_ineq_set,
+ const std::vector<PrimExpr>& next_ineq_set, const
std::vector<PrimExpr>& rest,
+ const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
+ const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
std::cout << "Current ineq set:\n[";
for (auto& ineq : current_ineq_set) {
std::cout << ineq << ", ";
@@ -148,9 +147,12 @@ class NormalizeComparisons : public ExprMutator {
arith::Analyzer analyzer_;
};
-void AddInequality(std::unordered_set<PrimExpr, StructuralHash,
StructuralEqual>* inequality_set,
- const PrimExpr& new_ineq, Analyzer* analyzer) {
- if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) !=
inequality_set->end()) {
+void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr&
new_ineq,
+ Analyzer* analyzer) {
+ if (analyzer->CanProve(new_ineq) ||
+ std::find_if(inequality_set->begin(), inequality_set->end(), [&](const
PrimExpr& e) {
+ return StructuralEqual()(e, new_ineq);
+ }) != inequality_set->end()) {
// redundant: follows from the vranges
// or has already been added
return;
@@ -168,15 +170,13 @@ void AddInequality(std::unordered_set<PrimExpr,
StructuralHash, StructuralEqual>
}
}
- inequality_set->insert(new_ineq);
+ inequality_set->push_back(new_ineq);
}
-void ClassifyByPolarity(
- const Var& var,
- const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>&
current_ineq_set,
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>*
next_ineq_set,
- std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>*
coef_pos,
- std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
+void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>&
current_ineq_set,
+ std::vector<PrimExpr>* next_ineq_set,
std::vector<PrimExpr>* rest,
+ std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
+ std::vector<std::pair<int64_t, PrimExpr>>* coef_neg,
Analyzer* analyzer) {
// Take formulas from current_ineq_set and classify them according to
polarity wrt var
// and store to coef_pos and coef_neg respectively.
for (const PrimExpr& ineq : current_ineq_set) {
@@ -218,14 +218,14 @@ void ClassifyByPolarity(
}
}
-void MoveEquality(std::unordered_set<PrimExpr, StructuralHash,
StructuralEqual>* upper_bounds,
- std::unordered_set<PrimExpr, StructuralHash,
StructuralEqual>* lower_bounds,
- std::unordered_set<PrimExpr, StructuralHash,
StructuralEqual>* equalities) {
+void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>*
lower_bounds,
+ std::vector<PrimExpr>* equalities) {
// those exist in both upper & lower bounds will be moved to equalities
for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
- auto lb = lower_bounds->find(*ub);
+ auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(),
+ [&](const PrimExpr& e) { return
StructuralEqual()(e, *ub); });
if (lb != lower_bounds->end()) {
- equalities->insert(*lb);
+ equalities->push_back(*lb);
lower_bounds->erase(lb);
ub = upper_bounds->erase(ub);
} else {
@@ -249,8 +249,8 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
// and move to the next variable.
// normalized inequality
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>
current_ineq_set_to_solve;
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>
next_ineq_set_to_solve;
+ std::vector<PrimExpr> current_ineq_set_to_solve;
+ std::vector<PrimExpr> next_ineq_set_to_solve;
// A vector of pairs (c, e), c > 0, representing formulas of the form c*v +
e <= 0
std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
// A vector of pairs (c, e), c < 0, representing formulas of the form c*v +
e <= 0
@@ -321,8 +321,8 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
}
// The resulting lower and upper bounds
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
+ std::vector<PrimExpr> upper_bounds;
+ std::vector<PrimExpr> lower_bounds;
upper_bounds.reserve(coef_pos.size());
lower_bounds.reserve(coef_neg.size());
@@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
}
}
// Add the upper bound
- upper_bounds.insert(bound);
+ upper_bounds.push_back(bound);
}
for (const auto& neg : coef_neg) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) *
neg.second;
@@ -366,10 +366,10 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
}
}
// Add the lower bound
- lower_bounds.insert(bound);
+ lower_bounds.push_back(bound);
}
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
+ std::vector<PrimExpr> equal;
equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
MoveEquality(&upper_bounds, &lower_bounds, &equal);
std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc
index cc0e820..96f278e 100644
--- a/src/te/autodiff/ad_simplify.cc
+++ b/src/te/autodiff/ad_simplify.cc
@@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor
auto res_b = VisitExpr(op->b);
// For the And case we return the union of the sets of atomic formulas
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
- res_set.reserve(res_a.atomic_formulas.size() +
res_b.atomic_formulas.size());
+ std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set;
+ res_a_set.reserve(res_a.atomic_formulas.size());
std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
- std::inserter(res_set, res_set.end()));
- std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
- std::inserter(res_set, res_set.end()));
-
- std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+ std::inserter(res_a_set, res_a_set.end()));
+ std::vector<PrimExpr> res = res_a.atomic_formulas;
+ for (const auto& e : res_b.atomic_formulas) {
+ if (res_a_set.find(e) == res_a_set.end()) {
+ res.emplace_back(e);
+ }
+ }
// And the residuals are combined with &&
return {res, res_a.rest && res_b.rest};
}
@@ -443,10 +445,13 @@ class FactorOutAtomicFormulasFunctor
// For the Or case we intersect the sets of atomic formulas
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+ std::vector<PrimExpr> res;
res_set.reserve(std::min(res_a.atomic_formulas.size(),
res_b.atomic_formulas.size()));
- for (const auto& res_b_formula : res_b_set) {
+ res.reserve(std::min(res_a.atomic_formulas.size(),
res_b.atomic_formulas.size()));
+ for (const auto& res_b_formula : res_b.atomic_formulas) {
if (res_a_set.count(res_b_formula)) {
res_set.insert(res_b_formula);
+ res.push_back(res_b_formula);
}
}
@@ -454,13 +459,13 @@ class FactorOutAtomicFormulasFunctor
// which are left behind, and then combine them with the residuals into
the new residual.
std::vector<PrimExpr> new_cond_a;
new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
- for (const auto& formula : res_a_set) {
+ for (const auto& formula : res_a.atomic_formulas) {
if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
}
std::vector<PrimExpr> new_cond_b;
new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
- for (const auto& formula : res_b_set) {
+ for (const auto& formula : res_b.atomic_formulas) {
if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
}
@@ -468,7 +473,6 @@ class FactorOutAtomicFormulasFunctor
res_b.atomic_formulas = std::move(new_cond_b);
PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
- std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
return {res, new_rest};
}