MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250424365
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -578,9 +621,31 @@ class CrossThreadReductionTransformer : public StmtMutator
{
Stmt VisitStmt_(const ForNode* loop) final {
loop_stack_.push_back(loop);
loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
+
+ // Collect loop-thread information:
+ // - when encountering a threadIdx loop, we keep note of its domain and
+ // the "loop var -> thread scope" relation, in order to collect all
existing
+ // threads within a thread block.
+ // - we are careful about thread block boundary for safety.
+ int old_thread_block_depth = thread_block_depth_;
+ if (loop->kind == ForKind::kThreadBinding) {
+ ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+ if (scope.rank == 0 || !thread_block_depth_) {
Review Comment:
Now changed the logic to
> When exiting a loop,
> * if it is a `blockIdx` loop, clear the map,
> * if it is a `threadIdx` loop, clear the map when both `threadIdx` depth
and `blockIdx` depth are 0.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]