MasterJH5574 commented on a change in pull request #8767: URL: https://github.com/apache/tvm/pull/8767#discussion_r693292570
########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + // Maps a parent sref to its child sref + std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor; + for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { + const StmtSRefNode* sref = ordered_loop_srefs[i].get(); + // if sref is not visited before, update `bottom` + if (!successor.count(sref->parent)) { + bottom = sref; + } + while (true) { + // stop at blocknode + if (sref->stmt->IsInstance<BlockNode>()) { + if (i != 0) { + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); + } else { + break; + } + } + const StmtSRefNode* parent_sref = sref->parent; + // stop at previously-visited loop + if (successor.count(parent_sref)) { + if (successor[parent_sref] == sref) { + break; + } else { + throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + } else { + successor[parent_sref] = sref; + } + // if it's the first traverse and the loop is in the input array, update `top` + if (loop_srefs.count(sref) && i == 0) { + top = sref; + } + sref = parent_sref; + } Review comment: * Avoid nesting too much. * List the different cases clearly. ```suggestion while (true) { // Case 1. If `sref` corresponds to a block, stop traversal. if (sref->stmt->IsInstance<BlockNode>()) { if (i != 0) { throw LoopsNotAChainError(self->mod, NullOpt, LoopsNotAChainError::ProblemKind::kNotUnderAScope); } break; } ICHECK(sref->stmt->IsInstance<ForNode>()); const StmtSRefNode* parent_sref = sref->parent; // Case 2. If `sref` corresponds to a previously-visited loop, stop traversal. if (successor.count(parent_sref)) { if (successor[parent_sref] != sref) { throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } break; } // Case 3. Mark `sref` as visited by setting the sref successor of its parent sref. successor[parent_sref] = sref; // If it's the first traversal and the loop corresponding to `sref` is in the input array, update `top`. if (loop_srefs.count(sref) && i == 0) { top = sref; } sref = parent_sref; } ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered Review comment: I wonder whether it's possible to move step 2 to a seperate function. @junrushao1994 What do you think? ########## File path: src/tir/schedule/primitive.h ########## @@ -63,6 +63,21 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, * \return The sref to the fused loop */ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs); +/*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops Review comment: ```suggestion * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. ``` ########## File path: python/tvm/tir/schedule/schedule.py ########## @@ -442,6 +442,65 @@ def after_split(a: ty.handle, b: ty.handle) -> None: # that there is at most one None in `factors` return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + def reorder(self, *ordered_loops: List[LoopRV]) -> None: + """ + Reorder a list of loops. It doesn't require the loops to be consecutive. + It requires: + 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + l_1 and l_n (which also indicates they are under the same scope). + 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops Review comment: ```suggestion 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. ``` ########## File path: include/tvm/tir/schedule/schedule.h ########## @@ -219,6 +219,19 @@ class ScheduleNode : public runtime::Object { * \return The new loops after split */ virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0; + /*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops Review comment: ```suggestion * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + // Maps a parent sref to its child sref + std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor; + for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { + const StmtSRefNode* sref = ordered_loop_srefs[i].get(); + // if sref is not visited before, update `bottom` + if (!successor.count(sref->parent)) { + bottom = sref; + } Review comment: ```suggestion // If `sref` is not visited before, update `bottom`. if (!successor.count(sref->parent)) { bottom = sref; } ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered Review comment: ```suggestion // Step 2. Gather loops to be reordered ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses Review comment: ```suggestion // - the top of the reorder range is the last loop visited in the first traversal which exists in // the input array // - the bottom of the reorder range is the last loop in the input array which is not visited in // the previous traversals ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + // Maps a parent sref to its child sref + std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor; + for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { + const StmtSRefNode* sref = ordered_loop_srefs[i].get(); + // if sref is not visited before, update `bottom` + if (!successor.count(sref->parent)) { + bottom = sref; + } + while (true) { + // stop at blocknode + if (sref->stmt->IsInstance<BlockNode>()) { + if (i != 0) { + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); + } else { + break; + } + } + const StmtSRefNode* parent_sref = sref->parent; + // stop at previously-visited loop + if (successor.count(parent_sref)) { + if (successor[parent_sref] == sref) { + break; + } else { + throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + } else { + successor[parent_sref] = sref; + } + // if it's the first traverse and the loop is in the input array, update `top` + if (loop_srefs.count(sref) && i == 0) { + top = sref; + } + sref = parent_sref; + } + } + // Step 3. Check that loops are single-branch + const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef<StmtSRef>(top)); + for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) { + loop_sref = successor[loop_sref]; + const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef<StmtSRef>(loop_sref)); + if (outer_loop->body.get() != inner_loop) { + throw LoopsNotAChainError(self->mod, GetRef<For>(outer_loop), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + outer_loop = inner_loop; + } + // Step 4. Check the block below has all its block_var to be data-parallel or reduction Review comment: ```suggestion // Step 4. Check the block below has all its block_var to be data-parallel or reduction, and the block has an affine binding. ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop Review comment: ```suggestion // For each loop sref in the input sref array, traverse upwards along its parent pointer in the sref tree, and stop on either a block, or a // previously-visited loop ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } Review comment: ```suggestion if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { return; } std::unordered_set<const StmtSRefNode*> loop_srefs; loop_srefs.reserve(ordered_loop_srefs.size()); ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check Review comment: I think we don't need this line. ```suggestion ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness Review comment: ```suggestion // Step 1. Check uniqueness. ``` ########## File path: src/tir/schedule/primitive/loop_transformation.cc ########## @@ -385,6 +511,113 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { + std::unordered_set<const StmtSRefNode*> loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); + } + } + // Step 2. gather loops to be reordered + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + // Maps a parent sref to its child sref + std::unordered_map<const StmtSRefNode*, const StmtSRefNode*> successor; + for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { + const StmtSRefNode* sref = ordered_loop_srefs[i].get(); + // if sref is not visited before, update `bottom` + if (!successor.count(sref->parent)) { + bottom = sref; + } + while (true) { + // stop at blocknode + if (sref->stmt->IsInstance<BlockNode>()) { + if (i != 0) { + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); + } else { + break; + } + } + const StmtSRefNode* parent_sref = sref->parent; + // stop at previously-visited loop + if (successor.count(parent_sref)) { + if (successor[parent_sref] == sref) { + break; + } else { + throw LoopsNotAChainError(self->mod, GetRef<Stmt>(parent_sref->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + } else { + successor[parent_sref] = sref; + } + // if it's the first traverse and the loop is in the input array, update `top` + if (loop_srefs.count(sref) && i == 0) { + top = sref; + } + sref = parent_sref; + } + } + // Step 3. Check that loops are single-branch + const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef<StmtSRef>(top)); + for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) { + loop_sref = successor[loop_sref]; + const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef<StmtSRef>(loop_sref)); + if (outer_loop->body.get() != inner_loop) { + throw LoopsNotAChainError(self->mod, GetRef<For>(outer_loop), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + outer_loop = inner_loop; + } + // Step 4. Check the block below has all its block_var to be data-parallel or reduction + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); + // Step 5. Replace the original loops with the reordered loops and check that outer loop is + // not dependent on inner loop + std::unordered_set<const VarNode*> inner_vars; + std::function<Stmt(const StmtSRefNode*, int index)> f_reorder = + [&bottom, &loop_srefs, &successor, &ordered_loop_srefs, &inner_vars, &self, &f_reorder]( + const StmtSRefNode* loop, int index) -> Stmt { + const ForNode* copy = loop_srefs.count(loop) ? ordered_loop_srefs[index++]->StmtAs<ForNode>() + : loop->StmtAs<ForNode>(); + ObjectPtr<ForNode> n = make_object<ForNode>(*copy); + if (loop == bottom) { + // stop recursion at bottom loop Review comment: ```suggestion // Stop recursion at the bottom loop ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org