jcf94 commented on a change in pull request #6529:
URL: https://github.com/apache/incubator-tvm/pull/6529#discussion_r493160895



##########
File path: src/auto_scheduler/search_policy/utils.cc
##########
@@ -414,19 +414,52 @@ void PruneInvalidState(const SearchTask& task, 
Array<State>* states) {
   }
 }
 
+/********** SplitFactorizationMemo **********/
+
+void SplitFactorizationMemo::ReadWriteLock::GetRead() {

Review comment:
       Updated the doc string as well as added the comments.

##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -332,29 +333,45 @@ Array<State> SketchPolicyNode::GenerateSketches() {
 }
 
 Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& 
sketches, int out_size) {
-  int fail_ct = 0;
+  std::atomic<int> fail_ct(0);
   Array<State> out_states;
+  std::vector<std::mt19937> rand_seeds;
+  rand_seeds.reserve(out_size);
+  for (int i = 0; i < out_size; i++) {
+    rand_seeds.push_back(std::mt19937(rand_gen()));
+  }
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
   while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) 
{
-    // Random choose a starting sketch
-    // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility 
for they may have
-    // different potential on generating state with better performance
-    State tmp_s = sketches[(rand_gen)() % sketches.size()];
-
-    // Derivation rule based enumeration
-    bool valid = true;
-    for (const auto& rule : init_rules) {
-      if (rule->Apply(this, &tmp_s) == 
PopulationGenerationRule::ResultKind::kInvalid) {
-        valid = false;
-        break;
+    std::vector<State> temp_states(out_size);
+
+    support::parallel_for(0, out_size - out_states.size(),
+        [this, &temp_states, &sketches, &rand_seeds, &fail_ct](int index) {
+      // Random choose a starting sketch
+      // TODO(jcf94, merrymercy): Maybe choose sketches in different 
possibility for they may have
+      // different potential on generating state with better performance

Review comment:
       Yes, this comment is just for the `State tmp_s = 
sketches[(rand_seeds[index])() % sketches.size()];`, which randomly choose a 
sketch as the start point of this InitPopulation thread.

##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -334,28 +335,44 @@ Array<State> SketchPolicyNode::GenerateSketches() {
 Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& 
sketches, int out_size) {
   int fail_ct = 0;
   Array<State> out_states;
+  std::vector<std::mt19937> rand_gens;
+  rand_gens.reserve(out_size);
+  for (int i = 0; i < out_size; i++) {
+    rand_gens.push_back(std::mt19937(rand_gen()));
+  }
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
   while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) 
{
-    // Random choose a starting sketch
-    // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility 
for they may have
-    // different potential on generating state with better performance
-    State tmp_s = sketches[(rand_gen)() % sketches.size()];
-
-    // Derivation rule based enumeration
-    bool valid = true;
-    for (const auto& rule : init_rules) {
-      if (rule->Apply(this, &tmp_s) == 
PopulationGenerationRule::ResultKind::kInvalid) {
-        valid = false;
-        break;
+    std::vector<State> temp_states(out_size);
+
+    support::parallel_for(0, out_size - out_states.size(),
+                          [this, &temp_states, &sketches, &rand_gens](int 
index) {
+                            // Random choose a starting sketch
+                            // TODO(jcf94, merrymercy): Maybe choose sketches 
in different
+                            // possibility for they may have different 
potential on generating state
+                            // with better performance
+                            State tmp_s = sketches[(rand_gens[index])() % 
sketches.size()];
+                            // Derivation rule based enumeration
+                            bool valid = true;
+                            for (const auto& rule : init_rules) {
+                              if (rule->Apply(this, &tmp_s, &rand_gens[index]) 
==
+                                  
PopulationGenerationRule::ResultKind::kInvalid) {
+                                valid = false;
+                                break;
+                              }
+                            }
+                            if (valid) {
+                              temp_states[index] = std::move(tmp_s);
+                            }
+                          });
+
+    for (int i = 0; i < out_size; i++) {
+      if (temp_states[i].defined()) {
+        out_states.push_back(std::move(temp_states[i]));

Review comment:
       This part is outside of the `parallel_for` block, it is fine.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to