wrongtest-intellif commented on code in PR #17133:
URL: https://github.com/apache/tvm/pull/17133#discussion_r1714550534


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode* 
realize,            //
 }
 
 /*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input 
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers 
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
  */
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
-  // Check if the input block needs cross-thread reduction.
-  std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* 
realize) {
-    // Step 0. If the block is the root block, just return.
-    if (block_stack_.empty()) {
-      return {};
-    }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,                    
//
+                          const Array<Buffer>& ct_buffers,                    
//
+                          const Array<Buffer>& wb_buffers,                    
//
+                          const Array<PrimExpr>& old_wb_indices,              
//
+                          const CommReducer& reducer,                         
//
+                          const Array<PrimExpr>& combiner_lhs,                
//
+                          const std::vector<const ForNode*>& reduction_loops  
//
+                          ) {
+  int n_buffers = wb_buffers.size();
+  const BlockNode* block = realize->block.get();
 
-    // Step 1. If the block is not a reduction block, cross-thread reduction 
is not needed.
-    if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
-                          GetRef<Block>(block_stack_.back()), &analyzer_)) {
-      return {};
+  auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+    Array<BufferRegion> regions;
+    regions.reserve(buffers.size());
+    for (const Buffer& buffer : buffers) {
+      regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
     }
+    return regions;
+  };
 
-    // Step 2. Collect all the vars that appear in the bindings of reduction 
block iters.
-    std::unordered_set<const VarNode*> reduction_vars;
-    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, 
&reduction_vars);
+  Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+  Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+  // In total, the block is transformed into at most 4 statements
+  // - Stmt 1: initialize the buffer for in-thread reduction
+  // - Stmt 2: do in-thread reduction
+  // - Stmt 3: do cross-thread reduction
+  // - Stmt 4: write cross-thread reduction result to the original buffer
+  Array<Stmt> stmts;
+  stmts.reserve(4);
 
-    // Step 3. Collect the loops whose loop vars appear in the bindings of 
reduction block iters.
-    // We call these loops "reduction-related".
-    // Step 4. See whether at least one reduction-related loop is bound to 
thread axis in GPU - if
-    // so, cross-thread reduction is needed. If none of the reduction-related 
loops is bound to
-    // thread axis, cross-thread reduction is not needed for the input block.
-    bool need = false;
-    std::vector<const ForNode*> reduction_loops;
-    for (const ForNode* loop : loop_stack_) {
-      if (reduction_vars.count(loop->loop_var.get())) {
-        // Step 3. Collect the loop.
-        reduction_loops.push_back(loop);
-        // Step 4. See whether the loop is bound to some thread axis.
-        if (loop->thread_binding.defined()) {
-          need = true;
-        }
+  // Stmt 3: do cross-thread reduction
+  {
+    // Step 3.1. Create the parameters to the intrinsic
+    Array<PrimExpr> parameters;
+    parameters.reserve(reduction_loops.size() + 4);
+    // 1-st argument: number of buffers
+    parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+    // Next `n_buffers` arguments: sources
+    parameters.insert(parameters.end(), combiner_lhs.begin(), 
combiner_lhs.end());
+    // Next argument: predicate
+    parameters.push_back(const_true());
+    // Next `n_buffers` arguments: destinations
+    for (int i = 0; i < n_buffers; ++i) {
+      parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+    }
+    // Next arguments: all the reduction threads
+    for (const ForNode* reduction_loop : reduction_loops) {
+      if (reduction_loop->thread_binding.defined()) {
+        parameters.push_back(reduction_loop->loop_var);
       }
     }
-    return need ? reduction_loops : std::vector<const ForNode*>{};
-  }
-
-  // Check if the input block needs thread broadcast rewrite.
-  // One block needs broadcast rewrite when
-  // 1. it consumes a buffer produced by cross-thread reduction under
-  // the same kernel (i.e., same group of blockIdx),
-  // 2. it writes to non-local memory,
-  // 3. at least one of the reduction thread vars of the cross-thread reduction
-  // is free to this block (i.e., not bound to the block).
-  std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
-      const BlockRealizeNode* realize) {
-    Block block = realize->block;
-
-    // If the block writes to local memory, no rewrite is needed.
-    for (BufferRegion write_region : block->writes) {
-      if (write_region->buffer.scope() == "local") {
-        return {};
+    // Step 3.2. Create the block and the block-realize.
+    Array<IterVar> iter_vars = block->iter_vars;
+    Array<PrimExpr> bindings = realize->iter_values;
+    Array<BufferRegion> reads = block->writes;
+
+    // Blockized block should also be considered
+    if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+      Array<BlockRealize> child_blocks =
+          HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+      ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than 
0";
+
+      // If has ChildBlocks, the reads should be analyzed from the child blocks
+      reads.clear();
+      for (BlockRealize child_block : child_blocks) {
+        Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+        Array<PrimExpr> child_bindings = child_block->iter_values;
+        iter_vars.insert(iter_vars.end(), child_iter_vars.begin(), 
child_iter_vars.end());
+        bindings.insert(bindings.end(), child_bindings.begin(), 
child_bindings.end());
+        reads.insert(reads.end(), child_block->block->writes.begin(),
+                     child_block->block->writes.end());
       }
     }
 
-    // Find out the reduction threads for the read-buffers which are produced 
by
-    // cross-thread reduction.
-    std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> 
thread2range;
-    for (BufferRegion read_region : block->reads) {
-      auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
-      if (buf_it == crt_buf2threads_.end()) {
-        continue;
+    // Remove unused iter vars which introduced by blockize
+    // otherwise may generate duplicated for loops
+    Array<IterVar> iter_vars_used;
+    Array<PrimExpr> bindings_used;
+    auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+                     &bindings_used](const VarNode* var) -> bool {
+      for (size_t i = 0; i < iter_vars.size(); ++i) {
+        const IterVar& iter_var = iter_vars[i];
+        if (iter_var->var.get() == var) {
+          if (std::find(iter_vars_used.begin(), iter_vars_used.end(), 
iter_var) ==
+              iter_vars_used.end()) {
+            iter_vars_used.push_back(iter_var);
+            bindings_used.push_back(bindings[i]);
+          }
+          return true;
+        }
       }
-      for (auto [scope, range] : buf_it->second) {
-        thread2range[scope] = range;
+      return false;
+    };
+    for (const BufferRegion& read : reads) {
+      for (const Range& range : read->region) {
+        UsedIterVarCollector::Collect(range->min, f_inject);
+        UsedIterVarCollector::Collect(range->extent, f_inject);
       }
     }
 
-    // Erase those threads which are not free to this block.
-    for (const ForNode* loop : loop_stack_) {
-      if (loop->thread_binding.defined()) {
-        ThreadScope scope = 
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
-        thread2range.erase(scope);
+    Block cross_thread_block =
+        Block(/*iter_vars=*/std::move(iter_vars_used),
+              /*reads=*/std::move(reads),
+              /*writes=*/ct_buffer_regions,
+              /*name_hint=*/block->name_hint + "_cross_thread",
+              /*body=*/
+              AttrStmt(/*node=*/reducer,
+                       /*attr_key=*/tir::attr::reduce_scope,
+                       /*value=*/make_zero(DataType::Handle()),
+                       /*body=*/
+                       Evaluate(Call(/*dtype=*/DataType::Handle(),
+                                     
/*op=*/tir::builtin::tvm_thread_allreduce(),
+                                     /*args=*/std::move(parameters)))));
+    ObjectPtr<BlockNode> cross_thread_block_node =

Review Comment:
   We could also use `block.CopyOnWrite()->annotations.Set(...)`  if full 
constructor params looks ugly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to