merrymercy commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r459810569



##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,272 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, 
CacheWriteStep,
+ * RfactorStep). This will filter out all steps that can change the stages of 
ComputeDAG.
+ */
+Array<Step> GetStageModifiableSteps(Step current_step, const Array<Step>& 
transform_steps) {
+  Array<Step> ret_steps;
+  for (const Step& step : transform_steps) {
+    if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+      ret_steps.push_back(step);
+    }
+    // TODO(jcf94): add rfactor support
+    if (step.same_as(current_step)) {
+      break;
+    }
+  }
+  return ret_steps;
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const 
{
+  StateNode* pstate = state->CopyOnWrite();
+  const ComputeDAG& current_compute_dag =
+      dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef<Step>(this), 
(*state)->transform_steps));
+
+  // target_stage -> target_stage + target_store
+  // Update the op of the target stage, insert a new cache read stage behind, 
update the op of
+  // later stages, then update the stage_id mapping in AttachMap
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, 
readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, 
StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Since the original stage will be changed after schedule apply, keep a 
copy here
+  // These information will be used to print Python API string later
+  auto stage = (*stages)[stage_id];
+  Array<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name 
<< "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  // Print the iterators of the new added stage
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) 
const {
+  StateNode* pstate = state->CopyOnWrite();
+  int last_dag_op_size = pstate->current_compute_dag
+                             ? 
pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag =
+      dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef<Step>(this), 
(*state)->transform_steps));
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  // TODO(jcf94): Update this check to equal after fixing the cache write bug 
in TVM

Review comment:
       Yeah, I can open an issue later.
   For this specific line, this is a bug of tvm and this feature is not used in 
Ansor. So it is okay to skip this.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to