This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new eee04c0  [ANSOR] Auto-scheduler tutorial for GPU and necessary 
refactor/fix (#6512)
eee04c0 is described below

commit eee04c089caf2b31e44709b6ca50dfa6e3c391a4
Author: Lianmin Zheng <[email protected]>
AuthorDate: Sat Sep 19 15:38:29 2020 -0700

    [ANSOR] Auto-scheduler tutorial for GPU and necessary refactor/fix (#6512)
    
    * add gpu tutorial
    
    * refactor mutation in evolutionary search
    
    * update
    
    * update double matmul
    
    * fix lint
    
    * add double matmul test
    
    * fix mutate compute location
    
    * fix sketch search policy
    
    * fix lint
    
    * update
    
    * address comments
    
    * fix PruneInvalidStates
---
 docs/api/python/auto_scheduler.rst                 |  15 ++
 include/tvm/auto_scheduler/search_policy.h         |   3 +-
 python/tvm/auto_scheduler/measure.py               |  14 +-
 python/tvm/auto_scheduler/search_policy.py         |   6 +-
 python/tvm/auto_scheduler/workload_registry.py     |  16 +-
 python/tvm/micro/session.py                        |   6 +-
 src/auto_scheduler/search_policy/sketch_policy.cc  | 266 +++++++++------------
 src/auto_scheduler/search_policy/sketch_policy.h   |  16 +-
 .../search_policy/sketch_policy_rules.cc           | 193 +++++++--------
 .../search_policy/sketch_policy_rules.h            |  63 +++--
 src/auto_scheduler/search_policy/utils.cc          |  83 ++++++-
 src/auto_scheduler/search_policy/utils.h           |  60 ++++-
 src/auto_scheduler/transform_step.cc               |   4 +-
 .../python/unittest/test_auto_scheduler_common.py  |  13 +
 .../test_auto_scheduler_evolutionary_search.py     |   2 +-
 .../unittest/test_auto_scheduler_search_policy.py  |   2 +-
 .../test_auto_scheduler_sketch_generation.py       |  46 ++--
 ...une_matmul_x86.py => tune_conv2d_layer_cuda.py} | 152 ++++++------
 tutorials/auto_scheduler/tune_matmul_x86.py        |  50 ++--
 19 files changed, 578 insertions(+), 432 deletions(-)

diff --git a/docs/api/python/auto_scheduler.rst 
b/docs/api/python/auto_scheduler.rst
index 85ff22f..a7c190a 100644
--- a/docs/api/python/auto_scheduler.rst
+++ b/docs/api/python/auto_scheduler.rst
@@ -31,5 +31,20 @@ tvm.auto_scheduler.auto_schedule
 
 .. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule
 
+tvm.auto_scheduler.workload_registry
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
+.. autofunction:: tvm.auto_scheduler.workload_registry.register_workload
 
+
+tvm.auto_scheduler.measure
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. automodule:: tvm.auto_scheduler.measure
+
+.. autoclass:: tvm.auto_scheduler.measure.LocalRPCMeasureContext
+
+.. autoclass:: tvm.auto_scheduler.measure.LocalRunner
+
+.. autoclass:: tvm.auto_scheduler.measure.LocalBuilder
+
+.. autoclass:: tvm.auto_scheduler.measure.RPCRunner
diff --git a/include/tvm/auto_scheduler/search_policy.h 
b/include/tvm/auto_scheduler/search_policy.h
index 176b10c..ddb0dd2 100755
--- a/include/tvm/auto_scheduler/search_policy.h
+++ b/include/tvm/auto_scheduler/search_policy.h
@@ -65,6 +65,7 @@
 #include <tvm/auto_scheduler/search_task.h>
 #include <tvm/node/node.h>
 
+#include <string>
 #include <unordered_set>
 #include <vector>
 
@@ -191,7 +192,7 @@ class SearchPolicyNode : public Object {
    * We store the string format of a state for redundancy check. This is used 
to make sure a
    * measured state will never be measured again.
    */
-  std::unordered_set<String> measured_states_set_;
+  std::unordered_set<std::string> measured_states_set_;
   /*! \brief The array of already measured states.
    *  The good states can be used as the initial population in evolutionary 
search. */
   std::vector<State> measured_states_vector_;
diff --git a/python/tvm/auto_scheduler/measure.py 
b/python/tvm/auto_scheduler/measure.py
index c57b39b..eebccf4 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""
+r"""
 Distributed measurement infrastructure to measure the runtime costs of tensor 
programs.
 
 These functions are responsible for building the tvm module, uploading it to
@@ -25,8 +25,8 @@ We separate the measurement into two steps: build and run.
 A builder builds the executable binary files and a runner runs the binary 
files to
 get the measurement results. The flow of data structures is
 
-                `ProgramBuilder`                 `ProgramRunner`
-`MeasureInput` -----------------> `BuildResult` ----------------> 
`MeasureResult`
+  .                `ProgramBuilder`                 `ProgramRunner`
+  `MeasureInput` -----------------> `BuildResult` ----------------> 
`MeasureResult`
 
 We implement these in python to utilize python's multiprocessing and error 
handling.
 """
@@ -222,7 +222,7 @@ class LocalRunner(ProgramRunner):
         where the first "1" is warm up and will be discarded.
         The returned result contains `repeat` costs,
         each of which is an average of `number` costs.
-    min_repeat_ms : int = 0
+    min_repeat_ms : int = 100
         The minimum duration of one `repeat` in milliseconds.
         By default, one `repeat` contains `number` runs. If this parameter is 
set,
         the parameters `number` will be dynamically adjusted to meet the
@@ -244,7 +244,7 @@ class LocalRunner(ProgramRunner):
         timeout=10,
         number=3,
         repeat=1,
-        min_repeat_ms=0,
+        min_repeat_ms=100,
         cooldown_interval=0.0,
         enable_cpu_cache_flush=False,
     ):
@@ -289,7 +289,7 @@ class RPCRunner(ProgramRunner):
         where the first "1" is warm up and will be discarded.
         The returned result contains `repeat` costs,
         each of which is an average of `number` costs.
-    min_repeat_ms : int = 0
+    min_repeat_ms : int = 100
         The minimum duration of one `repeat` in milliseconds.
         By default, one `repeat` contains `number` runs. If this parameter is 
set,
         the parameters `number` will be dynamically adjusted to meet the
@@ -316,7 +316,7 @@ class RPCRunner(ProgramRunner):
         timeout=10,
         number=3,
         repeat=1,
-        min_repeat_ms=0,
+        min_repeat_ms=100,
         cooldown_interval=0.0,
         enable_cpu_cache_flush=False,
     ):
diff --git a/python/tvm/auto_scheduler/search_policy.py 
b/python/tvm/auto_scheduler/search_policy.py
index a9d3236..6e278ae 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -91,7 +91,7 @@ class SketchPolicy(SearchPolicy):
     ----------
     task : SearchTask
         The SearchTask for the computation declaration.
-    schedule_cost_model : CostModel = RandomModel()
+    program_cost_model : CostModel = RandomModel()
         The cost model to estimate the complete schedules.
     params : Optional[Dict[str, Any]]
         Parameters of the search policy.
@@ -129,7 +129,7 @@ class SketchPolicy(SearchPolicy):
     def __init__(
         self,
         task,
-        schedule_cost_model=RandomModel(),
+        program_cost_model=RandomModel(),
         params=None,
         seed=None,
         verbose=1,
@@ -145,7 +145,7 @@ class SketchPolicy(SearchPolicy):
         self.__init_handle_by_constructor__(
             _ffi_api.SketchPolicy,
             task,
-            schedule_cost_model,
+            program_cost_model,
             params,
             seed or random.randint(1, 1 << 30),
             verbose,
diff --git a/python/tvm/auto_scheduler/workload_registry.py 
b/python/tvm/auto_scheduler/workload_registry.py
index f0c8398..1d9ee6d 100644
--- a/python/tvm/auto_scheduler/workload_registry.py
+++ b/python/tvm/auto_scheduler/workload_registry.py
@@ -55,13 +55,15 @@ def register_workload(func_name, f=None, override=False):
 
     Examples
     --------
-    @auto_scheduler.register_workload
-    def matmul(N, M, K):
-        A = te.placeholder((N, K), name='A')
-        B = te.placeholder((K, M), name='B')
-        k = te.reduce_axis((0, K), name='k')
-        C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], 
axis=[k]), name='C')
-        return [A, B, C]
+    .. code-block:: python
+
+      @auto_scheduler.register_workload
+      def matmul(N, M, K):
+          A = te.placeholder((N, K), name='A')
+          B = te.placeholder((K, M), name='B')
+          k = te.reduce_axis((0, K), name='k')
+          C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], 
axis=[k]), name='C')
+          return [A, B, C]
     """
     global WORKLOAD_FUNC_REGISTRY
 
diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py
index 084f467..3565040 100644
--- a/python/tvm/micro/session.py
+++ b/python/tvm/micro/session.py
@@ -22,10 +22,14 @@ import time
 
 from .._ffi import get_global_func
 from ..contrib import graph_runtime
-from .base import _rpc_connect
 from ..rpc import RPCSession
 from .transport import TransportLogger
 
+try:
+    from .base import _rpc_connect
+except ImportError:
+    raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in 
config.cmake")
+
 
 class Session:
     """MicroTVM Device Session
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc 
b/src/auto_scheduler/search_policy/sketch_policy.cc
index ffc0094..6b4b6ae 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -31,6 +31,7 @@
 #include <algorithm>
 #include <iomanip>
 #include <limits>
