MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250425042
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public
StmtMutator {
it_buffers = MakeScratchpads(reduction_buffers,
/*is_cross_thread_buffer=*/false);
new_buffers.insert(new_buffers.end(), it_buffers.value().begin(),
it_buffers.value().end());
}
- // Step 5. Transform.
+ // Step 4. Transform.
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers,
reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);
- // Step 6. Return an empty statement, because the transformation result
will be inserted when
- // returning to the first reduction-related loop.
- return Stmt{nullptr};
+ }
+
+ Stmt MakeCrossThreadBroadcast(
+ const BlockRealizeNode* realize,
+ const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+ // Step 1. Generate loop var for each unbound thread.
+ // Update the block predicate with clauses of `thread_var == min`.
+ PrimExpr predicate = realize->predicate;
+ Array<Var> loop_vars;
+ loop_vars.reserve(unbound_thread2range.size());
+ for (auto [scope, range] : unbound_thread2range) {
+ std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+ Var loop_var("t" + dim_index, range->min->dtype);
+ loop_vars.push_back(loop_var);
+ predicate = (loop_var == range->min) && predicate;
+ }
+
+ // Step 2. Update the BlockRealize with the new predicate.
+ ObjectPtr<BlockRealizeNode> p_realize =
make_object<BlockRealizeNode>(*realize);
+ p_realize->predicate = std::move(predicate);
+
+ // Step 3. Wrap the updated BlockRealize with the new loops.
+ Stmt body(p_realize);
+ for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+ std::string dim_index(1,
static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+ body = For(
+ /*loop_var=*/loop_vars[i], //
+ /*min=*/unbound_thread2range[i].second->min, //
+ /*extent=*/unbound_thread2range[i].second->extent, //
+ /*kind=*/ForKind::kThreadBinding, //
+ /*body=*/body, //
+ /*thread_binding=*/
+ IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
+ "threadIdx." + dim_index));
+ }
+ return body;
+ }
+
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ // Part 1. Check if the block needs cross-thread reduction rewrite.
+ std::vector<const ForNode*> reduction_loops =
NeedCrossThreadReduction(realize);
+ if (!reduction_loops.empty()) {
+ // Return an empty statement, because the transformation result will
+ // be inserted when returning to the first reduction-related loop.
+ MakeCrossThreadReduction(realize, reduction_loops);
+ return Stmt{nullptr};
+ }
+
Review Comment:
Yep updated, as this pattern basically appears only when cross-thread
reduction exists.
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -478,6 +497,30 @@ class CrossThreadReductionTransformer : public StmtMutator
{
return need ? reduction_loops : std::vector<const ForNode*>{};
}
+ // Check if the input block needs thread broadcast rewrite.
+ // One block needs broadcast rewrite when there exists one or more thread
+ // vars which vars free variables to this block.
+ std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+ const BlockRealizeNode* realize) {
+ std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
unbound_thread2range =
+ thread2range_;
+ for (const PrimExpr& iter_value : realize->iter_values) {
Review Comment:
Thanks for the good catch!! Updated.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]