jcf94 commented on a change in pull request #6310:
URL: https://github.com/apache/incubator-tvm/pull/6310#discussion_r473530560



##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -327,8 +335,150 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const 
Array<State>& init_popul
   Array<State> best_states;
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
-  // TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the 
cost model part
-  // yet, currently delete the implementation of EvolutionarySearch. To be 
added later.
+  size_t population = init_population.size();
+  int num_iters =
+      static_cast<int>(GetIntParam(params, 
SketchParamKey::EvolutionarySearch::num_iters));
+  double mutation_prob = static_cast<double>(
+      GetDoubleParam(params, 
SketchParamKey::EvolutionarySearch::mutation_prob));
+
+  // Two ping pong buffers to avoid copy.
+  Array<State> states_buf1{init_population}, states_buf2;
+  states_buf1.reserve(population);
+  states_buf2.reserve(population);
+  Array<State>* pnow = &states_buf1;
+  Array<State>* pnext = &states_buf2;
+
+  // The set of explored states to avoid redendants.
+  std::unordered_set<std::string> explored_set;
+
+  // The heap to maintain the so far best states.
+  using StateHeapItem = std::pair<State, float>;
+  auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
+    return left.second > right.second;
+  };
+  using StateHeap = std::priority_queue<StateHeapItem, 
std::vector<StateHeapItem>, decltype(cmp)>;
+  StateHeap heap(cmp);
+  auto update_heap = [&heap, &explored_set](const Array<State>& states,
+                                            const std::vector<float>& scores, 
const int out_size) {
+    float max_score = 0.0;
+    for (size_t i = 0; i < states.size(); ++i) {
+      const State& state = states[i];
+      std::string state_str = state.ToStr();
+
+      // Skip redundant states.
+      if (explored_set.count(state_str) > 0) {
+        continue;
+      }
+      explored_set.insert(state_str);
+
+      if (static_cast<int>(heap.size()) < out_size) {
+        // Directly push item if the heap is not full yet.
+        heap.push({state, scores[i]});
+      } else if (scores[i] > heap.top().second) {
+        // Replace the worst state in the heap with the new state.
+        heap.pop();
+        heap.push({state, scores[i]});
+      }
+      max_score = (scores[i] > max_score) ? scores[i] : max_score;
+    }
+    return max_score;
+  };
+
+  // Cost model predicted scores.
+  std::vector<float> scores;
+  scores.reserve(population);
+
+  // The function to generate prefix sum probabilities based on the given 
scores.
+  auto assign_prob = [](const std::vector<float>& scores, std::vector<double>* 
prefix_sum_probs) {
+    // Compute selection probabilities.
+    double sum = 0.0;
+    prefix_sum_probs->resize(scores.size());
+    for (size_t i = 0; i < scores.size(); ++i) {
+      sum += std::max(scores[i], 0.0f);
+      (*prefix_sum_probs)[i] = sum;
+    }
+    for (size_t i = 0; i < scores.size(); ++i) {
+      (*prefix_sum_probs)[i] /= sum;
+    }
+  };
+
+  // State selection probabilities.
+  std::uniform_real_distribution<> uniform_dist(0.0, 1.0);
+  std::vector<double> state_select_probs;
+  state_select_probs.reserve(population);
+
+  // Mutation rule selection probabilities.
+  std::vector<double> rule_select_probs;
+  rule_select_probs.reserve(mutation_rules.size());
+  std::vector<float> rule_levels;
+  for (const auto& rule : mutation_rules) {
+    rule_levels.push_back(rule->GetLevel());
+  }
+  assign_prob(rule_levels, &rule_select_probs);
+
+  // Evaluate the init populations.
+  search_task->compute_dag.InferBound(*pnow);

Review comment:
       ```suggestion
     *pnow = search_task->compute_dag.InferBound(*pnow);
   ```
   Pay attention that this implementation currently does not modifiy the input 
state inplace, but to return the result state.

##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.h
##########
@@ -201,6 +201,26 @@ class InitVectorization : public InitPopulationRule {
   ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
 };
 
+/********** Mutation **********/
+
+/*! \brief The base class for mutation rules used in the evolutionary search. 
*/
+class MutationRule : public InitPopulationRule {

Review comment:
       Inherit from this class seems strange, we can just have an independent 
class.
   
   The updated implementation in the last PR is trying to use a macro 
definition format:
   
https://github.com/apache/incubator-tvm/pull/6269/files#diff-27486a3e525b590252f09c3fca9d2718




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to