+#include <memory>
 #include <queue>
 #include <set>
 #include <string>
@@ -45,7 +46,6 @@ namespace tvm {
 namespace auto_scheduler {
 
 /********** Sketch generation rules **********/
-
 static RuleSkipStage rule_skip_stage;
 static RuleAlwaysInline rule_always_inline;
 static RuleMultiLevelTiling rule_multi_level_tiling;
@@ -58,7 +58,6 @@ static RuleSimplifyComputeWithConstTensor 
rule_simplify_compute_with_const_tenso
 static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
 
 /********** Init population rules **********/
-
 static InitFillTileSize init_fill_tile_size;
 static InitChangeComputeLocation init_change_compute_location;
 static InitParallel init_parallel;
@@ -66,23 +65,15 @@ static InitUnroll init_unroll;
 static InitVectorization init_vectorization;
 static InitThreadBind init_thread_bind;
 
-/********** Mutation rules **********/
-
-static MutateTileSize mutate_tile_size;
-static MutateMaxUnrollFactor mutate_max_unroll_factor;
-static MutateComputeLocation mutate_compute_location;
-static MutateParallel mutate_parallel;
-
 /********** Sketch policy **********/
-
 TVM_REGISTER_NODE_TYPE(SketchPolicyNode);
 
-SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,
+SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
                            Map<String, ObjectRef> params, int seed, int 
verbose,
                            Optional<Array<SearchCallback>> 
init_search_callbacks) {
   auto node = make_object<SketchPolicyNode>();
   node->search_task = std::move(task);
-  node->schedule_cost_model = std::move(schedule_cost_model);
+  node->program_cost_model = std::move(program_cost_model);
   node->rand_gen = std::mt19937(seed);
   node->params = std::move(params);
   node->verbose = verbose;
@@ -97,18 +88,32 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel 
schedule_cost_model,
     node->RunCallbacks(init_search_callbacks.value());
   }
 
-  // Notice: Some rules require us to skip all the rest rules after they are 
applied.
-  // So the rules below should be ordered carefully.
+  // NOTE: There are strong dependency among the rules below,
+  // so the order to push them into the vector should be considered carefully.
   if (IsCPUTask(node->search_task)) {
-    // The default sketch rules for CPU policy
+    // Sketch Generation Rules
     node->sketch_rules.push_back(&rule_always_inline);
     node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
     node->sketch_rules.push_back(&rule_add_rfactor);
     node->sketch_rules.push_back(&rule_add_cache_write_stage);
     node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
     node->sketch_rules.push_back(&rule_multi_level_tiling);
-  } else if (IsCUDATask(node->search_task)) {
-    // The default sketch rules for CUDA policy
+    node->sketch_rules.push_back(&rule_skip_stage);
+
+    // Initial Population Generation Rules
+    node->init_rules.push_back(&init_fill_tile_size);
+    node->init_rules.push_back(&init_change_compute_location);
+    node->init_rules.push_back(&init_parallel);
+    node->init_rules.push_back(&init_unroll);
+    node->init_rules.push_back(&init_vectorization);
+
+    // Mutation Rules for Evolutionary Search
+    node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
+    node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04));
+    
node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05));
+    node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
+  } else if (IsGPUTask(node->search_task)) {
+    // Sketch Generation Rules
     node->sketch_rules.push_back(&rule_add_cache_read_stage);
     node->sketch_rules.push_back(&rule_always_inline);
     node->sketch_rules.push_back(&rule_special_compute_location_gpu);
@@ -117,32 +122,20 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel 
schedule_cost_model,
     node->sketch_rules.push_back(&rule_add_cache_write_stage);
     node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
     node->sketch_rules.push_back(&rule_multi_level_tiling);
-  } else {
-    LOG(FATAL) << "No default sketch rules for target: " << task->target;
-  }
-  node->sketch_rules.push_back(&rule_skip_stage);  // This should always be 
the last rule
+    node->sketch_rules.push_back(&rule_skip_stage);
 
-  node->init_rules.push_back(&init_fill_tile_size);  // This should always be 
the first rule
-  if (IsCPUTask(node->search_task)) {
-    // The default init population rules for CPU policy
-    node->init_rules.push_back(&init_change_compute_location);
-    node->init_rules.push_back(&init_parallel);
-    node->init_rules.push_back(&init_unroll);
-    node->init_rules.push_back(&init_vectorization);
-  } else if (IsCUDATask(node->search_task)) {
-    // The default init population rules for CUDA policy
+    // Initial Population Generation Rules
+    node->init_rules.push_back(&init_fill_tile_size);
     node->init_rules.push_back(&init_thread_bind);
     node->init_rules.push_back(&init_unroll);
+
+    // Mutation Rules for Evolutionary Search
+    node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
+    node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
   } else {
-    LOG(FATAL) << "No default init rules for target: " << task->target;
+    LOG(FATAL) << "No default sketch rules for target: " << task->target;
   }
 
-  // The default mutation rules.
-  node->mutation_rules.push_back(&mutate_tile_size);
-  node->mutation_rules.push_back(&mutate_max_unroll_factor);
-  node->mutation_rules.push_back(&mutate_compute_location);
-  node->mutation_rules.push_back(&mutate_parallel);
-
   data_ = std::move(node);
 }
 
