jiuqi-yang commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461279732



##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* 
stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, 
int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& 
transform_steps,
+                                              Array<Optional<Integer>>* 
lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;

Review comment:
       I think the following loop to get the last splitting factor for 
follow_split step is requied. Otherwise we may miss a factor.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to