merrymercy commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461125355



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.

Review comment:
       ```suggestion
      * \param src_step_id The index of the split step to be followed in the 
history.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, 
inner_to_outer=True):
                                                      iterator, lengths, 
inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator 
structures to be same.

Review comment:
       ```suggestion
           This step is used to follow a former SplitStep and keep their 
iterator structures to be the same.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* 
stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, 
int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& 
transform_steps,
+                                              Array<Optional<Integer>>* 
lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* 
stage_to_axes,
+                                                    const Array<Step>& 
transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, 
lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& 
transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, 
lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, 
int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);

Review comment:
       Yes, it depends on the order. This is a limitation of tvm's JSON 
serialization support.
   However, as you suggested, the only two functions using this order 
(`FollowFusedSplitStep::FollowFusedSplitSte` and 
`FollowFusedSplitStepNode::WriteToRecord`) have been alerady located adjacently 
in this file.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int 
src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the 
history.

Review comment:
       Propagate this change to other places.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, 
const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       The `ps`s in different branches have different types, so they cannot be 
merged.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int 
src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the 
history.

Review comment:
       ```suggestion
      * \param src_step_ids The indices of the split steps to be followed in 
the history.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to