junrushao1994 commented on a change in pull request #8767:
URL: https://github.com/apache/tvm/pull/8767#discussion_r693398936



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   return self->stmt2ref.at(new_stmt.get());
 }
 
+void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) {
+    return;
+  }
+  // Step 1. check uniqueness
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+    // uniqueness check
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  // Step 2. gather loops to be reordered
+  // For each loop, traverse upwards along the parent pointer, and stop on 
either a block, or a
+  // previously-visited loop
+  // - the top of the reorder range is the last loop visited in the first 
traverse which exists in
+  //   the input array
+  // - the bottom of the reorder range is the last loop in the input array 
which is not visited in
+  // the previous traverses
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = nullptr;
+  // Maps a parent sref to its child sref
+  std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor;
+  for (size_t i = 0; i < ordered_loop_srefs.size(); i++) {
+    const StmtSRefNode* sref = ordered_loop_srefs[i].get();
+    // if sref is not visited before, update `bottom`
+    if (!successor.count(sref->parent)) {
+      bottom = sref;
+    }
+    while (true) {
+      // stop at blocknode
+      if (sref->stmt->IsInstance<BlockNode>()) {
+        if (i != 0) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        } else {
+          break;
+        }
+      }
+      const StmtSRefNode* parent_sref = sref->parent;
+      // stop at previously-visited loop
+      if (successor.count(parent_sref)) {
+        if (successor[parent_sref] == sref) {
+          break;
+        } else {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+      } else {
+        successor[parent_sref] = sref;
+      }
+      // if it's the first traverse and the loop is in the input array, update 
`top`
+      if (loop_srefs.count(sref) && i == 0) {
+        top = sref;
+      }
+      sref = parent_sref;
+    }
+  }
+  // Step 3. Check that loops are single-branch
+  const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, 
GetRef<StmtSRef>(top));
+  for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) {
+    loop_sref = successor[loop_sref];
+    const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, 
GetRef<StmtSRef>(loop_sref));
+    if (outer_loop->body.get() != inner_loop) {
+      throw LoopsNotAChainError(self->mod, GetRef<For>(outer_loop),
+                                
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+    }

Review comment:
       It is a utility function that is shared by several pieces of the 
codebase, so it does make sense to me to give it a semantically meaningful name




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to