tqchen commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1249603596
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public
StmtMutator {
it_buffers = MakeScratchpads(reduction_buffers,
/*is_cross_thread_buffer=*/false);
new_buffers.insert(new_buffers.end(), it_buffers.value().begin(),
it_buffers.value().end());
}
- // Step 5. Transform.
+ // Step 4. Transform.
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers,
reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);
- // Step 6. Return an empty statement, because the transformation result
will be inserted when
- // returning to the first reduction-related loop.
- return Stmt{nullptr};
+ }
+
+ Stmt MakeCrossThreadBroadcast(
+ const BlockRealizeNode* realize,
+ const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+ // Step 1. Generate loop var for each unbound thread.
+ // Update the block predicate with clauses of `thread_var == min`.
+ PrimExpr predicate = realize->predicate;
+ Array<Var> loop_vars;
+ loop_vars.reserve(unbound_thread2range.size());
+ for (auto [scope, range] : unbound_thread2range) {
+ std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+ Var loop_var("t" + dim_index, range->min->dtype);
+ loop_vars.push_back(loop_var);
+ predicate = (loop_var == range->min) && predicate;
+ }
+
+ // Step 2. Update the BlockRealize with the new predicate.
+ ObjectPtr<BlockRealizeNode> p_realize =
make_object<BlockRealizeNode>(*realize);
+ p_realize->predicate = std::move(predicate);
+
+ // Step 3. Wrap the updated BlockRealize with the new loops.
+ Stmt body(p_realize);
+ for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+ std::string dim_index(1,
static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+ body = For(
+ /*loop_var=*/loop_vars[i], //
+ /*min=*/unbound_thread2range[i].second->min, //
+ /*extent=*/unbound_thread2range[i].second->extent, //
+ /*kind=*/ForKind::kThreadBinding, //
+ /*body=*/body, //
+ /*thread_binding=*/
+ IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
+ "threadIdx." + dim_index));
+ }
+ return body;
+ }
+
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ // Part 1. Check if the block needs cross-thread reduction rewrite.
+ std::vector<const ForNode*> reduction_loops =
NeedCrossThreadReduction(realize);
+ if (!reduction_loops.empty()) {
+ // Return an empty statement, because the transformation result will
+ // be inserted when returning to the first reduction-related loop.
+ MakeCrossThreadReduction(realize, reduction_loops);
+ return Stmt{nullptr};
+ }
+
Review Comment:
Only checks if we already have cross thread reduction, this will reduce the
amount of checks needed for other realize.
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -578,9 +621,31 @@ class CrossThreadReductionTransformer : public StmtMutator
{
Stmt VisitStmt_(const ForNode* loop) final {
loop_stack_.push_back(loop);
loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
+
+ // Collect loop-thread information:
+ // - when encountering a threadIdx loop, we keep note of its domain and
+ // the "loop var -> thread scope" relation, in order to collect all
existing
+ // threads within a thread block.
+ // - we are careful about thread block boundary for safety.
+ int old_thread_block_depth = thread_block_depth_;
+ if (loop->kind == ForKind::kThreadBinding) {
+ ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+ if (scope.rank == 0 || !thread_block_depth_) {
Review Comment:
Do clear immediately after VisitLoop, since that helps to remove prior states
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -478,6 +497,30 @@ class CrossThreadReductionTransformer : public StmtMutator
{
return need ? reduction_loops : std::vector<const ForNode*>{};
}
+ // Check if the input block needs thread broadcast rewrite.
+ // One block needs broadcast rewrite when there exists one or more thread
+ // vars which vars free variables to this block.
+ std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+ const BlockRealizeNode* realize) {
+ std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
unbound_thread2range =
+ thread2range_;
+ for (const PrimExpr& iter_value : realize->iter_values) {
Review Comment:
Checking iter_values have a few disadvantages, for example, unit loop could
be unbound. Instead, check the surrounding `loop_stack_` which is much more
robust
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]