yincs-intellif commented on code in PR #14398:
URL: https://github.com/apache/tvm/pull/14398#discussion_r1155450325
##########
src/tir/schedule/primitive/loop_transformation.cc:
##########
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref, const Array
return result_srefs;
}
+class LoopReconstructor : private StmtMutator {
+ public:
+ explicit LoopReconstructor(Block scope_root,
+ const std::vector<std::vector<const ForNode*>>&
loops)
+ : scope_root_(scope_root), loops_(loops) {}
+
+ using StmtMutator::operator();
+
+ /*!
+ * \brief Create the new nest loops induced by the given loops
+ */
+ void MakeNewLoop() {
+ Array<Var> new_loop_vars;
+ Array<PrimExpr> new_loop_extents;
+ Array<Stmt> new_stmts;
+ for (size_t i = 0; i < loops_.size(); i++) {
+ Map<Var, PrimExpr> var_map;
+ for (size_t j = 0; j < loops_[i].size(); j++) {
+ if (i == 0) {
+ Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
+ new_loop_vars.push_back(merged_var);
+ new_loop_extents.push_back(loops_[i][j]->extent);
+ }
+ var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
+ }
+ auto new_stmt = Substitute(loops_[i][0]->body, var_map);
+ new_stmts.push_back(new_stmt);
+ this->need_remove_loop_.push_back(loops_[i].back());
+ }
+ auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0],
ForKind::kSerial,
+ SeqStmt(std::move(new_stmts)));
+ this->new_inner_loop_ = new_loop;
+ for (size_t i = 1; i < new_loop_vars.size(); ++i) {
+ const Var& loop_var = new_loop_vars[i];
+ const PrimExpr& loop_extent = new_loop_extents[i];
+ new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial,
new_loop);
+ }
+ this->new_outer_loop_ = new_loop;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block != scope_root_.get()) {
+ return GetRef<Block>(block);
+ }
+ return StmtMutator::VisitStmt_(block);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == need_remove_loop_.back()) {
+ return new_outer_loop_;
+ } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(),
loop)) {
+ return Evaluate(0);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
+ }
+ }
+ ret = SeqStmt(filtered);
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
+ } else {
+ return std::move(ret);
+ }
+ }
+
+ public:
+ /*! \brief The root block of the block scope */
+ Block scope_root_;
+ /*! \brief The given loops to be merge */
+ const std::vector<std::vector<const ForNode*>>& loops_;
+ /*! \brief The outermost new loop to replace the original loop */
+ For new_outer_loop_{nullptr};
+ /*! \brief The innermost new loop to replace the original loop */
+ For new_inner_loop_{nullptr};
+ /*! \brief The loops to be removed */
+ std::vector<const ForNode*> need_remove_loop_;
+};
+
+StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
+ // Invariance
+ // - The total repeat number has not changed for each direct child block.
+ // - The execution order has not changed. (The block executes with the same
+ // args and the same order with before.)
+ arith::Analyzer analyzer;
+ StmtSRef scope_root_sref;
+ StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
+ std::vector<std::vector<const ForNode*>> lca_nest_loops;
+ // Step 1. check correctness
+ std::vector<const ForNode*> nest_loop_loops;
+ std::vector<PrimExpr> nest_loop_extents;
+ for (size_t i = 0; i < loop_srefs.size(); i++) {
+ const StmtSRef& sref = loop_srefs[i];
+ auto scope_root_sref_ = GetScopeRoot(self, sref,
/*require_stage_pipeline=*/false);
Review Comment:
Has checked in line581
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]