merrymercy commented on a change in pull request #6512:
URL: https://github.com/apache/incubator-tvm/pull/6512#discussion_r491325191
##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -390,135 +385,102 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const
Array<State>& init_popul
Array<State>* pnow = &states_buf1;
Array<State>* pnext = &states_buf2;
- // The set of explored states to avoid redundancy.
- std::unordered_set<std::string> explored_set;
-
- // The heap to maintain the so far best states.
+ // A heap to keep the best states during evolution
using StateHeapItem = std::pair<State, float>;
auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
return left.second > right.second;
};
- using StateHeap = std::priority_queue<StateHeapItem,
std::vector<StateHeapItem>, decltype(cmp)>;
- StateHeap heap(cmp);
- auto update_heap = [&heap, &explored_set](const Array<State>& states,
- const std::vector<float>& scores,
const int out_size) {
- float max_score = 0.0;
- for (size_t i = 0; i < states.size(); ++i) {
- const State& state = states[i];
+ std::vector<StateHeapItem> heap;
+ std::unordered_set<std::string> in_heap(measured_states_set_);
+ heap.reserve(out_size);
+
+ // auxiliary global variables
+ std::vector<float> pop_scores;
+ std::vector<double> pop_selection_probs;
+ float max_score = 0.0;
+ pop_scores.reserve(population);
+ pop_selection_probs.reserve(population);
+ std::uniform_real_distribution<> dis(0.0, 1.0);
+
+ // mutation rules
+ int mutation_success_ct, mutation_fail_ct;
+ mutation_success_ct = mutation_fail_ct = 0;
+ std::vector<float> rule_weights;
+ std::vector<double> rule_selection_probs;
+ for (const auto& rule : mutation_rules) {
+ rule_weights.push_back(rule->weight);
+ }
+ ComputePrefixSumProb(rule_weights, &rule_selection_probs);
+
+ // Genetic Algorithm
+ for (int k = 0; k < num_iters + 1; ++k) {
+ // Maintain the heap
+ *pnow = search_task->compute_dag.InferBound(*pnow);
+ PruneInvalidState(search_task, pnow);
Review comment:
Yes, I added the check inside `PruneInvalidState`. If all states are
invalid, we assume it is an internal error and we can directly kill the program.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]