merrymercy commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r457773012
##########
File path: src/auto_scheduler/measure_record.cc
##########
@@ -82,98 +63,23 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::Array<::tvm::auto_scheduler::Step>&
data) {
writer->BeginArray(false);
- for (size_t i = 0; i < data.size(); ++i) {
+ for (const auto& step : data) {
writer->WriteArraySeperator();
writer->BeginArray(false);
- if (auto ps = data[i].as<::tvm::auto_scheduler::ReorderStepNode>()) {
- writer->WriteArrayItem(std::string("RE"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(IntArrayToVector(ps->after_ids));
- } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>())
{
- writer->WriteArrayItem(std::string("SP"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(ps->iter_id);
- writer->WriteArrayItem(ps->extent ?
::tvm::auto_scheduler::GetIntImm(ps->extent.value())
- : 0);
- writer->WriteArrayItem(IntArrayToVector(ps->lengths));
- writer->WriteArrayItem(static_cast<int>(ps->inner_to_outer));
- } else if (auto ps = data[i].as<::tvm::auto_scheduler::FuseStepNode>()) {
- writer->WriteArrayItem(std::string("FU"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(IntArrayToVector(ps->fused_ids));
- } else {
- LOG(FATAL) << "Invalid step: " << data[i];
- }
+ step->WriteToRecord(writer);
writer->EndArray();
}
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::Array<::tvm::auto_scheduler::Step>* data) {
- std::vector<int> int_list;
- bool s, inner_to_outer;
- std::string name, scope_name, pragma_type, ti_func_name;
- int stage_id, iter_id, extent;
-
reader->BeginArray();
data->clear();
while (reader->NextArrayItem()) {
reader->BeginArray();
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&name);
- if (name == "RE") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> after_ids;
- for (const auto& i : int_list) {
- after_ids.push_back(i);
- }
- data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id,
after_ids));
- } else if (name == "SP") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&iter_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&extent);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&inner_to_outer);
- ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
- for (const auto& i : int_list) {
- lengths.push_back(::tvm::Integer(i));
- }
- data->push_back(::tvm::auto_scheduler::SplitStep(
- stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent,
lengths, inner_to_outer));
- } else if (name == "FU") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> fused_ids;
- for (const auto& i : int_list) {
- fused_ids.push_back(i);
- }
- data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids));
- } else {
- LOG(FATAL) << "Invalid step format";
- }
- s = reader->NextArrayItem();
- CHECK(!s);
+ data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
+ CHECK(!reader->NextArrayItem());
Review comment:
`NextArrayItem` has side effects (i.e. moving the cursor of the reader),
so I don't want to put it into the debug-oriented `CHECK`. In my opinion, the
code should run correctly if we delete all `CHECK`.
##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -244,285 +197,73 @@ Iterator State::unroll(int stage_id, const Iterator& it,
int max_unroll) {
AnnotationStep step =
AnnotationStep(stage_id, GetIndex(stage->iters, it),
IteratorAnnotation::kUnroll);
CopyOnWrite()->transform_steps.push_back(step);
- return DoAnnotationStep(step);
+ return step->ApplyToState(this);
}
-Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation
thread_type) {
+Iterator State::vectorize(int stage_id, const Iterator& it) {
const Stage& stage = operator->()->stages[stage_id];
- if (thread_type < IteratorAnnotation::kVThread || thread_type >
IteratorAnnotation::kThreadZ) {
- LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
- << "kThreadX, kThreadY, kBlockZ, kThreadZ";
- }
- AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it),
thread_type);
+ AnnotationStep step =
+ AnnotationStep(stage_id, GetIndex(stage->iters, it),
IteratorAnnotation::kVectorize);
CopyOnWrite()->transform_steps.push_back(step);
- return DoAnnotationStep(step);
-}
-
-/********** Step implementations for state **********/
-void State::DoReorderStep(const ReorderStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
- Array<Iterator> iters;
- for (auto x : step->after_ids) {
- iters.push_back(stage->iters[x]);
- }
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(step->stage_id,
- Stage(stage->op, stage->op_type, iters,
stage->compute_at, stage->attrs));
+ return step->ApplyToState(this);
}
-void State::DoComputeAtStep(const ComputeAtStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
-
- // Remove the bound information of each iterator since they may not be
accurate after
- // compute at
- Array<Iterator> new_iters;
- for (const Iterator& it : stage->iters) {
- new_iters.push_back(Iterator(it->name, Range(), it->iter_kind,
it->annotation));
- }
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type,
std::move(new_iters),
- ComputeAtKind::kIter,
stage->attrs));
- // Update attach map
- pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id,
step->target_iter_id);
-}
-
-void State::DoComputeRootStep(const ComputeRootStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
-
- // Remove the bound information of each iterator since they may not be
accurate after
- // compute root
- Array<Iterator> new_iters;
- for (const Iterator& it : stage->iters) {
- new_iters.push_back(Iterator(it->name, Range(), it->iter_kind,
it->annotation));
- }
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type,
std::move(new_iters),
- ComputeAtKind::kRoot,
stage->attrs));
- // Update attach map
- pstate->attach_map.DeleteStage(step->stage_id);
+Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
+ const Stage& stage = operator->()->stages[stage_id];
+ Array<Integer> indices;
+ GetIndices(stage->iters, iters, &indices);
+ FuseStep step = FuseStep(stage_id, indices);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
}
-void State::DoComputeInlineStep(const ComputeInlineStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
-
- // Check the validity of compute_inline
- for (size_t i = 0; i < stage->iters.size(); ++i) {
- CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count(
- std::make_pair(step->stage_id, i)),
- 0)
- << "Invalid compute_inline: There are some other stages that are
attached to the "
- << "target stage";
- }
-
- StateNode* pstate = CopyOnWrite();
- auto new_stage = pstate->stages[step->stage_id];
- new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined;
- pstate->stages.Set(step->stage_id, std::move(new_stage));
- // Update attach map
- pstate->attach_map.DeleteStage(step->stage_id);
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+ const Stage& stage = operator->()->stages[stage_id];
+ CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+ << "should be specified";
+ Array<Integer> after_ids;
+ GetIndices(stage->iters, order, &after_ids);
+ ReorderStep step = ReorderStep(stage_id, after_ids);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
-Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
- const Array<Optional<Integer>>&
lengths,
- bool inner_to_outer) {
+Array<Iterator> State::split(int stage_id, const Iterator& it,
+ const Array<Optional<Integer>>& lengths, bool
inner_to_outer) {
const Stage& stage = operator->()->stages[stage_id];
- const Iterator& it = stage->iters[iter_id];
- size_t old_iter_size = stage->iters.size();
- bool concrete = true;
-
- Optional<PrimExpr> tosplit_min, tosplit_extent;
- if (it->range.defined()) {
- tosplit_min = it->range->min;
- tosplit_extent = it->range->extent;
- } else {
- tosplit_min = NullOpt;
- tosplit_extent = NullOpt;
- }
-
- Array<Iterator> outs;
- for (size_t i = 0; i < lengths.size(); ++i) {
- Optional<Integer> l;
- String name;
- if (inner_to_outer) {
- l = lengths[lengths.size() - i - 1];
- name = it->name + "." + std::to_string(lengths.size() - i);
- } else {
- l = lengths[i];
- name = it->name + "." + std::to_string(i);
- }
- Iterator res;
- if (l && tosplit_min && tosplit_extent) {
- res = Iterator(name, Range::FromMinExtent(tosplit_min.value(),
l.value()), it->iter_kind,
- IteratorAnnotation::kNone);
- tosplit_min = Integer(0);
- tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1,
l.value());
- } else {
- res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
- tosplit_min = NullOpt;
- tosplit_extent = NullOpt;
- concrete = false;
- }
- outs.push_back(std::move(res));
- }
-
- Range range;
- if (tosplit_min && tosplit_extent) {
- range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
- }
- if (inner_to_outer) {
- outs.push_back(Iterator(it->name + ".0", range, it->iter_kind,
IteratorAnnotation::kNone));
- // Reverse the Iterator array
- Array<Iterator> temp(outs.rbegin(), outs.rend());
- outs = std::move(temp);
- } else {
- outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()),
range, it->iter_kind,
- IteratorAnnotation::kNone));
- }
-
- Array<Iterator> new_iters;
- new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin()
+ iter_id);
- new_iters.insert(new_iters.end(), outs.begin(), outs.end());
- new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1,
stage->iters.end());
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(stage_id,
- Stage(stage->op, stage->op_type, new_iters,
stage->compute_at, stage->attrs));
- pstate->concrete &= concrete;
-
- // Two vectors are used to represent the iterator relation before and after
split
- // The original iterators in AttachMap will be updated with the new iterators
- std::vector<IterKey> from_iters;
- std::vector<IterKey> to_iters;
- for (size_t i = iter_id; i < old_iter_size; ++i) {
- from_iters.emplace_back(stage_id, i);
- to_iters.emplace_back(stage_id, i + lengths.size());
- }
- pstate->attach_map.UpdateIters(from_iters, to_iters);
-
- return outs;
+ SplitStep step =
+ SplitStep(stage_id, GetIndex(stage->iters, it),
+ it->range.defined() ? it->range->extent : PrimExpr(), lengths,
inner_to_outer);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
}
-Array<Iterator> State::DoSplitStep(const SplitStep& step) {
- return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths,
step->inner_to_outer);
+void State::compute_at(int stage_id, int target_stage_id, const Iterator&
target_iter) {
+ const Stage& target_stage = operator->()->stages[target_stage_id];
+ ComputeAtStep step =
+ ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters,
target_iter));
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-Iterator State::DoFuseStep(const FuseStep& step) {
- int stage_id = step->stage_id;
- const Stage& stage = operator->()->stages[stage_id];
- size_t old_iter_size = static_cast<int>(stage->iters.size());
-
- String new_name;
- PrimExpr new_extent = 1;
- IteratorKind new_iter_kind = IteratorKind::kSpecial;
-
- for (size_t i = 0; i < step->fused_ids.size(); ++i) {
- if (i > 0) {
- CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
- }
-
- if (i != step->fused_ids.size() - 1) {
- const auto& iter_to_attached_stage =
operator->()->attach_map->iter_to_attached_stages;
- if (iter_to_attached_stage.find(std::make_pair(stage_id,
step->fused_ids[i])) !=
- iter_to_attached_stage.end()) {
- LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been
attached by some "
- << "stages. State before fusion:\n"
- << *this;
- }
- }
-
- const Iterator& it = stage->iters[step->fused_ids[i]];
- new_name = new_name + it->name + "@";
-
- if (it->range.defined() && new_extent.defined()) {
- new_extent = new_extent * it->range->extent;
- } else {
- new_extent = PrimExpr();
- }
-
- if (i == 0) {
- new_iter_kind = it->iter_kind;
- } else {
- if (new_iter_kind != it->iter_kind) {
- new_iter_kind = IteratorKind::kMixed;
- }
- }
- }
-
- Range range;
- if (new_extent.defined()) {
- range = Range::FromMinExtent(0, new_extent);
- }
- Iterator new_it = Iterator(new_name, range, new_iter_kind,
IteratorAnnotation::kNone);
- Array<Iterator> new_iters;
- new_iters.insert(new_iters.end(), stage->iters.begin(),
- stage->iters.begin() + step->fused_ids.front());
- new_iters.push_back(new_it);
- new_iters.insert(new_iters.end(), stage->iters.begin() +
step->fused_ids.back() + 1,
- stage->iters.end());
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(stage_id,
- Stage(stage->op, stage->op_type, new_iters,
stage->compute_at, stage->attrs));
-
- // Two vectors are used to represent the iterator relation before and after
fuse
- // The original iterators in AttachMap will be updated with the new iterators
- std::vector<IterKey> from_iters;
- std::vector<IterKey> to_iters;
- const size_t begin_id = step->fused_ids.front(), end_id =
step->fused_ids.back();
- for (size_t i = 0; i < old_iter_size; ++i) {
- if (i <= begin_id) {
- continue;
- } else if (i > end_id) {
- // move forward
- from_iters.emplace_back(stage_id, i);
- to_iters.emplace_back(stage_id, i - end_id + begin_id);
- } else {
- // move to the fused id
- from_iters.emplace_back(stage_id, i);
- to_iters.emplace_back(stage_id, begin_id);
- }
- }
- pstate->attach_map.UpdateIters(from_iters, to_iters);
-
- return new_it;
+void State::compute_inline(int stage_id) {
+ ComputeInlineStep step = ComputeInlineStep(stage_id);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-Iterator State::DoAnnotationStep(const AnnotationStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
- Iterator it = stage->iters[step->iter_id];
-
- CHECK(it->annotation == IteratorAnnotation::kNone);
- Iterator new_it = Iterator(it->name, it->range, it->iter_kind,
step->annotation);
- Stage new_stage = stage;
- new_stage.CopyOnWrite()->iters.Set(step->iter_id, new_it);
- CopyOnWrite()->stages.Set(step->stage_id, std::move(new_stage));
- return new_it;
+void State::compute_root(int stage_id) {
+ ComputeRootStep step = ComputeRootStep(stage_id);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-void State::DoSteps(const ComputeDAG& dag) {
+void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation
stages.";
+ // Call each step's ApplyToState method
for (const auto& step : operator->()->transform_steps) {
- if (auto ps = step.as<ReorderStepNode>()) {
- DoReorderStep(GetRef<ReorderStep>(ps));
- } else if (auto ps = step.as<ComputeAtStepNode>()) {
- DoComputeAtStep(GetRef<ComputeAtStep>(ps));
- } else if (auto ps = step.as<ComputeRootStepNode>()) {
- DoComputeRootStep(GetRef<ComputeRootStep>(ps));
- } else if (auto ps = step.as<ComputeInlineStepNode>()) {
- DoComputeInlineStep(GetRef<ComputeInlineStep>(ps));
- } else if (auto ps = step.as<SplitStepNode>()) {
- DoSplitStep(GetRef<SplitStep>(ps));
- } else if (auto ps = step.as<FuseStepNode>()) {
- DoFuseStep(GetRef<FuseStep>(ps));
- } else if (auto ps = step.as<AnnotationStepNode>()) {
- DoAnnotationStep(GetRef<AnnotationStep>(ps));
- } else {
- LOG(FATAL) << "Invalid step: " << step;
- }
+ StepApplyToState(step, this, dag);
Review comment:
1. We can remove `StepApplyToState`
2. `dag` will be used in other steps such as cache_write and rfactor.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]