@@ -169,7 +162,7 @@ State SketchPolicyNode::Search(int n_trials, int 
early_stopping, int num_measure
       if (!inputs.empty()) {
         // Retrain cost models before the next search round
         PrintTitle("Train cost model", verbose);
-        schedule_cost_model->Update(inputs, results);
+        program_cost_model->Update(inputs, results);
       }
 
       // Search one round to get promising states
@@ -179,9 +172,7 @@ State SketchPolicyNode::Search(int n_trials, int 
early_stopping, int num_measure
 
       // Infer bound. This is necessary for computing the correct ToStr() for 
redundancy check
       best_states = search_task->compute_dag.InferBound(best_states);
-      PruneInvalidState(search_task, &best_states);
       random_states = search_task->compute_dag.InferBound(random_states);
-      PruneInvalidState(search_task, &random_states);
 
       // Pick `num_measure_per_iter` states to measure, check hash to remove 
already measured state
       // Also pick some random states to do eps-greedy
@@ -242,14 +233,16 @@ Array<State> SketchPolicyNode::SearchOneRound(int 
num_random_states, Array<State
                static_cast<int>(
                    GetDoubleParam(params, 
SketchParamKey::EvolutionarySearch::use_measured_ratio) *
                    population));
-  bool is_cost_model_reasonable = 
!schedule_cost_model->IsInstance<RandomModelNode>();
+  bool is_cost_model_reasonable = 
!program_cost_model->IsInstance<RandomModelNode>();
 
   // 1. Generate sketches
-  const Array<State>& sketches = GenerateSketches();
+  if (sketch_cache_.empty()) {
+    sketch_cache_ = GenerateSketches();
+  }
 
   // 2. Sample the init population
   Array<State> init_population = SampleInitPopulation(
-      sketches, is_cost_model_reasonable ? population - num_use_measured : 
population);
+      sketch_cache_, is_cost_model_reasonable ? population - num_use_measured 
: population);
 
   // 3. If the cost model is useless (i.e. RandomCostModel), just random pick 
some generated
   // states, else perform evolutionary search
@@ -260,7 +253,7 @@ Array<State> SketchPolicyNode::SearchOneRound(int 
num_random_states, Array<State
       init_population.push_back(measured_states_vector_[indices[i]]);
     }
     // Sample some random states for eps-greedy
-    *random_states = RandomSampleStates(init_population, &rand_gen, 
num_random_states * 10);
+    *random_states = RandomSampleStates(init_population, &rand_gen, 
num_random_states * 3);
     return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
   } else {
     PruneInvalidState(search_task, &init_population);
@@ -278,7 +271,7 @@ Array<State> SketchPolicyNode::GenerateSketches() {
 
   // A map that maps state to its current working position (stage_id)
   std::unordered_map<State, int, ObjectHash, ObjectEqual> cur_stage_id_map;
-  cur_stage_id_map[init_state] = static_cast<int>(init_state->stages.size() - 
1);
+  cur_stage_id_map[init_state] = static_cast<int>(init_state->stages.size()) - 
1;
 
   // Derivation rule based enumeration
   Array<State> out_states;
@@ -379,7 +372,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const 
Array<State>& init_popul
   Array<State> best_states;
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
-  size_t population = init_population.size();
+  size_t population = GetIntParam(params, 
SketchParamKey::EvolutionarySearch::population);
   int num_iters = GetIntParam(params, 
SketchParamKey::EvolutionarySearch::num_iters);
   double mutation_prob = GetDoubleParam(params, 
SketchParamKey::EvolutionarySearch::mutation_prob);
 
@@ -390,135 +383,102 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const 
Array<State>& init_popul
   Array<State>* pnow = &states_buf1;
   Array<State>* pnext = &states_buf2;
 
-  // The set of explored states to avoid redundancy.
-  std::unordered_set<std::string> explored_set;
-
-  // The heap to maintain the so far best states.
+  // A heap to keep the best states during evolution
   using StateHeapItem = std::pair<State, float>;
   auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
     return left.second > right.second;
   };
-  using StateHeap = std::priority_queue<StateHeapItem, 
std::vector<StateHeapItem>, decltype(cmp)>;
-  StateHeap heap(cmp);
-  auto update_heap = [&heap, &explored_set](const Array<State>& states,
-                                            const std::vector<float>& scores, 
const int out_size) {
-    float max_score = 0.0;
-    for (size_t i = 0; i < states.size(); ++i) {
-      const State& state = states[i];
+  std::vector<StateHeapItem> heap;
+  std::unordered_set<std::string> in_heap(measured_states_set_);
+  heap.reserve(out_size);
+
+  // auxiliary global variables
+  std::vector<float> pop_scores;
+  std::vector<double> pop_selection_probs;
+  float max_score = 0.0;
+  pop_scores.reserve(population);
+  pop_selection_probs.reserve(population);
+  std::uniform_real_distribution<> dis(0.0, 1.0);
+
+  // mutation rules
+  int mutation_success_ct, mutation_fail_ct;
+  mutation_success_ct = mutation_fail_ct = 0;
+  std::vector<float> rule_weights;
+  std::vector<double> rule_selection_probs;
+  for (const auto& rule : mutation_rules) {
+    rule_weights.push_back(rule->weight);
+  }
+  ComputePrefixSumProb(rule_weights, &rule_selection_probs);
+
+  // Genetic Algorithm
+  for (int k = 0; k < num_iters + 1; ++k) {
+    // Maintain the heap
+    *pnow = search_task->compute_dag.InferBound(*pnow);
+    PruneInvalidState(search_task, pnow);
+    program_cost_model->Predict(search_task, *pnow, &pop_scores);
+
+    for (size_t i = 0; i < pnow->size(); ++i) {
+      const State& state = (*pnow)[i];
       std::string state_str = state.ToStr();
 
-      // Skip redundant states.
-      if (explored_set.count(state_str) > 0) {
-        continue;
-      }
-      explored_set.insert(state_str);
-
-      if (static_cast<int>(heap.size()) < out_size) {
-        // Directly push item if the heap is not full yet.
-        heap.push({state, scores[i]});
-      } else if (scores[i] > heap.top().second) {
-        // Replace the worst state in the heap with the new state.
-        heap.pop();
-        heap.push({state, scores[i]});
+      if (in_heap.count(state_str) == 0) {
+        if (static_cast<int>(heap.size()) < out_size) {
+          heap.emplace_back((*pnow)[i], pop_scores[i]);
+          std::push_heap(heap.begin(), heap.end(), cmp);
+          in_heap.insert(state_str);
+        } else if (pop_scores[i] > heap.front().second) {
+          std::string old_state_str = heap.front().first.ToStr();
+          in_heap.erase(old_state_str);
+          in_heap.insert(state_str);
+
+          std::pop_heap(heap.begin(), heap.end(), cmp);
+          heap.back() = StateHeapItem(state, pop_scores[i]);
+          std::push_heap(heap.begin(), heap.end(), cmp);
+        }
+        if (pop_scores[i] > max_score) {
+          max_score = pop_scores[i];
+        }
       }
-      max_score = (scores[i] > max_score) ? scores[i] : max_score;
     }
-    return max_score;
-  };
 
-  // Cost model predicted scores.
-  std::vector<float> scores;
-  scores.reserve(population);
-
-  // The function to generate prefix sum probabilities based on the given 
scores.
-  auto assign_prob = [](const std::vector<float>& scores, std::vector<double>* 
prefix_sum_probs) {
-    // Compute selection probabilities.
-    double sum = 0.0;
-    prefix_sum_probs->resize(scores.size());
-    for (size_t i = 0; i < scores.size(); ++i) {
-      sum += std::max(scores[i], 0.0f);
-      (*prefix_sum_probs)[i] = sum;
+    // Print statistical information
+    if (k % 5 == 0 || k == num_iters) {
+      StdCout(verbose) << "GA Iter: " << k << std::fixed << 
std::setprecision(4)
+                       << "\tMax score: " << max_score << "\tMin score: " << 
heap.front().second
+                       << "\t#Pop: " << pnow->size() << "\t#M+: " << 
mutation_success_ct / (k + 1)
+                       << "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl;
     }
-    for (size_t i = 0; i < scores.size(); ++i) {
-      (*prefix_sum_probs)[i] /= sum;
+    if (k == num_iters) {
+      break;
     }
-  };
 
-  // State selection probabilities.
-  std::uniform_real_distribution<> uniform_dist(0.0, 1.0);
-  std::vector<double> state_select_probs;
-  state_select_probs.reserve(population);
+    // Compute selection probability
+    ComputePrefixSumProb(pop_scores, &pop_selection_probs);
 
-  // Mutation rule selection probabilities.
-  std::vector<double> rule_select_probs;
-  rule_select_probs.reserve(mutation_rules.size());
-  std::vector<float> rule_levels;
-  for (const auto& rule : mutation_rules) {
-    rule_levels.push_back(rule->GetLevel(search_task));
-  }
-  assign_prob(rule_levels, &rule_select_probs);
-
-  // Evaluate the init populations.
-  *pnow = search_task->compute_dag.InferBound(*pnow);
-  PruneInvalidState(search_task, pnow);
-  CHECK_GT(pnow->size(), 0) << "All initial populations are invalid";
-  schedule_cost_model->Predict(search_task, *pnow, &scores);
-
-  // Maintain the best states in the heap.
-  float max_score = update_heap(*pnow, scores, out_size);
-
-  // Genetic algorithm.
-  for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) {
-    // Assign the selection probability to each state based on the cost model 
scores.
-    assign_prob(scores, &state_select_probs);
-
-    // TODO(@comaniac): Perform cross over.
-
-    // Perform mutations.
-    size_t fail_ct = 0;
-    while (pnext->size() < population && fail_ct < population * 2) {
-      // Select a state to be mutated.
-      State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)];
-      if (uniform_dist(rand_gen) < mutation_prob) {
-        // Select a rule and mutate the state.
-        const auto& rule = mutation_rules[RandomChoose(rule_select_probs, 
&rand_gen)];
+    // Do mutation
+    while (pnext->size() < population) {
+      State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)];
+
+      if (dis(rand_gen) < mutation_prob) {
+        const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, 
&rand_gen)];
         if (rule->Apply(this, &tmp_s) == 
PopulationGenerationRule::ResultKind::kValid) {
           pnext->push_back(std::move(tmp_s));
+          mutation_success_ct++;
         } else {
-          fail_ct++;
+          mutation_fail_ct++;
         }
       } else {
-        // Do not mutate this state in this round.
         pnext->push_back(std::move(tmp_s));
       }
     }
 
-    // Evaluate the new populations.
-    *pnext = search_task->compute_dag.InferBound(*pnext);
-    PruneInvalidState(search_task, pnext);
-
-    // Throw away all states generated in this iterations if all new states 
are invalid.
-    if (pnext->size() > 0) {
-      std::swap(pnext, pnow);
-      schedule_cost_model->Predict(search_task, *pnow, &scores);
-
-      // Maintain the best states in the heap.
-      float iter_max_score = update_heap(*pnow, scores, out_size);
-      max_score = (iter_max_score > max_score) ? iter_max_score : max_score;
-    }
+    std::swap(pnext, pnow);
     pnext->clear();
-
-    if (iter_idx % 5 == 0 || iter_idx == num_iters) {
-      StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << 
std::setprecision(4)
-                       << "\tMax Score: " << max_score << "\tPop Size: " << 
pnow->size()
-                       << std::endl;
-    }
   }
 
-  // Copy best states in the heap to the output.
-  while (!heap.empty()) {
-    auto item = heap.top();
-    heap.pop();
+  // Copy best states in the heap to out_states
+  std::sort(heap.begin(), heap.end(), cmp);
+  for (auto& item : heap) {
     best_states.push_back(std::move(item.first));
   }
 
@@ -580,10 +540,10 @@ Array<MeasureInput> 
SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
 }
 
 TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
-    .set_body_typed([](SearchTask task, CostModel schedule_cost_model,
-                       Map<String, ObjectRef> params, int seed, int verbose,
+    .set_body_typed([](SearchTask task, CostModel program_cost_model, 
Map<String, ObjectRef> params,
+                       int seed, int verbose,
                        Optional<Array<SearchCallback>> init_search_callbacks) {
-      return SketchPolicy(task, schedule_cost_model, params, seed, verbose, 
init_search_callbacks);
+      return SketchPolicy(task, program_cost_model, params, seed, verbose, 
init_search_callbacks);
     });
 
 TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches")
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h 
b/src/auto_scheduler/search_policy/sketch_policy.h
index 2d93d87..21aaa6e 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -34,6 +34,7 @@
 #include <tvm/auto_scheduler/cost_model.h>
 #include <tvm/auto_scheduler/search_policy.h>
 
+#include <memory>
 #include <set>
 #include <string>
 #include <unordered_set>
@@ -88,15 +89,15 @@ struct SketchParamKey {
 class SketchPolicyNode : public SearchPolicyNode {
  public:
   /*! \brief The cost model to estimate the complete schedules. */
-  CostModel schedule_cost_model;
+  CostModel program_cost_model;
   /*! \brief The parameters map for this search policy. */
   Map<String, ObjectRef> params;
   /*! \brief The rules to generate sketches. */
   std::vector<SketchGenerationRule*> sketch_rules;
-  /*! \brief The rules to generate initial states. */
+  /*! \brief The rules to generate initial population. */
   std::vector<PopulationGenerationRule*> init_rules;
-  /*! \brief The rules to mutate states. */
-  std::vector<PopulationMutationRule*> mutation_rules;
+  /*! \brief The rules to mutate states in the evolutionary search. */
+  std::vector<std::shared_ptr<PopulationMutationRule>> mutation_rules;
   /*! \brief Random generator. */
   std::mt19937 rand_gen;
   /*! \brief Memorize split space for Split. */
@@ -154,6 +155,9 @@ class SketchPolicyNode : public SearchPolicyNode {
 
   /*! \brief The number of states to measure per iteration. */
   int num_measure_per_iter_;
+
+  /*! \brief The cached sketches */
+  Array<State> sketch_cache_;
 };
 
 /*!
@@ -165,14 +169,14 @@ class SketchPolicy : public SearchPolicy {
   /*!
    * \brief The constructor.
    * \param task  The SearchTask for the computation declaration.
-   * \param schedule_cost_model The cost model for complete programs.
+   * \param program_cost_model The cost model for complete programs.
    * \param params The parameters map for this search process.
    * \param seed The random seed of this search process.
    * \param verbose Verbose level. 0 for silent, 1 to output information 
during schedule
    * search.
    * \param init_search_callbacks SearchCallback to be called before schedule 
search.
    */
-  SketchPolicy(SearchTask task, CostModel schedule_cost_model, Map<String, 
ObjectRef> params,
+  SketchPolicy(SearchTask task, CostModel program_cost_model, Map<String, 
ObjectRef> params,
                int seed, int verbose, Optional<Array<SearchCallback>> 
init_search_callbacks);
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, 
SketchPolicyNode);
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc 
b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index dab6e4d..228dda4 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -34,6 +34,9 @@
 namespace tvm {
 namespace auto_scheduler {
 
+static std::vector<int> auto_unroll_configs_cpu = {0, 16, 64, 512};
+static std::vector<int> auto_unroll_configs_gpu = {0, 16, 64, 512, 1024};
+
 /********** Sketch Generation Rule **********/
 /********** RuleSkipStage **********/
 
@@ -472,9 +475,8 @@ PopulationGenerationRule::ResultKind 
InitFillTileSize::Apply(SketchPolicyNode* p
   return ResultKind::kValid;
 }
 
-PopulationGenerationRule::ResultKind 
MutateComputeLocationCommon(SketchPolicyNode* policy,
-                                                                 State* state,
-                                                                 bool 
infer_bound = true) {
+PopulationGenerationRule::ResultKind 
InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
+                                                                      State* 
state) const {
   if (GetIntParam(policy->params, 
SketchParamKey::disable_change_compute_location)) {
     return PopulationGenerationRule::ResultKind::kValid;
   }
@@ -490,81 +492,8 @@ PopulationGenerationRule::ResultKind 
MutateComputeLocationCommon(SketchPolicyNod
       continue;
     }
 
-    int target_stage_id = GetSingleConsumerId(policy->search_task, *state, 
stage_id);
-    if (target_stage_id < 0) {
-      continue;
-    }
-    const Stage& target_stage = (*state)->stages[target_stage_id];
-
-    std::vector<std::pair<int, int>> candidates;
-    bool target_compute_at_other = target_stage->compute_at == 
ComputeAtKind::kIter;
-    bool target_is_tiled = IsTiled(target_stage);
-
-    bool visited_reduce = false;
-    // enumerate compute_at location at target_stage
-    // TODO(merrymercy): More analysis here to make smarter choices
-    for (size_t i = 0; i < target_stage->iters.size(); ++i) {
-      const Iterator& target_iter = target_stage->iters[i];
-      if (target_iter->iter_kind == IteratorKind::kReduction) {
-        visited_reduce = true;
-        if (!target_is_tiled) {  // Do not go into reduce iter
-          break;
-        }
-      } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
-        if (visited_reduce) {  // Do not go into inner tile
-          break;
-        }
-      }
-
-      if (target_iter->annotation == IteratorAnnotation::kUnroll) {
-        // Do not go into the unroll region of const tensor indices
-        break;
-      }
-
-      if (GetExtent(target_iter) == 1) {
-        // Skip iterators with length of 1
-        continue;
-      }
-      if (target_compute_at_other && target_iter->iter_kind == 
IteratorKind::kSpatial &&
-          StrEndsWith(target_iter->name, ".0")) {
-        // Skip the first level iterators if target stage compute_at another 
stage
-        // In this case, the lengths of first level iterators are always one
-        continue;
-      }
-      candidates.emplace_back(target_stage_id, i);
-
-      if 
((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id,
 i))) {
-        break;
-      }
-    }
-
-    // if the target_stage is already compute_at another stage X, try also 
compute_at X
-    // We call stage X as `target_target_stage`
-    if (target_compute_at_other) {
-      int target_target_stage_id;
-      target_target_stage_id = 
(*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first;
-      const Stage& target_target_stage = 
(*state)->stages[target_target_stage_id];
-
-      for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
-        const Iterator& target_target_iter = target_target_stage->iters[i];
-        if (target_target_iter->iter_kind == IteratorKind::kReduction ||
-            (*state)->attach_map->iter_to_attached_stages.count(
-                std::make_pair(target_target_stage_id, i))) {
-          break;
-        }
-
-        if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
-          // Do not go into the unroll region of const tensor indices
-          break;
-        }
-
-        if (GetExtent(target_target_iter) == 1) {  // skip iterators with 
length of 1
-          continue;
-        }
-
-        candidates.emplace_back(target_target_stage_id, i);
-      }
-    }
+    std::vector<std::pair<int, int>> candidates =
+        GetComputeLocationCandidates(policy->search_task, *state, stage_id);
 
     int choice = (policy->rand_gen)() % (candidates.size() + 2);
 
@@ -585,17 +514,10 @@ PopulationGenerationRule::ResultKind 
MutateComputeLocationCommon(SketchPolicyNod
     }
   }
 
-  if (infer_bound) {
-    *state = policy->search_task->compute_dag.InferBound(*state);
-  }
+  *state = policy->search_task->compute_dag.InferBound(*state);
   return PopulationGenerationRule::ResultKind::kValid;
 }
 
-PopulationGenerationRule::ResultKind 
InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
-                                                                      State* 
state) const {
-  return MutateComputeLocationCommon(policy, state, true);
-}
-
 PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* 
policy,
                                                          State* state) const {
   std::function<void(const SketchPolicyNode&, State*, int stage_id, int 
iter_offset)>
@@ -663,9 +585,8 @@ PopulationGenerationRule::ResultKind 
InitParallel::Apply(SketchPolicyNode* polic
 
 PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* 
policy,
                                                        State* state) const {
-  std::vector<int> auto_unroll_configs = IsGPUTask(policy->search_task)
-                                             ? std::vector<int>({0, 16, 64, 
512, 1024})
-                                             : std::vector<int>({0, 16, 64, 
512});
+  std::vector<int>& auto_unroll_configs =
+      IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : 
auto_unroll_configs_cpu;
   for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
     const Stage& stage = (*state)->stages[stage_id];
     // Skip the inlined stage and placeholder stage
@@ -801,6 +722,10 @@ PopulationGenerationRule::ResultKind 
InitThreadBind::Apply(SketchPolicyNode* pol
 
     // Deal with the cross-thread reduction generated by 
RuleCrossThreadReduction
     if (HasCrossThreadReduction(*state, stage_id)) {
+      if (stage->compute_at != ComputeAtKind::kRoot) {
+        continue;
+      }
+
       Iterator fused_it;
       *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, 
&fused_it));
       state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX);
@@ -983,6 +908,7 @@ PopulationGenerationRule::ResultKind 
MutateTileSize::Apply(SketchPolicyNode* pol
       continue;
     }
 
+    // Divide one factor from lengths[src_idx] and multiply it to 
lengths[dst_idx]
     size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
     const std::vector<int>& factors = policy->split_memo.GetFactors(length);
     CHECK_GE(factors.size(), 1);
@@ -1017,6 +943,8 @@ PopulationGenerationRule::ResultKind 
MutateTileSize::Apply(SketchPolicyNode* pol
       }
     }
 
+    CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor);
+
     StateNode* pstate = state->CopyOnWrite();
     pstate->transform_steps.Set(
         step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
@@ -1027,39 +955,98 @@ PopulationGenerationRule::ResultKind 
MutateTileSize::Apply(SketchPolicyNode* pol
   return ResultKind::kInvalid;
 }
 
-PopulationGenerationRule::ResultKind 
MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
-                                                                  State* 
state) const {
+PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* 
policy,
+                                                             State* state) 
const {
   // Extract all auto_unroll_max_step pragma steps.
-  std::vector<int> annotate_steps;
+  std::vector<int> pragma_steps;
   for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
     if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
       if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
-        annotate_steps.push_back(i);
+        pragma_steps.push_back(i);
       }
     }
   }
