MasterJH5574 commented on a change in pull request #8767:
URL: https://github.com/apache/tvm/pull/8767#discussion_r693518237



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range
+ * \throws ScheduleError If the loops to be reordered is not in a chain
+ */
+std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
+    const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = *loop_srefs.begin();
+  std::unordered_set<const StmtSRefNode*> visited;
+  bool scope_block_visited = false;
+  bool first_traversal = true;
+  for (const StmtSRefNode* loop_sref : loop_srefs) {
+    if (visited.count(loop_sref)) {
+      continue;
+    }
+    for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
+      // Case 1. If `v` corresponds to a block, stop traversal.
+      if (v->stmt->IsInstance<BlockNode>()) {
+        if (scope_block_visited) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        }
+        scope_block_visited = true;
+        break;
+      }
+      // Case 2. If `v` corresponds to a previously-visited loop, stop 
traversal and update
+      // `bottom`.
+      if (visited.count(v)) {
+        if (v != bottom) {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+        bottom = loop_sref;
+        break;
+      }
+      // Case 3. Add `v` into `visited`
+      visited.insert(v);
+      // If it's the first traversal and the loop corresponding to `v` is in 
the input array,
+      // update `top`.
+      if (first_traversal && loop_srefs.count(v)) {
+        top = v;
+      }
+    }
+    first_traversal = false;
+  }
+  return std::make_pair(top, bottom);
+}
+
+/*!
+ * \brief get all the loops in the reorder range
+ * \param self The schedule state
+ * \param top The top boundary of the reorder range
+ * \param bottom The bottom boundary of the reorder range
+ * \return an array containing all the loops in the reorder range

Review comment:
       ```suggestion
    * \return An array containing all the loops in the reorder range
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range

Review comment:
       ```suggestion
    * \return A pair containing the top and bottom boundary of the reorder range
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range
+ * \throws ScheduleError If the loops to be reordered is not in a chain
+ */
+std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
+    const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = *loop_srefs.begin();
+  std::unordered_set<const StmtSRefNode*> visited;
+  bool scope_block_visited = false;
+  bool first_traversal = true;
+  for (const StmtSRefNode* loop_sref : loop_srefs) {
+    if (visited.count(loop_sref)) {
+      continue;
+    }
+    for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
+      // Case 1. If `v` corresponds to a block, stop traversal.
+      if (v->stmt->IsInstance<BlockNode>()) {
+        if (scope_block_visited) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        }
+        scope_block_visited = true;
+        break;
+      }
+      // Case 2. If `v` corresponds to a previously-visited loop, stop 
traversal and update
+      // `bottom`.
+      if (visited.count(v)) {
+        if (v != bottom) {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+        bottom = loop_sref;
+        break;
+      }
+      // Case 3. Add `v` into `visited`
+      visited.insert(v);
+      // If it's the first traversal and the loop corresponding to `v` is in 
the input array,
+      // update `top`.
+      if (first_traversal && loop_srefs.count(v)) {
+        top = v;
+      }
+    }
+    first_traversal = false;
+  }
+  return std::make_pair(top, bottom);
+}
+
+/*!
+ * \brief get all the loops in the reorder range
+ * \param self The schedule state
+ * \param top The top boundary of the reorder range
+ * \param bottom The bottom boundary of the reorder range
+ * \return an array containing all the loops in the reorder range
+ * \throws ScheduleError If some loop in the reorder range is not single-branch
+ */
+std::vector<const StmtSRefNode*> GetLoopsInReorderRange(const ScheduleState& 
self,
+                                                        const StmtSRefNode* 
top,
+                                                        const StmtSRefNode* 
bottom) {
+  std::vector<const StmtSRefNode*> chain;
+  for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) {
+    const StmtSRefNode* parent_loop_sref = loop_sref->parent;
+    const ForNode* outer = parent_loop_sref->StmtAs<ForNode>();
+    const ForNode* inner = loop_sref->StmtAs<ForNode>();
+    ICHECK(outer != nullptr && inner != nullptr);
+    if (outer->body.get() != inner) {
+      throw LoopsNotAChainError(self->mod, GetRef<For>(outer),
+                                
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+    }
+    chain.push_back(loop_sref);
+    loop_sref = parent_loop_sref;
+  }
+  chain.push_back(top);
+  return chain;
+}
+
+/*!
+ * \brief Construct a loop chain in the new order
+ * \param self The schedule state
+ * \param chain The loops in the reorder range
+ * \param ordered_loop_srefs The loop srefs to be reordered
+ * \param loop_srefs The set containing loop srefs to be reordered
+ * \return the new loop chain

Review comment:
       ```suggestion
    * \return The new loop chain
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set

Review comment:
       ```suggestion
    * \brief Collect an array of loop srefs into a set
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range
+ * \throws ScheduleError If the loops to be reordered is not in a chain
+ */
+std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
+    const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = *loop_srefs.begin();
+  std::unordered_set<const StmtSRefNode*> visited;
+  bool scope_block_visited = false;
+  bool first_traversal = true;
+  for (const StmtSRefNode* loop_sref : loop_srefs) {
+    if (visited.count(loop_sref)) {
+      continue;
+    }
+    for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
+      // Case 1. If `v` corresponds to a block, stop traversal.
+      if (v->stmt->IsInstance<BlockNode>()) {
+        if (scope_block_visited) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        }
+        scope_block_visited = true;
+        break;
+      }
+      // Case 2. If `v` corresponds to a previously-visited loop, stop 
traversal and update
+      // `bottom`.
+      if (visited.count(v)) {
+        if (v != bottom) {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+        bottom = loop_sref;
+        break;
+      }
+      // Case 3. Add `v` into `visited`
+      visited.insert(v);
+      // If it's the first traversal and the loop corresponding to `v` is in 
the input array,
+      // update `top`.
+      if (first_traversal && loop_srefs.count(v)) {
+        top = v;
+      }
+    }
+    first_traversal = false;
+  }
+  return std::make_pair(top, bottom);
+}
+
+/*!
+ * \brief get all the loops in the reorder range

Review comment:
       ```suggestion
    * \brief Get all the loops in the reorder range
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range
+ * \throws ScheduleError If the loops to be reordered is not in a chain
+ */
+std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
+    const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = *loop_srefs.begin();
+  std::unordered_set<const StmtSRefNode*> visited;
+  bool scope_block_visited = false;
+  bool first_traversal = true;
+  for (const StmtSRefNode* loop_sref : loop_srefs) {
+    if (visited.count(loop_sref)) {
+      continue;
+    }
+    for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
+      // Case 1. If `v` corresponds to a block, stop traversal.
+      if (v->stmt->IsInstance<BlockNode>()) {
+        if (scope_block_visited) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        }
+        scope_block_visited = true;
+        break;
+      }
+      // Case 2. If `v` corresponds to a previously-visited loop, stop 
traversal and update
+      // `bottom`.
+      if (visited.count(v)) {
+        if (v != bottom) {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+        bottom = loop_sref;
+        break;
+      }
+      // Case 3. Add `v` into `visited`
+      visited.insert(v);
+      // If it's the first traversal and the loop corresponding to `v` is in 
the input array,
+      // update `top`.
+      if (first_traversal && loop_srefs.count(v)) {
+        top = v;
+      }
+    }
+    first_traversal = false;
+  }
+  return std::make_pair(top, bottom);
+}
+
+/*!
+ * \brief get all the loops in the reorder range
+ * \param self The schedule state
+ * \param top The top boundary of the reorder range
+ * \param bottom The bottom boundary of the reorder range
+ * \return an array containing all the loops in the reorder range
+ * \throws ScheduleError If some loop in the reorder range is not single-branch
+ */
+std::vector<const StmtSRefNode*> GetLoopsInReorderRange(const ScheduleState& 
self,
+                                                        const StmtSRefNode* 
top,
+                                                        const StmtSRefNode* 
bottom) {
+  std::vector<const StmtSRefNode*> chain;
+  for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) {
+    const StmtSRefNode* parent_loop_sref = loop_sref->parent;
+    const ForNode* outer = parent_loop_sref->StmtAs<ForNode>();
+    const ForNode* inner = loop_sref->StmtAs<ForNode>();
+    ICHECK(outer != nullptr && inner != nullptr);
+    if (outer->body.get() != inner) {
+      throw LoopsNotAChainError(self->mod, GetRef<For>(outer),
+                                
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+    }
+    chain.push_back(loop_sref);
+    loop_sref = parent_loop_sref;
+  }
+  chain.push_back(top);
+  return chain;
+}
+
+/*!
+ * \brief Construct a loop chain in the new order
+ * \param self The schedule state
+ * \param chain The loops in the reorder range
+ * \param ordered_loop_srefs The loop srefs to be reordered
+ * \param loop_srefs The set containing loop srefs to be reordered
+ * \return the new loop chain
+ * \throws ScheduleError If the domain of an outer loop depends on any of the 
inner loops after
+ * reordering
+ */
+For ConstructNewLoopChain(const ScheduleState& self, std::vector<const 
StmtSRefNode*> chain,
+                          const Array<StmtSRef>& ordered_loop_srefs,
+                          const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  std::unordered_set<const VarNode*> inner_vars;
+  inner_vars.reserve(chain.size());
+  For new_loop{nullptr};
+  int index = static_cast<int>(ordered_loop_srefs.size()) - 1;
+  for (const StmtSRefNode* loop_sref : chain) {
+    const ForNode* copy = nullptr;
+    if (loop_srefs.count(loop_sref)) {
+      copy = ordered_loop_srefs[index]->StmtAs<ForNode>();
+      --index;
+    } else {
+      copy = loop_sref->StmtAs<ForNode>();
+    }
+    ICHECK(copy != nullptr);
+    ObjectPtr<ForNode> n = make_object<ForNode>(*copy);
+    if (new_loop.defined()) {
+      n->body = new_loop;
+    } else {
+      n->body = loop_sref->StmtAs<ForNode>()->body;
+    }
+    const VarNode* used_var = nullptr;
+    auto f_contain = [&inner_vars, &used_var](const VarNode* var) {
+      if (inner_vars.count(var)) {
+        used_var = var;
+        return true;
+      }
+      return false;
+    };
+    if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) {
+      throw DependentLoopError(self->mod, GetRef<For>(copy), 
used_var->name_hint);
+    }
+    inner_vars.insert(copy->loop_var.get());
+    new_loop = For(std::move(n));
+  }
+  return new_loop;
+}
+
+void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
+  if (ordered_loop_srefs.size() <= 1) {
+    return;
+  }
+  // Step 1. Check uniqueness.and collect the input loop srefs into a set

