comaniac commented on a change in pull request #6310:
URL: https://github.com/apache/incubator-tvm/pull/6310#discussion_r475802287



##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.cc
##########
@@ -580,5 +580,254 @@ InitPopulationRule::ResultKind 
InitVectorization::Apply(SketchPolicyNode* policy
   return ResultKind::kValid;
 }
 
+MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, 
State* state) const {
+  int max_innermost_split_factor =
+      GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
+
+  // Extract all SplitStep
+  std::vector<size_t> split_step_ids;
+  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
+    if (auto ps = (*state)->transform_steps[i].as<SplitStepNode>()) {
+      if (!ps->extent.defined() || 
!ps->extent.value()->IsInstance<IntImmNode>()) {
+        continue;
+      }
+      auto innermost_factor = 
ps->lengths.back().value_or(max_innermost_split_factor + 1);
+      if (GetIntImm(innermost_factor) <= max_innermost_split_factor) {
+        split_step_ids.push_back(i);
+      }
+    }
+  }
+  if (split_step_ids.empty()) {
+    // No tile size could be mutated.
+    return ResultKind::kInvalid;
+  }
+
+  // Select a SplitStep with extent larger than one to mutate.
+  int retry_ct = 0;
+  int64_t extent = 1;
+  int step_id;
+  const SplitStepNode* ps;
+
+  do {
+    step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()];
+    ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);
+    extent = GetIntImm(ps->extent.value());
+    retry_ct += 1;
+  } while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && (extent 
== 1 || extent == 0));
+
+  if (extent <= 1) {
+    // Cannot find a step with extent larger than one.
+    return ResultKind::kInvalid;
+  }
+
+  // Fetch the current tile sizes.
+  std::vector<int> lengths(ps->lengths.size() + 1, 1);
+  for (int i = 0; i < static_cast<int>(ps->lengths.size()); ++i) {
+    lengths[i + 1] = GetIntImm(ps->lengths[i].value());
+  }
+  lengths[0] = extent / ElementProduct(lengths);
+
+  // Random permute the tile size order.
+  std::vector<int> random_perm;
+  RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen));
+
+  // Try to divide a factor from one tile size and multiple it to another.
+  for (size_t i = 0; i < random_perm.size(); ++i) {
+    size_t src_idx = random_perm[i];
+    int length = lengths[src_idx];
+    if (length == 1) {
+      continue;
+    }
+
+    size_t dst_idx = random_perm[(i + 1) % random_perm.size()];
+    const std::vector<int>& factors = policy->split_memo.GetFactors(length);
+    CHECK_GE(factors.size(), 1);
+
+    int divide_factor;
+    if (dst_idx == lengths.size() - 1) {
+      // Maintain the restriction of 
hardware_params.max_innermost_split_factor.
+      int max_factor_index = static_cast<int>(factors.size()) - 1;
+      for (; max_factor_index >= 1; max_factor_index--) {
+        if (factors[max_factor_index] * lengths[dst_idx] <= 
max_innermost_split_factor) {
+          break;
+        }
+      }
+      if (max_factor_index == 0) {
+        // Failed on this dst_idx, try next one.
+        continue;
+      }
+      divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)];
+    } else {
+      divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)];
+    }
+
+    // Divide one factor from lengths[src_idx] and multiply it to 
lengths[dst_idx].
+    Array<Integer> new_lengths;
+    for (size_t j = 1; j < lengths.size(); ++j) {
+      if (j == src_idx) {
+        new_lengths.push_back(Integer(lengths[j] / divide_factor));
+      } else if (j == dst_idx) {
+        new_lengths.push_back(Integer(lengths[j] * divide_factor));
+      } else {
+        new_lengths.push_back(Integer(lengths[j]));
+      }
+    }
+
+    StateNode* pstate = state->CopyOnWrite();
+    pstate->transform_steps.Set(
+        step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
+                           Array<Optional<Integer>>(new_lengths.begin(), 
new_lengths.end()),
+                           ps->inner_to_outer));
+    return ResultKind::kValid;
+  }
+  return ResultKind::kInvalid;
+}
+
+MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
+                                                      State* state) const {
+  // Extract all auto_unroll_max_step pragma steps.
+  std::vector<int> annotate_steps;
+  for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
+    if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
+      if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
+        annotate_steps.push_back(i);
+      }
+    }
+  }
+  if (annotate_steps.empty()) {
+    return ResultKind::kInvalid;
+  }
+
+  // Random pick up one unroll factor candidate.
+  auto cands = (IsGPUTask(policy->search_task))? &gpu_unroll_cands_: 
&cpu_unroll_cands_;
+  auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % 
cands->size()]);
+
+  // Random pick up and mutate an unroll step.
+  auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()];
+  auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
+  CHECK(ps);
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->transform_steps.Set(step_id,
+                              PragmaStep(ps->stage_id, ps->iter_id,
+                                         std::string("auto_unroll_max_step") + 
"$" + new_factor));
+  return ResultKind::kValid;
+}
+
+MutationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
+                                                      State* state) const {
+  // FIXME (@comaniac, @jc94): Combine initial population rules with the 
mutation rules.
+  static InitChangeComputeLocation mutate_compute_location;
+  if (mutate_compute_location.Apply(policy, state) == 
InitPopulationRule::ResultKind::kInvalid) {
+    return ResultKind::kInvalid;
+  }
+  return ResultKind::kValid;
+}

Review comment:
       Since it's more like a mutation rule instead of an initial rule, I moved 
it to `MutateComputeLocationRule`. Please check the latest commit to see if 
that makes sense.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to