jcf94 commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r457782201



##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -244,285 +197,73 @@ Iterator State::unroll(int stage_id, const Iterator& it, 
int max_unroll) {
   AnnotationStep step =
       AnnotationStep(stage_id, GetIndex(stage->iters, it), 
IteratorAnnotation::kUnroll);
   CopyOnWrite()->transform_steps.push_back(step);
-  return DoAnnotationStep(step);
+  return step->ApplyToState(this);
 }
 
-Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation 
thread_type) {
+Iterator State::vectorize(int stage_id, const Iterator& it) {
   const Stage& stage = operator->()->stages[stage_id];
-  if (thread_type < IteratorAnnotation::kVThread || thread_type > 
IteratorAnnotation::kThreadZ) {
-    LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
-               << "kThreadX, kThreadY, kBlockZ, kThreadZ";
-  }
-  AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), 
thread_type);
+  AnnotationStep step =
+      AnnotationStep(stage_id, GetIndex(stage->iters, it), 
IteratorAnnotation::kVectorize);
   CopyOnWrite()->transform_steps.push_back(step);
-  return DoAnnotationStep(step);
-}
-
-/********** Step implementations for state **********/
-void State::DoReorderStep(const ReorderStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-  Array<Iterator> iters;
-  for (auto x : step->after_ids) {
-    iters.push_back(stage->iters[x]);
-  }
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(step->stage_id,
-                     Stage(stage->op, stage->op_type, iters, 
stage->compute_at, stage->attrs));
+  return step->ApplyToState(this);
 }
 
-void State::DoComputeAtStep(const ComputeAtStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-
-  // Remove the bound information of each iterator since they may not be 
accurate after
-  // compute at
-  Array<Iterator> new_iters;
-  for (const Iterator& it : stage->iters) {
-    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, 
it->annotation));
-  }
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, 
std::move(new_iters),
-                                           ComputeAtKind::kIter, 
stage->attrs));
-  // Update attach map
-  pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, 
step->target_iter_id);
-}
-
-void State::DoComputeRootStep(const ComputeRootStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-
-  // Remove the bound information of each iterator since they may not be 
accurate after
-  // compute root
-  Array<Iterator> new_iters;
-  for (const Iterator& it : stage->iters) {
-    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, 
it->annotation));
-  }
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, 
std::move(new_iters),
-                                           ComputeAtKind::kRoot, 
stage->attrs));
-  // Update attach map
-  pstate->attach_map.DeleteStage(step->stage_id);
+Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
+  const Stage& stage = operator->()->stages[stage_id];
+  Array<Integer> indices;
+  GetIndices(stage->iters, iters, &indices);
+  FuseStep step = FuseStep(stage_id, indices);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
 }
 