-  if (annotate_steps.empty()) {
+  if (pragma_steps.empty()) {
     return ResultKind::kInvalid;
   }
 
-  // Random pick up one unroll factor candidate.
-  auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : 
&cpu_unroll_cands_;
-  auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % 
cands->size()]);
+  std::vector<int>& auto_unroll_configs =
+      IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : 
auto_unroll_configs_cpu;
 
-  // Random pick up and mutate an unroll step.
-  auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()];
+  // Randomly pick up an auto unroll pragma step
+  auto step_id = pragma_steps[(policy->rand_gen)() % pragma_steps.size()];
   auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
   CHECK(ps);
+
+  // Mutate its value to a random candidates
+  auto val = std::to_string(auto_unroll_configs[(policy->rand_gen)() % 
auto_unroll_configs.size()]);
   StateNode* pstate = state->CopyOnWrite();
-  pstate->transform_steps.Set(step_id,
-                              PragmaStep(ps->stage_id, ps->iter_id,
-                                         std::string("auto_unroll_max_step") + 
"$" + new_factor));
+  pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id,
+                                                  
std::string("auto_unroll_max_step") + "$" + val));
   return ResultKind::kValid;
 }
 
 PopulationGenerationRule::ResultKind 
MutateComputeLocation::Apply(SketchPolicyNode* policy,
                                                                   State* 
state) const {
-  return MutateComputeLocationCommon(policy, state, false);
+  if (GetIntParam(policy->params, 
SketchParamKey::disable_change_compute_location)) {
+    return PopulationGenerationRule::ResultKind::kInvalid;
+  }
+
+  // Extract all compute_at steps.
+  std::vector<int> compute_at_steps;
+  for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
+    if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) {
+      int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id;
+
+      if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) {
+        continue;
+      }
+
+      if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + 
stage_inc)) {
+        continue;
+      }
+      compute_at_steps.push_back(s);
+    }
+  }
+  if (compute_at_steps.empty()) {
+    return PopulationGenerationRule::ResultKind::kInvalid;
+  }
+
+  // Randomly pick one step
+  size_t step_id = compute_at_steps[(policy->rand_gen)() % 
compute_at_steps.size()];
+  auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
+  int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
+  CHECK(ps != nullptr);
+
+  std::vector<std::pair<int, int>> candidates =
+      GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + 
stage_inc);
+
+  if (candidates.empty()) {
+    return PopulationGenerationRule::ResultKind::kInvalid;
+  }
+
+  int choice = (policy->rand_gen)() % (candidates.size());
+  int new_compute_at_stage_id = candidates[choice].first;
+  int new_compute_at_iter_id = candidates[choice].second;
+
+  // Replay a new state.
+  State tmp_s = policy->search_task->compute_dag->init_state;
+  for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
+    if (s == step_id) {
+      tmp_s.CopyOnWrite()->transform_steps.push_back(
+          ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, 
new_compute_at_iter_id));
+    } else {
+      
tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]);
+    }
+    try {
+      StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, 
policy->search_task->compute_dag);
+    } catch (dmlc::Error& e) {
+      return PopulationGenerationRule::ResultKind::kInvalid;
+    }
+  }
+
+  *state = tmp_s;
+  return PopulationGenerationRule::ResultKind::kValid;
 }
 
 PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* 
