junrushao1994 commented on a change in pull request #8767:
URL: https://github.com/apache/tvm/pull/8767#discussion_r691686504
##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -385,6 +540,108 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>&
loop_srefs) {
return self->stmt2ref.at(new_stmt.get());
}
+void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
+ std::unordered_set<const StmtSRefNode*> loop_srefs;
+ loop_srefs.reserve(ordered_loop_srefs.size());
+ if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) {
+ return;
+ }
+ // Step 1. check uniqueness
+ for (const StmtSRef loop_sref : ordered_loop_srefs) {
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ // uniqueness check
+ auto inserted = loop_srefs.insert(loop_sref.get());
+ if (!inserted.second) {
+ throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+ }
+ }
+ // Step 2. gather loops to be reordered
+ // The algorithm is to scan the inverse preorder of the whole loop tree in
the scope.
+ // For some Loop x, it is potentially in the reorder range if
+ // - x is in the reorder list
+ // - x has only one child which is a loop and is potentially in the
reorder range
+ // After the inverse DFS, we can know the exact reorder range
+ // `top` and `bottom` denote the boundary of the loop range that need
reordering
+ const StmtSRefNode* top = nullptr;
+ const StmtSRefNode* bottom = nullptr;
+ // Maps a parent sref to its child sref
+ std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor;
+ int n_loops_not_found = ordered_loop_srefs.size();
+ // Gather all the loops under the block scope
+ std::vector<const StmtSRefNode*> inverse_preorder_loops =
GetLoopsInversePreOrderUnderScope(
+ self, GetScopeRoot(self, ordered_loop_srefs[0],
/*require_stage_pipeline=*/true));
+ for (const StmtSRefNode* loop : inverse_preorder_loops) {
+ bool is_in_reorder_list = loop_srefs.count(loop);
+ bool has_successor_in_reorder_list = successor.count(loop);
+ if (is_in_reorder_list || has_successor_in_reorder_list) {
+ const StmtSRefNode* parent = loop->parent;
+ // If the successor of `parent` exists, then `parent` can't be a
single-branch loop
+ auto inserted = successor.insert({parent, loop});
+ if (!inserted.second) {
+ throw LoopsNotALineError(self->mod, GetRef<Stmt>(parent->stmt),
+ LoopsNotALineError::kHaveNonSingleBranchStmt);
+ }
+ // `bottom` is the first loop encountered
+ if (bottom == nullptr) {
+ bottom = loop;
+ }
+ // `top` is the last loop encountered
+ if (is_in_reorder_list) {
+ top = loop;
+ --n_loops_not_found;
+ }
+ }
+ }
Review comment:
I have some reservation on the overall design of the first two steps.
Ideally we do not need to visit the entire scope, and no need to use
complicated code to extra the relationship between these loop nests.
Here is what I am proposing to do:
1. Have a class called `LoopNest`
2. `LoopNest` is initialized by a list of loops
3. `LoopNest` stores a consecutive chain of loops, where each element is the
only child of the previous one in the list.
4. Reordering is basically permuting the loop nest
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]