-void State::DoComputeInlineStep(const ComputeInlineStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-
-  // Check the validity of compute_inline
-  for (size_t i = 0; i < stage->iters.size(); ++i) {
-    CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count(
-                 std::make_pair(step->stage_id, i)),
-             0)
-        << "Invalid compute_inline: There are some other stages that are 
attached to the "
-        << "target stage";
-  }
-
-  StateNode* pstate = CopyOnWrite();
-  auto new_stage = pstate->stages[step->stage_id];
-  new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined;
-  pstate->stages.Set(step->stage_id, std::move(new_stage));
-  // Update attach map
-  pstate->attach_map.DeleteStage(step->stage_id);
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+  const Stage& stage = operator->()->stages[stage_id];
+  CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+                                              << "should be specified";
+  Array<Integer> after_ids;
+  GetIndices(stage->iters, order, &after_ids);
+  ReorderStep step = ReorderStep(stage_id, after_ids);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
-Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
-                                         const Array<Optional<Integer>>& 
lengths,
-                                         bool inner_to_outer) {
+Array<Iterator> State::split(int stage_id, const Iterator& it,
+                             const Array<Optional<Integer>>& lengths, bool 
inner_to_outer) {
   const Stage& stage = operator->()->stages[stage_id];
-  const Iterator& it = stage->iters[iter_id];
-  size_t old_iter_size = stage->iters.size();
-  bool concrete = true;
-
-  Optional<PrimExpr> tosplit_min, tosplit_extent;
-  if (it->range.defined()) {
-    tosplit_min = it->range->min;
-    tosplit_extent = it->range->extent;
-  } else {
-    tosplit_min = NullOpt;
-    tosplit_extent = NullOpt;
-  }
-
-  Array<Iterator> outs;
-  for (size_t i = 0; i < lengths.size(); ++i) {
-    Optional<Integer> l;
-    String name;
-    if (inner_to_outer) {
-      l = lengths[lengths.size() - i - 1];
-      name = it->name + "." + std::to_string(lengths.size() - i);
-    } else {
-      l = lengths[i];
-      name = it->name + "." + std::to_string(i);
-    }
-    Iterator res;
-    if (l && tosplit_min && tosplit_extent) {
-      res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), 
l.value()), it->iter_kind,
-                     IteratorAnnotation::kNone);
-      tosplit_min = Integer(0);
-      tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, 
l.value());
-    } else {
-      res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
-      tosplit_min = NullOpt;
-      tosplit_extent = NullOpt;
-      concrete = false;
-    }
-    outs.push_back(std::move(res));
-  }
-
-  Range range;
-  if (tosplit_min && tosplit_extent) {
-    range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
-  }
-  if (inner_to_outer) {
-    outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, 
IteratorAnnotation::kNone));
-    // Reverse the Iterator array
-    Array<Iterator> temp(outs.rbegin(), outs.rend());
-    outs = std::move(temp);
-  } else {
-    outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), 
range, it->iter_kind,
-                            IteratorAnnotation::kNone));
-  }
-
-  Array<Iterator> new_iters;
-  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() 
+ iter_id);
-  new_iters.insert(new_iters.end(), outs.begin(), outs.end());
-  new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, 
stage->iters.end());
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(stage_id,
-                     Stage(stage->op, stage->op_type, new_iters, 
stage->compute_at, stage->attrs));
-  pstate->concrete &= concrete;
-
-  // Two vectors are used to represent the iterator relation before and after 
split
-  // The original iterators in AttachMap will be updated with the new iterators
-  std::vector<IterKey> from_iters;
-  std::vector<IterKey> to_iters;
-  for (size_t i = iter_id; i < old_iter_size; ++i) {
-    from_iters.emplace_back(stage_id, i);
-    to_iters.emplace_back(stage_id, i + lengths.size());
-  }
-  pstate->attach_map.UpdateIters(from_iters, to_iters);
-
-  return outs;
+  SplitStep step =
+      SplitStep(stage_id, GetIndex(stage->iters, it),
+                it->range.defined() ? it->range->extent : PrimExpr(), lengths, 
inner_to_outer);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
 }
 
