LeiWang1999 commented on code in PR #17133:
URL: https://github.com/apache/tvm/pull/17133#discussion_r1710928665


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode* 
realize,            //
 }
 
 /*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input 
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers 
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
  */
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
-  // Check if the input block needs cross-thread reduction.
-  std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* 
realize) {
-    // Step 0. If the block is the root block, just return.
-    if (block_stack_.empty()) {
-      return {};
-    }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,                    
//

Review Comment:
   do you mean `//` there, I just copy them from another funcs in this file.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to