policy,
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h 
b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 418fbda..4098df2 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -124,7 +124,7 @@ 
DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
 
 /********** Init Population **********/
 
-/*! \brief The base class for derivation rules used in the initial population. 
*/
+/*! \brief The base class for rules used to annotate the sketches to get the 
initial population. */
 class PopulationGenerationRule {
  public:
   /*! \brief Result enumeration of the apply function. */
@@ -138,8 +138,12 @@ class PopulationGenerationRule {
    * \return The result of this rule, indicate if there's any valid state 
generated.
    */
   virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;
+
+  /*! \brief The deconstructor */
+  virtual ~PopulationGenerationRule() = default;
 };
 
+// A helper to define population initialization rules
 #define DEFINE_INIT_POPULATION_RULE(rule_name)                            \
   class rule_name : public PopulationGenerationRule {                     \
    public:                                                                \
@@ -149,7 +153,7 @@ class PopulationGenerationRule {
 /*! \brief The rule that fills the incomplete SplitSteps. */
 DEFINE_INIT_POPULATION_RULE(InitFillTileSize);
 
-/*! \brief The rule that randomly changes the computation location for some 
stages, which do not
+/*! \brief The rule that randomly changes the computation location for some 
stages that do not
  * need tiling and are not strictly inlineable(e.g. data padding). */
 DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation);
 
@@ -170,50 +174,37 @@ DEFINE_INIT_POPULATION_RULE(InitThreadBind);
 /*! \brief The base class for mutation rules used in the evolutionary search. 
*/
 class PopulationMutationRule : public PopulationGenerationRule {
  public:
-  /*!
-   * \brief Get the priority level of this mutation rule.
-   * \return The priority level of this mutation rule. Higher the better.
+  /* \brief The constructor
+   * \param selection_weight the probabiliy of applying this rule is
+   *        proportional to this weight
    */
-  virtual int GetLevel(const SearchTask& task) const = 0;
+  explicit PopulationMutationRule(double selection_weight) : 
weight(selection_weight) {}
+
+  /* \brief The weight of this rule */
+  double weight;
 };
 
-// A helper to define mutation rules with a constant rule level.
-#define DEFINE_MUTATE_POPULATION_RULE(rule_name, rule_level)                \
-  class rule_name : public PopulationMutationRule {                         \
-   public:                                                                  \
-    ResultKind Apply(SketchPolicyNode* policy, State* state) const final;   \
-    int GetLevel(const SearchTask& task) const final { return rule_level; } \
+// A helper to define mutation rules used in the evolutionary search
+#define DEFINE_MUTATE_POPULATION_RULE(rule_name)                          \
+  class rule_name : public PopulationMutationRule {                       \
+   public:                                                                \
+    explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
+    ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
   };
 
 /*! \brief The rule that mutates tile size by randomly dividing a tile size by 
a factor
     and multipling it to another tile size. */
-DEFINE_MUTATE_POPULATION_RULE(MutateTileSize, 100);
-
-/*! \brief The rule that mutates the fusion iterators annotated by parallel. */
-DEFINE_MUTATE_POPULATION_RULE(MutateParallel, 50);
+DEFINE_MUTATE_POPULATION_RULE(MutateTileSize);
 
-/*! \brief The rule that mutates the factor of a randomly selected auto max 
unroll step. */
-class MutateMaxUnrollFactor : public PopulationMutationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-  int GetLevel(const SearchTask& task) const final { return 10; }
-
-  const std::vector<int> cpu_unroll_cands_ = {0, 16, 64, 512, 1024};
-  const std::vector<int> gpu_unroll_cands_ = {0, 16, 64, 512};
-};
+/*! \brief The rule that mutates the number of fused outer iterators annotated 
by parallel. */
+DEFINE_MUTATE_POPULATION_RULE(MutateParallel);
 
-/*! \brief The rule that randomly changes the computation location for some 
stages, which do not
+/*! \brief The rule that randomly changes the computation location for some 
stages that do not
  * need tiling and are not strictly inlineable(e.g. data padding). */
-class MutateComputeLocation : public PopulationMutationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-  int GetLevel(const SearchTask& task) const final {
-    if (IsGPUTask(task)) {
-      return 0;
-    }
-    return 5;
-  }
-};
+DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation);
+
+/*! \brief The rule that mutates the value of a randomly selected auto unroll 
pragma step. */
+DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll);
 
 }  // namespace auto_scheduler
 }  // namespace tvm
diff --git a/src/auto_scheduler/search_policy/utils.cc 
b/src/auto_scheduler/search_policy/utils.cc
index 62ffce4..744573a 100644
--- a/src/auto_scheduler/search_policy/utils.cc
+++ b/src/auto_scheduler/search_policy/utils.cc
@@ -67,6 +67,87 @@ Array<Integer> GetSpatialSplitStepIds(const State& s, int 
stage_id) {
   return spatial_split_step_ids;
 }
 
