jcf94 commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r457796865
##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -51,129 +53,539 @@ const char* IteratorAnnotationString[] = {
"tensorize" // kTensorized = 11
};
-/********** Reorder **********/
-ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
- auto node = make_object<ReorderStepNode>();
- node->stage_id = stage_id;
- for (const auto& x : after_ids) {
- CHECK(x->IsInstance<IntImmNode>());
+Step StepReadFromRecord(dmlc::JSONReader* reader) {
+ std::string name;
+ CHECK(reader->NextArrayItem());
+ reader->Read(&name);
+ if (name == AnnotationStepNode::record_prefix_str) {
+ return AnnotationStep(reader);
+ } else if (name == FuseStepNode::record_prefix_str) {
+ return FuseStep(reader);
+ } else if (name == ReorderStepNode::record_prefix_str) {
+ return ReorderStep(reader);
+ } else if (name == SplitStepNode::record_prefix_str) {
+ return SplitStep(reader);
+ } else if (name == ComputeAtStepNode::record_prefix_str) {
+ return ComputeAtStep(reader);
+ } else if (name == ComputeInlineStepNode::record_prefix_str) {
+ return ComputeInlineStep(reader);
+ } else if (name == ComputeRootStepNode::record_prefix_str) {
+ return ComputeRootStep(reader);
+ } else {
+ LOG(FATAL) << "Invalid step format: " << name;
}
- node->after_ids = after_ids;
- data_ = std::move(node);
+ return Step();
}
-void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) const {
- auto stage = (*stages)[stage_id];
- const Array<IterVar>& axes = stage_to_axes->at(stage);
- CHECK_EQ(after_ids.size(), axes.size());
-
- Array<IterVar> new_axes;
- new_axes.reserve(axes.size());
- for (auto i : after_ids) {
- new_axes.push_back(axes[i]);
+void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
+ if (auto ps = step.as<AnnotationStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ReorderStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeAtStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeRootStepNode>()) {
+ ps->ApplyToState(state);
+ } else {
+ LOG(FATAL) << "Invalid step: " << step;
}
- stage.reorder(new_axes);
-
- stage_to_axes->Set(stage, std::move(new_axes));
- stages->Set(stage_id, std::move(stage));
}
-String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) const {
- const auto& stage = (*stages)[stage_id];
- std::stringstream ss;
-
- ss << "s[" << CleanName(stage->op->name) << "].reorder(";
- for (size_t i = 0; i < after_ids.size(); ++i) {
- ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint);
- if (i != after_ids.size() - 1) {
- ss << ", ";
- }
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
Review comment:
Emm, `step->ApplyToState`, `step->ApplyToSchedule` &
`step->PrintAsPythonAPI` all have problems to be written to virtual functions,
we can later try to see if there's any better way to deal with this.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]