MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250367349
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public
StmtMutator {
it_buffers = MakeScratchpads(reduction_buffers,
/*is_cross_thread_buffer=*/false);
new_buffers.insert(new_buffers.end(), it_buffers.value().begin(),
it_buffers.value().end());
}
- // Step 5. Transform.
+ // Step 4. Transform.
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers,
reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);
- // Step 6. Return an empty statement, because the transformation result
will be inserted when
- // returning to the first reduction-related loop.
- return Stmt{nullptr};
+ }
+
+ Stmt MakeCrossThreadBroadcast(
+ const BlockRealizeNode* realize,
+ const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+ // Step 1. Generate loop var for each unbound thread.
+ // Update the block predicate with clauses of `thread_var == min`.
+ PrimExpr predicate = realize->predicate;
+ Array<Var> loop_vars;
+ loop_vars.reserve(unbound_thread2range.size());
+ for (auto [scope, range] : unbound_thread2range) {
+ std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+ Var loop_var("t" + dim_index, range->min->dtype);
+ loop_vars.push_back(loop_var);
+ predicate = (loop_var == range->min) && predicate;
+ }
+
+ // Step 2. Update the BlockRealize with the new predicate.
+ ObjectPtr<BlockRealizeNode> p_realize =
make_object<BlockRealizeNode>(*realize);
+ p_realize->predicate = std::move(predicate);
+
+ // Step 3. Wrap the updated BlockRealize with the new loops.
+ Stmt body(p_realize);
+ for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+ std::string dim_index(1,
static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+ body = For(
+ /*loop_var=*/loop_vars[i], //
Review Comment:
To force a line break for formatter. So that it doesn't format it into
```c++
For(/*loop_var=*/loop_vars[i], /*min=*/...)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]