+std::vector<std::pair<int, int>> GetComputeLocationCandidates(const 
SearchTask& task,
+                                                              const State& 
state, int stage_id) {
+  int target_stage_id = GetSingleConsumerId(task, state, stage_id);
+  if (target_stage_id < 0) {
+    return {};
+  }
+  const Stage& target_stage = state->stages[target_stage_id];
+
+  std::vector<std::pair<int, int>> candidates;
+  bool target_compute_at_other = target_stage->compute_at == 
ComputeAtKind::kIter;
+  bool target_is_tiled = IsTiled(target_stage);
+
+  bool visited_reduce = false;
+  // Enumerate compute_at location at target_stage
+  // TODO(merrymercy): More analysis here to make smarter choices
+  for (size_t i = 0; i < target_stage->iters.size(); ++i) {
+    const Iterator& target_iter = target_stage->iters[i];
+    if (target_iter->iter_kind == IteratorKind::kReduction) {
+      visited_reduce = true;
+      if (!target_is_tiled) {  // Do not go into reduce iter
+        break;
+      }
+    } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
+      if (visited_reduce) {  // Do not go into inner tile
+        break;
+      }
+    }
+
+    if (target_iter->annotation == IteratorAnnotation::kUnroll) {
+      // Do not go into the unroll region of const tensor indices
+      break;
+    }
+
+    if (GetExtent(target_iter) == 1) {
+      // Skip iterators with length of 1
+      continue;
+    }
+    if (target_compute_at_other && target_iter->iter_kind == 
IteratorKind::kSpatial &&
+        StrEndsWith(target_iter->name, ".0")) {
+      // Skip the first level iterators if target stage compute_at another 
stage
+      // In this case, the lengths of first level iterators are always one
+      continue;
+    }
+    candidates.emplace_back(target_stage_id, i);
+
+    if 
(state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id,
 i))) {
+      break;
+    }
+  }
+
+  // if the target_stage is already compute_at another stage X, try also 
compute_at X
+  // We call stage X as `target_target_stage`
+  if (target_compute_at_other) {
+    int target_target_stage_id;
+    target_target_stage_id = 
state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
+    const Stage& target_target_stage = state->stages[target_target_stage_id];
+
+    for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
+      const Iterator& target_target_iter = target_target_stage->iters[i];
+      if (target_target_iter->iter_kind == IteratorKind::kReduction ||
+          state->attach_map->iter_to_attached_stages.count(
+              std::make_pair(target_target_stage_id, i))) {
+        break;
+      }
+
+      if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
+        // Do not go into the unroll region of const tensor indices
+        break;
+      }
+
+      if (GetExtent(target_target_iter) == 1) {  // skip iterators with length 
of 1
+        continue;
+      }
+
+      candidates.emplace_back(target_target_stage_id, i);
+    }
+  }
+
+  return candidates;
+}
+
 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& 
format,
                          std::vector<int>* spatial_split_step_ids) {
   // Temporal object to be used if the input pointer is nullptr
@@ -327,7 +408,7 @@ void PruneInvalidState(const SearchTask& task, 
Array<State>* states) {
   }
 
   if (pt == 0) {
-    LOG(INFO) << "All states are invalid.";
+    LOG(FATAL) << "Internal error: All states are invalid.";
   } else {
     states->resize(pt);
   }
diff --git a/src/auto_scheduler/search_policy/utils.h 
b/src/auto_scheduler/search_policy/utils.h
index d2ba128..75bf0d0 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -372,7 +372,8 @@ inline bool HasSingleElementwiseMatchedConsumer(const 
SearchTask& task, const St
     *target_stage_id = *consumers.begin();
     if (ElementwiseMatch(task, state, stage_id, *target_stage_id) &&
         (!(HasReduceIter(state->stages[stage_id]) &&
-           HasReduceIter(state->stages[*target_stage_id])))) {
+           HasReduceIter(state->stages[*target_stage_id]))) &&
+        (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) {
       return true;
     }
   }
@@ -535,6 +536,22 @@ inline Iterator 
GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
   return stage->iters[0];
 }
 
+/*! \brief Get the target stage id of a history step in the new state.
+ * We need this because the stage_id in the history may be stale due to later 
steps */
+inline int GetTargetStageIDInState(const State& s, int step_id) {
+  int stage_inc = 0;
+
+  for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
+    if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+        s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+        s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+      if (s->transform_steps[i]->stage_id <= 
s->transform_steps[step_id]->stage_id + stage_inc)
+        stage_inc++;
+    }
+  }
+  return s->transform_steps[step_id]->stage_id + stage_inc;
+}
+
 /*! \brief Get all split steps for one stage. */
 inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* 
split_step_ids) {
   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
@@ -615,6 +632,32 @@ inline Array<State> RandomSampleStates(const Array<State>& 
in_states, std::mt199
   return out_states;
 }
 
+/*! \brief Compute prefix-sum probabiilty based on the given weights */
+inline void ComputePrefixSumProb(const std::vector<float>& weights,
+                                 std::vector<double>* prefix_sum_probs) {
+  // Compute selection probabilities.
+  float sum = 0.0;
+  prefix_sum_probs->resize(weights.size());
+  for (size_t i = 0; i < weights.size(); ++i) {
+    sum += std::max(weights[i], 0.0f);
+    (*prefix_sum_probs)[i] = sum;
+  }
+  for (size_t i = 0; i < weights.size(); ++i) {
+    (*prefix_sum_probs)[i] /= sum;
+  }
+}
+
+/*! \brief Random choose an index according to a prefix sum probability. */
+inline int RandomChoose(const std::vector<double>& prefix_sum_probs, 
std::mt19937* random_gen) {
+  std::uniform_real_distribution<> dis(0.0, 1.0);
+  double x = dis(*random_gen);
+
+  CHECK(!prefix_sum_probs.empty());
+
+  return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) 
-
+         prefix_sum_probs.begin();
+}
+
 /*! \brief Print a title */
 inline void PrintTitle(const std::string& title, int verbose) {
   StdCout(verbose) << Chars('-', 60) << "\n"
@@ -648,6 +691,10 @@ class SplitFactorizationMemo {
 /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
 Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
 
+/*! \brief Get the possible compute locations for a stage. */
+std::vector<std::pair<int, int>> GetComputeLocationCandidates(const 
SearchTask& task,
+                                                              const State& 
state, int stage_id);
+
 // Apply multi-level tiling structure according to a string format,
 // where "S" stands a space level, "R" stands for a reduction level.
 // For example, if the format is "SSRSRS", then we will
@@ -662,17 +709,6 @@ State DoMultiLevelTiling(const State& state, int stage_id, 
const std::string& fo
 State FollowTiling(const State& state, int stage_id, const std::vector<int>& 
split_step_ids,
                    int n_split);
 
-// Random choose an index according to a prefix sum probability.
-inline int RandomChoose(const std::vector<double>& prefix_sum_probs, 
std::mt19937* random_gen) {
-  std::uniform_real_distribution<> dis(0.0, 1.0);
-  double x = dis(*random_gen);
-
-  CHECK(!prefix_sum_probs.empty());
-
-  return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) 
-
-         prefix_sum_probs.begin();
-}
-
 // Prune invalid states and return the results in-place.
 void PruneInvalidState(const SearchTask& task, Array<State>* states);
 
diff --git a/src/auto_scheduler/transform_step.cc 
b/src/auto_scheduler/transform_step.cc
index cec83bb..2a93497 100755
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -780,7 +780,9 @@ Array<Iterator> ApplySplitToState(State* state, int 
stage_id, int iter_id,
       res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
       tosplit_min = NullOpt;
       tosplit_extent = NullOpt;
-      concrete = false;
+      if (!l.defined()) {
+        concrete = false;
+      }
     }
     outs.push_back(std::move(res));
   }
diff --git a/tests/python/unittest/test_auto_scheduler_common.py 
b/tests/python/unittest/test_auto_scheduler_common.py
index 33e498e..eaf328c 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -40,6 +40,19 @@ def matmul_auto_scheduler_test(N, M, K):
     return [A, B, C]
 
 
+@auto_scheduler.register_workload
+def double_matmul_auto_scheduler_test(N):
+    A = te.placeholder((N, N), name="A", dtype="float32")
+    B = te.placeholder((N, N), name="B", dtype="float32")
+    C = te.placeholder((N, N), name="C", dtype="float32")
+    k = te.reduce_axis((0, N), name="k")
+    D = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), 
name="D")
+    k = te.reduce_axis((0, N), name="k")
+    E = te.compute((N, N), lambda i, j: te.sum(D[i][k] * C[k][j], axis=[k]), 
name="E")
+
+    return [A, B, C, E]
+
+
 # Test for register_workload with different name
 @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1")
 def matmul_auto_scheduler_test_rename_0(N, M, K):
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py 
b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index eb706b7..bf6efd0 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -47,7 +47,7 @@ def test_evo_search():
     workload_key = 
auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4))
     dag = auto_scheduler.ComputeDAG(workload_key)
     task = auto_scheduler.SearchTask(dag, workload_key, 
tvm.target.Target("llvm"))
-    policy = auto_scheduler.SketchPolicy(task, 
schedule_cost_model=MockCostModel(), verbose=0)
+    policy = auto_scheduler.SketchPolicy(task, 
program_cost_model=MockCostModel(), verbose=0)
     states = policy.sample_initial_population(50)
     pruned_states = []
     for state in states:
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py 
b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 6ec96a6..04b54b2 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -57,7 +57,7 @@ def search_common(
             search_policy = auto_scheduler.EmptyPolicy(task)
         elif search_policy == "sketch":
             search_policy = auto_scheduler.SketchPolicy(
-                task, schedule_cost_model=cost_model, 
init_search_callbacks=init_search_callbacks
+                task, program_cost_model=cost_model, 
init_search_callbacks=init_search_callbacks
             )
 
         tuning_options = auto_scheduler.TuningOptions(
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py 
b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index fa67756..5a687da 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -25,6 +25,7 @@ from tvm.auto_scheduler.loop_state import Stage
 
 from test_auto_scheduler_common import (
     matmul_auto_scheduler_test,
+    double_matmul_auto_scheduler_test,
     conv2d_nchw_bn_relu_auto_scheduler_test,
     max_pool2d_auto_scheduler_test,
     min_nm_auto_scheduler_test,
@@ -73,9 +74,9 @@ def assert_has_cross_thread_reduction(state, stage_id):
 def test_cpu_matmul_sketch():
     sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 
"llvm")
     """ 3 multi-level tiling sketches
-        0 - Multi-level tiling
-        1 - Multi-level tiling with cache write on position 0
-        2 - Multi-level tiling with cache write on position 1
+        No.0 : Multi-level tiling
+        No.1 : Multi-level tiling with cache write on position 0
+        No.2 : Multi-level tiling with cache write on position 1
     """
     assert len(sketches) == 3
     # Sketch 0
@@ -92,11 +93,11 @@ def test_cpu_matmul_sketch():
 
     sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), 
"llvm")
     """ 2 rfactor sketches + 3 multi-level tiling sketches
-        0 - Rfactor with factor position 0
-        1 - Rfactor with factor position 1
-        2 - Multi-level tiling
-        3 - Multi-level tiling with cache write on position 0
-        4 - Multi-level tiling with cache write on position 1
+        No.0 : Rfactor with factor position 0
+        No.1 : Rfactor with factor position 1
+        No.2 : Multi-level tiling
+        No.3 : Multi-level tiling with cache write on position 0
+        No.4 : Multi-level tiling with cache write on position 1
     """
     assert len(sketches) == 5
     # Sketch 0
@@ -116,15 +117,20 @@ def test_cpu_matmul_sketch():
     assert_compute_at_condition(sketches[4].stages[2], "iter")
     assert sketches[3] != sketches[4]
 
+    sketches = generate_sketches(double_matmul_auto_scheduler_test, (512,), 
"llvm")
+    """ 3 multi-level tiling sketches for one matmul, so 3 * 3 = 9 sketches in 
total """
+    assert len(sketches) == 9
+    assert_is_tiled(sketches[8].stages[5])
+
 
 def test_cpu_conv2d_bn_relu_sketch():
     sketches = generate_sketches(
         conv2d_nchw_bn_relu_auto_scheduler_test, (1, 56, 56, 512, 512, 3, 1, 
1), "llvm"
     )
     """ 3 multi-level tiling sketches
-        0 - Conv2d multi-level tiling with fusion on position 0
-        1 - Conv2d multi-level tiling with fusion on position 1
-        2 - Conv2d multi-level tiling without fusion
+        No.0 : Conv2d multi-level tiling with fusion on position 0
+        No.1 : Conv2d multi-level tiling with fusion on position 1
+        No.2 : Conv2d multi-level tiling without fusion
     """
     assert len(sketches) == 3
     # Sketch 0
@@ -164,9 +170,9 @@ def test_cpu_max_pool2d_sketch():
 def test_cpu_min_sketch():
     sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 
"llvm")
     """ 2 rfactor sketches + 1 default sketch
-        0 - Rfactor with factor position 0
-        1 - Rfactor with factor position 1
-        2 - Default sketch
+        No.0 : Rfactor with factor position 0
+        No.1 : Rfactor with factor position 1
+        No.2 : Default sketch
     """
     assert len(sketches) == 3
     # Sketch 0
@@ -209,9 +215,9 @@ def test_cpu_conv2d_winograd_sketch():
         conv2d_winograd_nhwc_auto_scheduler_test, (1, 28, 28, 128, 128, 3, 1, 
1), "llvm"
     )
     """ 3 multi-level tiling sketches
-        0 - Bgemm multi-level tiling
-        1 - Bgemm multi-level tiling with cache write on position 0
-        2 - Bgemm multi-level tiling with cache write on position 1
+        No.0 : Bgemm multi-level tiling
+        No.1 : Bgemm multi-level tiling with cache write on position 0
+        No.2 : Bgemm multi-level tiling with cache write on position 1
     """
     assert len(sketches) == 3
     # Sketch 0
@@ -277,6 +283,12 @@ def test_cuda_matmul_sketch():
     assert_compute_at_condition(sketches[1].stages[4], "iter")
     assert_is_tiled(sketches[1].stages[5])
 
+    sketches = generate_sketches(double_matmul_auto_scheduler_test, (512,), 
"cuda")
+    """ 1 multi-level tiling sketch for one matmul, so 1 x 1 = 1 sketch in 
total """
+    assert len(sketches) == 1
+    assert_compute_at_condition(sketches[0].stages[5], "root")
+    assert_compute_at_condition(sketches[0].stages[6], "iter")
+
 
 @tvm.testing.requires_cuda
 def test_cuda_conv2d_bn_relu_sketch():
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py 
b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
similarity index 53%
copy from tutorials/auto_scheduler/tune_matmul_x86.py
copy to tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 1a9af42..98e66bb 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -15,11 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 """
-Auto-scheduling matrix multiplication for CPU
-=============================================
+.. _auto-scheduler-conv-gpu:
+
+Auto-scheduling a convolution layer for GPU
+===========================================
 **Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
             `Chengfan Jia <https://github.com/jcf94/>`_
 
+
 Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which 
relies on 
 manual templates to define the search space, the auto-scheduler does not 
require any templates.
 The auto-scheduler is template-free, so users only need to write the 
computation declaration without
@@ -27,58 +30,70 @@ any schedule commands or templates.
 The auto-scheduler can automatically generate a large
 search space and find a good schedule in the space.
 
-We use matrix multiplication as an example in this tutorial.
+We use a convolution layer as an example in this tutorial.
 """
 
 import numpy as np
 import tvm
-from tvm import te, testing, auto_scheduler
+from tvm import te, testing, auto_scheduler, topi
+from tvm.topi.testing import conv2d_nchw_python
 
 ######################################################################
 # Define the computation
 # ^^^^^^^^^^^^^^^^^^^^^^
-# To begin with, we define the computation of a matmul with bias add.
+# To begin with, let us define the computation of a convolution layer.
 # The function should return the list of input/output tensors.
 # From these tensors, the auto-scheduler can get the whole computational graph.
 
 
 @auto_scheduler.register_workload
-def matmul_add(N, L, M, dtype):
-    A = te.placeholder((N, L), name="A", dtype=dtype)
-    B = te.placeholder((L, M), name="B", dtype=dtype)
-    C = te.placeholder((N, M), name="C", dtype=dtype)
-
-    k = te.reduce_axis((0, L), name="k")
-    matmul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], 
axis=k), name="matmul")
-    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
-
-    return [A, B, C, out]
+def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
+    data = te.placeholder((N, CI, H, W), name="data")
+    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+    bias = te.placeholder((1, CO, 1, 1), name="bias")
+    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, 
out_dtype="float32")
+    out = topi.nn.relu(conv + bias)
+    return [data, kernel, bias, out]
 
 
 ######################################################################
 # Create the search task
 # ^^^^^^^^^^^^^^^^^^^^^^
-# We then create a search task with N=L=M=128 and dtype="float32"
+# We then create a search task for the last convolution layer in the resnet.
 
-target = tvm.target.Target("llvm")
-task = auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), 
target)
+target = tvm.target.Target("cuda")
+
+# the last layer in resnet
+N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), 
(1, 1)
+task = auto_scheduler.create_task(conv2d_layer, (N, H, W, CO, CI, KH, KW, 
strides, padding), target)
 
 # Inspect the computational graph
 print(task.compute_dag)
 
 ######################################################################
-# Next, we set parameters for the auto-scheduler.
+# Next, we set parameters for the auto-scheduler. These parameters
+# mainly specify how we do the measurement during the search and auto-tuning.
 #
+# * `measure_ctx` launches a different process for measurement. This
+#   provides an isolation. It can protect the master process from GPU crashes
+#   happended during measurement and avoid other runtime conflicts.
+# * `min_repeat_ms` defines the minimum duration of one "repeat" in every 
measurement.
+#   This can warmup the GPU, which is necessary to get accurate measurement 
results.
+#   Typically, we recommend a value > 300 ms.
 # * `num_measure_trials` is the number of measurement trials we can use during 
the search.
 #   We only make 10 trials in this tutorial for a fast demonstration. In 
practice, 1000 is a
 #   good value for the search to converge. You can do more trials according to 
your time budget.
-# * In addition, we use `RecordToFile` to dump measurement records into a file 
`matmul.json`.
+# * In addition, we use `RecordToFile` to dump measurement records into a file 
`conv2d.json`.
 #   The measurement records can be used to query the history best, resume the 
search,
 #   and do more analyses later.
-# * see :any:`auto_schedule.TuningOptions`: for more parameters
+# * see :any:`auto_scheduler.auto_schedule.TuningOptions`:,
+#   :any:`auto_scheduler.measure.LocalRPCMeasureContext` for more parameters.
 
+measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
 tune_option = auto_scheduler.TuningOptions(
-    num_measure_trials=10, 
measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")]
+    num_measure_trials=10,
+    runner=measure_ctx.runner,
+    measure_callbacks=[auto_scheduler.RecordToFile("conv2d.json")],
 )
 
 ######################################################################
