merrymercy commented on a change in pull request #6310:
URL: https://github.com/apache/incubator-tvm/pull/6310#discussion_r478134161



##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.cc
##########
@@ -908,7 +795,362 @@ InitPopulationRule::ResultKind 
InitThreadBind::Apply(SketchPolicyNode* policy, S
       state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX);
     }
   }
+  return ResultKind::kValid;
+}
+
+PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* 
policy,
+                                                           State* state) const 
{
+  int max_innermost_split_factor =
+      GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
+
+  // Extract all SplitStep
+  std::vector<size_t> split_step_ids;
+  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
+    if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) {
+      if (!ps->extent.defined() || 
!ps->extent.value()->IsInstance<IntImmNode>()) {
+        continue;
+      }
+      auto innermost_factor = 
ps->lengths.back().value_or(max_innermost_split_factor + 1);
+      if (GetIntImm(innermost_factor) <= max_innermost_split_factor) {
+        split_step_ids.push_back(i);
+      }
+    }
+  }
+  if (split_step_ids.empty()) {
+    // No tile size could be mutated.
+    return ResultKind::kInvalid;
+  }
+
+  // Select a SplitStep with extent larger than one to mutate.
+  int retry_ct = 0;
+  int64_t extent = 1;
+  int step_id;
+  const SplitStepNode* ps;
+
+  do {
+    step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()];
+    ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);
+    extent = GetIntImm(ps->extent.value());
+    retry_ct += 1;
+  } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent 
== 1 || extent == 0));
+
+  if (extent <= 1) {
+    // Cannot find a step with extent larger than one.
+    return ResultKind::kInvalid;
+  }
+
+  // Fetch the current tile sizes.
+  std::vector<int> lengths(ps->lengths.size() + 1, 1);
+  for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) {
+    lengths[i + 1] = GetIntImm(ps->lengths[i].value());
+  }
+  lengths[0] = extent / ElementProduct(lengths);
+
+  // Random permute the tile size order.
+  std::vector<int> random_perm;
+  RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen));
+
+  // Try to divide a factor from one tile size and multiple it to another.
+  for (size_t i = 0; i < random_perm.size(); ++i) {
+    size_t src_idx = random_perm[i];
+    int length = lengths[src_idx];
+    if (length == 1) {
+      continue;
+    }
+
+    size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
+    const std::vector<int>& factors = policy->split_memo.GetFactors(length);
+    CHECK_GE(factors.size(), 1);
+
+    int divide_factor;
+    if (dst_idx == lengths.size() - 1) {
+      // Maintain the restriction of 
hardware_params.max_innermost_split_factor.
+      int max_factor_index = static_cast<int>(factors.size()) - 1;
+      for (; max_factor_index >= 1; max_factor_index--) {
+        if (factors[max_factor_index] * lengths[dst_idx] <= 
max_innermost_split_factor) {
+          break;
+        }
+      }
+      if (max_factor_index == 0) {
+        // Failed on this dst_idx, try next one.
+        continue;
+      }
+      divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)];
+    } else {
+      divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)];
+    }
+
+    // Divide one factor from lengths[src_idx] and multiply it to 
lengths[dst_idx].
+    Array<Integer> new_lengths;
+    for (size_t j = 1; j < lengths.size(); ++j) {
+      if (j == src_idx) {
+        new_lengths.push_back(Integer(lengths[j] / divide_factor));
+      } else if (j == dst_idx) {
+        new_lengths.push_back(Integer(lengths[j] * divide_factor));
+      } else {
+        new_lengths.push_back(Integer(lengths[j]));
+      }
+    }
+
+    StateNode* pstate = state->CopyOnWrite();
+    pstate->transform_steps.Set(
+        step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
+                           Array<Optional<Integer>>(new_lengths.begin(), 
new_lengths.end()),
+                           ps->inner_to_outer));
+    return ResultKind::kValid;
+  }
+  return ResultKind::kInvalid;
+}
+
+PopulationGenerationRule::ResultKind 
MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
+                                                                  State* 
state) const {
+  // Extract all auto_unroll_max_step pragma steps.
+  std::vector<int> annotate_steps;
+  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
+    if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
+      if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
+        annotate_steps.push_back(i);
+      }
+    }
+  }
+  if (annotate_steps.empty()) {
+    return ResultKind::kInvalid;
+  }
+
+  // Random pick up one unroll factor candidate.
+  auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : 
&cpu_unroll_cands_;
+  auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % 
cands->size()]);
+
+  // Random pick up and mutate an unroll step.
+  auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()];
+  auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
+  CHECK(ps);
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->transform_steps.Set(step_id,
+                              PragmaStep(ps->stage_id, ps->iter_id,
+                                         std::string("auto_unroll_max_step") + 
"$" + new_factor));
+  return ResultKind::kValid;
+}
+
+PopulationGenerationRule::ResultKind 
MutateComputeLocation::Apply(SketchPolicyNode* policy,
+                                                                  State* 
state) const {
+  if (GetIntParam(policy->params, 
SketchParamKey::disable_change_compute_location)) {
+    return ResultKind::kValid;
+  }
+
+  for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id 
>= 0; stage_id--) {
+    const Stage& stage = (*state)->stages[stage_id];
+    // Skip the inlined stages and placeholders
+    if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == 
ComputeAtKind::kInlined) {
+      continue;
+    }
+    // Skip the tiled stages
+    if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, 
stage_id)) {
+      continue;
+    }
+
+    int target_stage_id = GetSingleConsumerId(policy->search_task, *state, 
stage_id);
+    if (target_stage_id < 0) {
+      continue;
+    }
+    const Stage& target_stage = (*state)->stages[target_stage_id];
+
+    std::vector<std::pair<int, int>> candidates;
+    bool target_compute_at_other = target_stage->compute_at == 
ComputeAtKind::kIter;
+    bool target_is_tiled = IsTiled(target_stage);
+
+    bool visited_reduce = false;
+    // enumerate compute_at location at target_stage
+    // TODO(merrymercy): More analysis here to make smarter choices
+    for (size_t i = 0; i < target_stage->iters.size(); ++i) {
+      const Iterator& target_iter = target_stage->iters[i];
+      if (target_iter->iter_kind == IteratorKind::kReduction) {
+        visited_reduce = true;
+        if (!target_is_tiled) {  // Do not go into reduce iter
+          break;
+        }
+      } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
+        if (visited_reduce) {  // Do not go into inner tile
+          break;
+        }
+      }
+
+      if (target_iter->annotation == IteratorAnnotation::kUnroll) {
+        // Do not go into the unroll region of const tensor indices
+        break;
+      }
+
+      if (GetExtent(target_iter) == 1) {
+        // Skip iterators with length of 1
+        continue;
+      }
+      if (target_compute_at_other && target_iter->iter_kind == 
IteratorKind::kSpatial &&
+          StrEndsWith(target_iter->name, ".0")) {
+        // Skip the first level iterators if target stage compute_at another 
stage
+        // In this case, the lengths of first level iterators are always one
+        continue;
+      }
+      candidates.emplace_back(target_stage_id, i);
+
+      if 
((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id,
 i))) {
+        break;
+      }
+    }
+
+    // if the target_stage is already compute_at another stage X, try also 
compute_at X
+    // We call stage X as `target_target_stage`
+    if (target_compute_at_other) {
+      int target_target_stage_id;
+      target_target_stage_id = 
(*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first;
+      const Stage& target_target_stage = 
(*state)->stages[target_target_stage_id];
+
+      for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
+        const Iterator& target_target_iter = target_target_stage->iters[i];
+        if (target_target_iter->iter_kind == IteratorKind::kReduction ||
+            (*state)->attach_map->iter_to_attached_stages.count(
+                std::make_pair(target_target_stage_id, i))) {
+          break;
+        }
+
+        if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
+          // Do not go into the unroll region of const tensor indices
+          break;
+        }
+
+        if (GetExtent(target_target_iter) == 1) {  // skip iterators with 
length of 1
+          continue;
+        }
+
+        candidates.emplace_back(target_target_stage_id, i);
+      }
+    }
+
+    int choice = (policy->rand_gen)() % (candidates.size() + 2);
+
+    if (choice == 0) {
+      if (!HasReduceIter(stage)) {
+        const auto& stage_to_attach_iter = 
(*state)->attach_map->stage_to_attach_iter;
+        if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) 
{
+          state->compute_inline(stage_id);
+        }
+      }
+    } else if (choice == 1) {
+      state->compute_root(stage_id);
+    } else {
+      choice = choice - 2;
+      const Stage& stage = (*state)->stages[candidates[choice].first];
+      state->compute_at(stage_id, candidates[choice].first,
+                        stage->iters[candidates[choice].second]);
+    }
+  }
+
+  *state = policy->search_task->compute_dag.InferBound(*state);

Review comment:
       @comaniac  We should not remove it. I think it is useful in 
SampleInitPopulation and further improvement can be done here to make the 
initial population better.
   To solve the issue, we can create a new function for the common part and 
create two separate rules.
   On rule only calls the common part, while the other rule calls the common 
part + InferBound.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to