-Array<Iterator> State::DoSplitStep(const SplitStep& step) {
-  return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, 
step->inner_to_outer);
+void State::compute_at(int stage_id, int target_stage_id, const Iterator& 
target_iter) {
+  const Stage& target_stage = operator->()->stages[target_stage_id];
+  ComputeAtStep step =
+      ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, 
target_iter));
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-Iterator State::DoFuseStep(const FuseStep& step) {
-  int stage_id = step->stage_id;
-  const Stage& stage = operator->()->stages[stage_id];
-  size_t old_iter_size = static_cast<int>(stage->iters.size());
-
-  String new_name;
-  PrimExpr new_extent = 1;
-  IteratorKind new_iter_kind = IteratorKind::kSpecial;
-
-  for (size_t i = 0; i < step->fused_ids.size(); ++i) {
-    if (i > 0) {
-      CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
-    }
-
-    if (i != step->fused_ids.size() - 1) {
-      const auto& iter_to_attached_stage = 
operator->()->attach_map->iter_to_attached_stages;
-      if (iter_to_attached_stage.find(std::make_pair(stage_id, 
step->fused_ids[i])) !=
-          iter_to_attached_stage.end()) {
-        LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been 
attached by some "
-                   << "stages. State before fusion:\n"
-                   << *this;
-      }
-    }
-
-    const Iterator& it = stage->iters[step->fused_ids[i]];
-    new_name = new_name + it->name + "@";
-
-    if (it->range.defined() && new_extent.defined()) {
-      new_extent = new_extent * it->range->extent;
-    } else {
-      new_extent = PrimExpr();
-    }
-
-    if (i == 0) {
-      new_iter_kind = it->iter_kind;
-    } else {
-      if (new_iter_kind != it->iter_kind) {
-        new_iter_kind = IteratorKind::kMixed;
-      }
-    }
-  }
-
-  Range range;
-  if (new_extent.defined()) {
-    range = Range::FromMinExtent(0, new_extent);
-  }
-  Iterator new_it = Iterator(new_name, range, new_iter_kind, 
IteratorAnnotation::kNone);
-  Array<Iterator> new_iters;
-  new_iters.insert(new_iters.end(), stage->iters.begin(),
-                   stage->iters.begin() + step->fused_ids.front());
-  new_iters.push_back(new_it);
-  new_iters.insert(new_iters.end(), stage->iters.begin() + 
step->fused_ids.back() + 1,
-                   stage->iters.end());
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(stage_id,
-                     Stage(stage->op, stage->op_type, new_iters, 
stage->compute_at, stage->attrs));
-
-  // Two vectors are used to represent the iterator relation before and after 
fuse
-  // The original iterators in AttachMap will be updated with the new iterators
-  std::vector<IterKey> from_iters;
-  std::vector<IterKey> to_iters;
-  const size_t begin_id = step->fused_ids.front(), end_id = 
step->fused_ids.back();
-  for (size_t i = 0; i < old_iter_size; ++i) {
-    if (i <= begin_id) {
-      continue;
-    } else if (i > end_id) {
-      // move forward
-      from_iters.emplace_back(stage_id, i);
-      to_iters.emplace_back(stage_id, i - end_id + begin_id);
-    } else {
-      // move to the fused id
-      from_iters.emplace_back(stage_id, i);
-      to_iters.emplace_back(stage_id, begin_id);
-    }
-  }
-  pstate->attach_map.UpdateIters(from_iters, to_iters);
-
-  return new_it;
+void State::compute_inline(int stage_id) {
+  ComputeInlineStep step = ComputeInlineStep(stage_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-Iterator State::DoAnnotationStep(const AnnotationStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-  Iterator it = stage->iters[step->iter_id];
-
-  CHECK(it->annotation == IteratorAnnotation::kNone);
-  Iterator new_it = Iterator(it->name, it->range, it->iter_kind, 
step->annotation);
-  Stage new_stage = stage;
-  new_stage.CopyOnWrite()->iters.Set(step->iter_id, new_it);
-  CopyOnWrite()->stages.Set(step->stage_id, std::move(new_stage));
-  return new_it;
+void State::compute_root(int stage_id) {
+  ComputeRootStep step = ComputeRootStep(stage_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-void State::DoSteps(const ComputeDAG& dag) {
+void State::ApplySteps(const ComputeDAG& dag) {
   CHECK(operator->()->stages.size()) << "Invalid State with empty operation 
stages.";
 
+  // Call each step's ApplyToState method
   for (const auto& step : operator->()->transform_steps) {
-    if (auto ps = step.as<ReorderStepNode>()) {
-      DoReorderStep(GetRef<ReorderStep>(ps));
-    } else if (auto ps = step.as<ComputeAtStepNode>()) {
-      DoComputeAtStep(GetRef<ComputeAtStep>(ps));
-    } else if (auto ps = step.as<ComputeRootStepNode>()) {
-      DoComputeRootStep(GetRef<ComputeRootStep>(ps));
-    } else if (auto ps = step.as<ComputeInlineStepNode>()) {
-      DoComputeInlineStep(GetRef<ComputeInlineStep>(ps));
-    } else if (auto ps = step.as<SplitStepNode>()) {
-      DoSplitStep(GetRef<SplitStep>(ps));
-    } else if (auto ps = step.as<FuseStepNode>()) {
-      DoFuseStep(GetRef<FuseStep>(ps));
-    } else if (auto ps = step.as<AnnotationStepNode>()) {
-      DoAnnotationStep(GetRef<AnnotationStep>(ps));
-    } else {
-      LOG(FATAL) << "Invalid step: " << step;
-    }
+    StepApplyToState(step, this, dag);

Review comment:
       The `step->ApplyToState()` is actually not a virtual function(the return 
value of different steps may be different), so I add the `StepApplyToState` to 
pack all those step type check to `transform_steps.cc`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to