@@ -93,31 +108,46 @@ sch, args = auto_scheduler.auto_schedule(task, 
tuning_options=tune_option)
 ######################################################################
 # We can lower the schedule to see the IR after auto-scheduling.
 # The auto-scheduler correctly performs optimizations including multi-level 
tiling,
-# parallelization, vectorization, unrolling and fusion.
+# cooperative fetching, unrolling and operator fusion.
 
 print(tvm.lower(sch, args, simple_mode=True))
 
 ######################################################################
-# Check correctness
-# ^^^^^^^^^^^^^^^^^
-# We build the binary and check its correctness
-
-func = tvm.build(sch, args)
-a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-d_np = a_np.dot(b_np) + c_np
-
-d_tvm = tvm.nd.empty(d_np.shape)
-func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm)
-
-tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)
+# Check correctness and evaluate performance
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# We build the binary and check its correctness and performance.
+
+func = tvm.build(sch, args, target)
+
+# check correctness
+data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
+weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
+bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
+conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
+out_np = np.maximum(conv_np + bias_np, 0.0)
+
+ctx = tvm.gpu()
+data_tvm = tvm.nd.array(data_np, ctx=ctx)
+weight_tvm = tvm.nd.array(weight_np, ctx=ctx)
+bias_tvm = tvm.nd.array(bias_np, ctx=ctx)
+out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
+func(data_tvm, weight_tvm, bias_tvm, out_tvm)
+
+# Check results
+tvm.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)
+
+# Evaluate execution time
+evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
+print(
+    "Execution time of this operator: %.3f ms"
+    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 
1000)
+)
 
 ######################################################################
 # Using the record file
 # ^^^^^^^^^^^^^^^^^^^^^
 # During the search, all measuremnt records are dumpped into the record
-# file "matmul.json". The measurement records can be used to re-apply search 
results,
+# file "conv2d.json". The measurement records can be used to re-apply search 
results,
 # resume the search, and perform other analyses.
 
 ######################################################################
@@ -125,16 +155,17 @@ tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), 
rtol=1e-3)
 # print the equivalent python schedule API, and build the binary again.
 
 # Load the measuremnt record for the best schedule
-inp, res = auto_scheduler.load_best("matmul.json", task.workload_key)
+inp, res = auto_scheduler.load_best("conv2d.json", task.workload_key)
 
 # Print equivalent python schedule API. This can be used for debugging and
 # learning the behavior of the auto-scheduler.
+print("Equivalent python schedule:")
 print(task.compute_dag.print_python_code_from_state(inp.state))
 
 # Rebuild the binary. This shows how you can apply the best schedule from a
 # log file without reruning the search again.
 sch, args = task.compute_dag.apply_steps_from_state(inp.state)
-func = tvm.build(sch, args)
+func = tvm.build(sch, args, target)
 
 ######################################################################
 # A more complicated example is to resume the search.
@@ -143,31 +174,18 @@ func = tvm.build(sch, args)
 # In the example below we resume the status and do more 5 trials.
 
 
-def resume_search(task, log_file):
-    cost_model = auto_scheduler.XGBModel()
-    cost_model.update_from_file(log_file)
-    search_policy = auto_scheduler.SketchPolicy(
-        task, cost_model, 
init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
-    )
-    tune_option = auto_scheduler.TuningOptions(
-        num_measure_trials=5, 
measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
-    )
-    sch, args = auto_scheduler.auto_schedule(task, search_policy, 
tuning_options=tune_option)
-
-
-# resume_search(task, "matmul.json")
+log_file = "conv2d.json"
+cost_model = auto_scheduler.XGBModel()
+cost_model.update_from_file(log_file)
+search_policy = auto_scheduler.SketchPolicy(
+    task, cost_model, 
init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
+)
+tune_option = auto_scheduler.TuningOptions(
+    num_measure_trials=5,
+    runner=measure_ctx.runner,
+    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+)
+sch, args = auto_scheduler.auto_schedule(task, search_policy, 
tuning_options=tune_option)
 
-######################################################################
-# .. note::
-#   We cannot run the line above because of the conflict between
-#   python's multiprocessing and tvm's thread pool.
-#   After running a tvm generated binary (L112), the python's multiprocessing
-#   library will hang forever.
-#   You have to make sure that you don't run any tvm generated binaries before
-#   calling ansor's search. To run the L156 above, you should comment out 
L112-114.
-#
-#   You should be careful about this problem in your applications.
-#   There are other workarounds for this problem.
-#   For example, you can start a new thread/process (with the builtin python 
library
-#   threading or multiprocessing) and run the tvm binaries in the new 
thread/process.
-#   This provides an isolation and avoids the conflict in the main 
thread/process.
+# kill the measurement process
+del measure_ctx
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py 
b/tutorials/auto_scheduler/tune_matmul_x86.py
index 1a9af42..918030d 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -37,7 +37,7 @@ from tvm import te, testing, auto_scheduler
 ######################################################################
 # Define the computation
 # ^^^^^^^^^^^^^^^^^^^^^^
-# To begin with, we define the computation of a matmul with bias add.
+# To begin with, let us define the computation of a matmul with bias add.
 # The function should return the list of input/output tensors.
 # From these tensors, the auto-scheduler can get the whole computational graph.
 
@@ -59,6 +59,9 @@ def matmul_add(N, L, M, dtype):
 # Create the search task
 # ^^^^^^^^^^^^^^^^^^^^^^
 # We then create a search task with N=L=M=128 and dtype="float32"
+# If your machine supports avx instructions, you can
+# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2
+# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512
 
 target = tvm.target.Target("llvm")
 task = auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), 
target)
@@ -75,7 +78,7 @@ print(task.compute_dag)
 # * In addition, we use `RecordToFile` to dump measurement records into a file 
`matmul.json`.
 #   The measurement records can be used to query the history best, resume the 
search,
 #   and do more analyses later.
-# * see :any:`auto_schedule.TuningOptions`: for more parameters
+# * see :any:`auto_scheduler.auto_schedule.TuningOptions`: for more parameters
 
 tune_option = auto_scheduler.TuningOptions(
     num_measure_trials=10, 
measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")]
@@ -93,25 +96,38 @@ sch, args = auto_scheduler.auto_schedule(task, 
tuning_options=tune_option)
 ######################################################################
 # We can lower the schedule to see the IR after auto-scheduling.
 # The auto-scheduler correctly performs optimizations including multi-level 
tiling,
-# parallelization, vectorization, unrolling and fusion.
+# parallelization, vectorization, unrolling and operator fusion.
 
 print(tvm.lower(sch, args, simple_mode=True))
 
 ######################################################################
-# Check correctness
-# ^^^^^^^^^^^^^^^^^
-# We build the binary and check its correctness
+# Check correctness and evaluate performance
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+# We build the binary and check its correctness and performance.
 
 func = tvm.build(sch, args)
 a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
 b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
 c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-d_np = a_np.dot(b_np) + c_np
-
-d_tvm = tvm.nd.empty(d_np.shape)
-func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm)
+out_np = a_np.dot(b_np) + c_np
+
+ctx = tvm.cpu()
+a_tvm = tvm.nd.array(a_np, ctx=ctx)
+b_tvm = tvm.nd.array(b_np, ctx=ctx)
+c_tvm = tvm.nd.array(c_np, ctx=ctx)
+out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
+func(a_tvm, b_tvm, c_tvm, out_tvm)
+
+# Check results
+tvm.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)
+
+# Evaluate execution time.
+evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
+print(
+    "Execution time of this operator: %.3f ms"
+    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
+)
 
-tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)
 
 ######################################################################
 # Using the record file
@@ -129,6 +145,7 @@ inp, res = auto_scheduler.load_best("matmul.json", 
task.workload_key)
 
 # Print equivalent python schedule API. This can be used for debugging and
 # learning the behavior of the auto-scheduler.
+print("Equivalent python schedule:")
 print(task.compute_dag.print_python_code_from_state(inp.state))
 
 # Rebuild the binary. This shows how you can apply the best schedule from a
@@ -161,13 +178,16 @@ def resume_search(task, log_file):
 # .. note::
 #   We cannot run the line above because of the conflict between
 #   python's multiprocessing and tvm's thread pool.
-#   After running a tvm generated binary (L112), the python's multiprocessing
-#   library will hang forever.
-#   You have to make sure that you don't run any tvm generated binaries before
-#   calling ansor's search. To run the L156 above, you should comment out 
L112-114.
+#   After running a tvm generated binary the python's multiprocessing library
+#   will hang forever. You have to make sure that you don't run any tvm
+#   generated binaries before calling auot-scheduler's search.
+#   To run the function above, you should comment out all code in
+#   "Check correctness and evaluate performance" section.
 #
 #   You should be careful about this problem in your applications.
 #   There are other workarounds for this problem.
 #   For example, you can start a new thread/process (with the builtin python 
library
 #   threading or multiprocessing) and run the tvm binaries in the new 
thread/process.
 #   This provides an isolation and avoids the conflict in the main 
thread/process.
+#   You can also use :any:`auto_scheduler.measure.LocalRPCMeasureContext` for 
auto-scheduler,
+#   as shown in the GPU tutorial (:ref:`auto-scheduler-conv-gpu`).

Reply via email to