junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717228066
##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
namespace tvm {
namespace tir {
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+ /*!
+ * \brief The open interface to users to call the helper class
+ * \param old_scope_root The original block scope before decomposition
+ * \param target_loop The loop we insert the decomposed init body before
+ * \param decompose_body The decomposed init body
+ * \param old_reduction_block The reduction block we want to decompose
+ * \return The new block scope and the updated reduction block
+ */
+ static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+ Stmt decomposed_body, Block
old_reduction_block) {
+ DecomposeReductionBlockReplacer replacer(std::move(target_loop),
std::move(decomposed_body),
+ std::move(old_reduction_block));
+ return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+ replacer.new_reduction_block_);
+ }
+
+ private:
+ explicit DecomposeReductionBlockReplacer(For target_loop, Stmt
decomposed_body,
+ Block old_reduction_block)
+ : target_loop_(std::move(target_loop)),
+ decomposed_body_(std::move(decomposed_body)),
+ old_reduction_block_(std::move(old_reduction_block)) {}
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+ if (loop == target_loop_.get()) {
+ return SeqStmt({decomposed_body_, mutated_stmt});
+ } else {
+ return mutated_stmt;
+ }
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block == old_reduction_block_.get()) {
+ ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+ p_new_block->name_hint = p_new_block->name_hint + "_update";
+ p_new_block->init = NullOpt;
+ new_reduction_block_ = Block(p_new_block);
+ return new_reduction_block_;
+ } else {
+ return StmtMutator::VisitStmt_(block);
+ }
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq) final {
+ Array<Stmt> new_stmts;
+ new_stmts.reserve(seq->seq.size());
+ for (const Stmt& old_stmt : seq->seq) {
+ new_stmts.push_back(VisitStmt(old_stmt));
+ }
+ return SeqStmt::Flatten(new_stmts);
+ }
Review comment:
Why do we need special handling of SeqStmt btw?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]