Review comment:
       ```suggestion
     // Step 1. Check uniqueness and collect the input loop srefs into a set
   ```

##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
   self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
   return self->stmt2ref.at(new_stmt.get());
 }
+/*!
+ * \brief collect an array of loop srefs into a set
+ * \param self The schedule state
+ * \param ordered_loop_srefs The array of loop srefs
+ * \return A set containing all loops in the array
+ * \throws ScheduleError If there are duplicate loops in the array
+ */
+std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
+    const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
+  std::unordered_set<const StmtSRefNode*> loop_srefs;
+  loop_srefs.reserve(ordered_loop_srefs.size());
+  for (const StmtSRef& loop_sref : ordered_loop_srefs) {
+    auto inserted = loop_srefs.insert(loop_sref.get());
+    if (!inserted.second) {
+      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
+    }
+  }
+  return loop_srefs;
+}
+
+/*!
+ * \brief Get the top and bottom boundary of reorder range (which should be a 
chain)
+ * \param self The schedule state
+ * \param loop_srefs The set containing the srefs to the loops to be reordered
+ * \return a pair containing the top and bottom boundary of the reorder range
+ * \throws ScheduleError If the loops to be reordered is not in a chain
+ */
+std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
+    const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  const StmtSRefNode* top = nullptr;
+  const StmtSRefNode* bottom = *loop_srefs.begin();
+  std::unordered_set<const StmtSRefNode*> visited;
+  bool scope_block_visited = false;
+  bool first_traversal = true;
+  for (const StmtSRefNode* loop_sref : loop_srefs) {
+    if (visited.count(loop_sref)) {
+      continue;
+    }
+    for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
+      // Case 1. If `v` corresponds to a block, stop traversal.
+      if (v->stmt->IsInstance<BlockNode>()) {
+        if (scope_block_visited) {
+          throw LoopsNotAChainError(self->mod, NullOpt,
+                                    
LoopsNotAChainError::ProblemKind::kNotUnderAScope);
+        }
+        scope_block_visited = true;
+        break;
+      }
+      // Case 2. If `v` corresponds to a previously-visited loop, stop 
traversal and update
+      // `bottom`.
+      if (visited.count(v)) {
+        if (v != bottom) {
+          throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
+                                    
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+        }
+        bottom = loop_sref;
+        break;
+      }
+      // Case 3. Add `v` into `visited`
+      visited.insert(v);
+      // If it's the first traversal and the loop corresponding to `v` is in 
the input array,
+      // update `top`.
+      if (first_traversal && loop_srefs.count(v)) {
+        top = v;
+      }
+    }
+    first_traversal = false;
+  }
+  return std::make_pair(top, bottom);
+}
+
+/*!
+ * \brief get all the loops in the reorder range
+ * \param self The schedule state
+ * \param top The top boundary of the reorder range
+ * \param bottom The bottom boundary of the reorder range
+ * \return an array containing all the loops in the reorder range
+ * \throws ScheduleError If some loop in the reorder range is not single-branch
+ */
+std::vector<const StmtSRefNode*> GetLoopsInReorderRange(const ScheduleState& 
self,
+                                                        const StmtSRefNode* 
top,
+                                                        const StmtSRefNode* 
bottom) {
+  std::vector<const StmtSRefNode*> chain;
+  for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) {
+    const StmtSRefNode* parent_loop_sref = loop_sref->parent;
+    const ForNode* outer = parent_loop_sref->StmtAs<ForNode>();
+    const ForNode* inner = loop_sref->StmtAs<ForNode>();
+    ICHECK(outer != nullptr && inner != nullptr);
+    if (outer->body.get() != inner) {
+      throw LoopsNotAChainError(self->mod, GetRef<For>(outer),
+                                
LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
+    }
+    chain.push_back(loop_sref);
+    loop_sref = parent_loop_sref;
+  }
+  chain.push_back(top);
+  return chain;
+}
+
+/*!
+ * \brief Construct a loop chain in the new order
+ * \param self The schedule state
+ * \param chain The loops in the reorder range
+ * \param ordered_loop_srefs The loop srefs to be reordered
+ * \param loop_srefs The set containing loop srefs to be reordered
+ * \return the new loop chain
+ * \throws ScheduleError If the domain of an outer loop depends on any of the 
inner loops after
+ * reordering
+ */
+For ConstructNewLoopChain(const ScheduleState& self, std::vector<const 
StmtSRefNode*> chain,
+                          const Array<StmtSRef>& ordered_loop_srefs,
+                          const std::unordered_set<const StmtSRefNode*>& 
loop_srefs) {
+  std::unordered_set<const VarNode*> inner_vars;
+  inner_vars.reserve(chain.size());
+  For new_loop{nullptr};
+  int index = static_cast<int>(ordered_loop_srefs.size()) - 1;
+  for (const StmtSRefNode* loop_sref : chain) {
+    const ForNode* copy = nullptr;
+    if (loop_srefs.count(loop_sref)) {
+      copy = ordered_loop_srefs[index]->StmtAs<ForNode>();
+      --index;
+    } else {
+      copy = loop_sref->StmtAs<ForNode>();
+    }
+    ICHECK(copy != nullptr);
+    ObjectPtr<ForNode> n = make_object<ForNode>(*copy);
+    if (new_loop.defined()) {
+      n->body = new_loop;
+    } else {
+      n->body = loop_sref->StmtAs<ForNode>()->body;
+    }
+    const VarNode* used_var = nullptr;
+    auto f_contain = [&inner_vars, &used_var](const VarNode* var) {
+      if (inner_vars.count(var)) {
+        used_var = var;
+        return true;
+      }
+      return false;
+    };
+    if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) {
+      throw DependentLoopError(self->mod, GetRef<For>(copy), 
used_var->name_hint);
+    }
+    inner_vars.insert(copy->loop_var.get());
+    new_loop = For(std::move(n));
+  }
+  return new_loop;
+}
+
+void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
+  if (ordered_loop_srefs.size() <= 1) {
+    return;
+  }
+  // Step 1. Check uniqueness.and collect the input loop srefs into a set
+  std::unordered_set<const StmtSRefNode*> loop_srefs =
+      CollectLoopsIntoSet(self, ordered_loop_srefs);
+  // Step 2. Gather loops to be reordered
+  // For each loop sref in the input sref array, traverse upwards along its 
parent pointer in the
+  // sref tree, and stop on either a block, or a previously-visited loop
+  // - the top of the reorder range is the last loop visited in the first 
traversal which exists in
+  //   the input array
+  // - the bottom of the reorder range is the last loop in the input array 
which is not visited in
+  // the previous traversals
+  auto pair = GetBoundaryOfReorderRange(self, loop_srefs);
+  const StmtSRefNode* top = pair.first;
+  const StmtSRefNode* bottom = pair.second;

Review comment:
       In this way we can avoid creating a temporary pair.
   ```suggestion
     const StmtSRefNode* top;
     const StmtSRefNode* bottom;
     std::tie(top, bottom) = GetBoundaryOfReorderRange(self, loop_srefs);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to