yzh119 commented on code in PR #15327:
URL: https://github.com/apache/tvm/pull/15327#discussion_r1264589363
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
- if (is_warp_reduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
- ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
- //
- // This is the index to the reduction variable, one reduction
- // variable per warp. Local scope seems easier to reason without
- // relying on a pattern match pass to fix it later.
- Array<PrimExpr> zero_indices = {0};
-
- for (size_t idx = 0; idx < size; ++idx) {
- Array<PrimExpr> shape = {1};
-
- Buffer buffer = decl_buffer(shape, types[idx], "red_buf" +
std::to_string(idx));
- Var buffer_var = buffer->data;
-
- shared_buffer_vars[idx] = buffer_var;
- shared_bufs[idx] = buffer;
-
- PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
zero_indices));
-
- // Uses a local variable to store the shuffled data. Later
- // on, an allocation will be built for this local variable.
- local_bufs.push_back(decl_buffer(shape, types[idx], "t" +
std::to_string(idx)));
- }
-
- // The mask for this reducer, as this reducer may sit inside
- // a divergent control flow. Here it uses a variable to cache the current
- // active channels.
- //
+ if (IsWarpReduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
+ std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
- Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
- {
- PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
- if (group_extent > 1) {
- mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
- << (reduce_extent * cast(mask_dtype, group_index)));
+ PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+ if (reduce_extent <= warp_size_) {
Review Comment:
Enter single/sub warp reduction branch when `reduce_extent` is less than or
equal to `warp_size_`
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
- if (warp_size_ % reduce_extent == 0) {
- return true; // warp size is multiple of reduce extent
+ if (warp_size_ % reduce_extent == 0 ||
Review Comment:
Create some bool variable to make this logic look more clear, e.g.:
`is_subwarp_reduction = warp_size_ % reduce_extent == 0`
`is_multiwarp_reduction = max_num_threads_ != -1 && max_num_threads_ <=
warp_size_ * warp_size_ && reduce_extent % warp_size_ == 0`
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
- if (warp_size_ % reduce_extent == 0) {
- return true; // warp size is multiple of reduce extent
+ if (warp_size_ % reduce_extent == 0 ||
+ (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ *
warp_size_ &&
Review Comment:
Shall we throw error is `max_num_threads == -1`?
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
- if (is_warp_reduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
- ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
- //
- // This is the index to the reduction variable, one reduction
- // variable per warp. Local scope seems easier to reason without
- // relying on a pattern match pass to fix it later.
- Array<PrimExpr> zero_indices = {0};
-
- for (size_t idx = 0; idx < size; ++idx) {
- Array<PrimExpr> shape = {1};
-
- Buffer buffer = decl_buffer(shape, types[idx], "red_buf" +
std::to_string(idx));
- Var buffer_var = buffer->data;
-
- shared_buffer_vars[idx] = buffer_var;
- shared_bufs[idx] = buffer;
-
- PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
zero_indices));
-
- // Uses a local variable to store the shuffled data. Later
- // on, an allocation will be built for this local variable.
- local_bufs.push_back(decl_buffer(shape, types[idx], "t" +
std::to_string(idx)));
- }
-
- // The mask for this reducer, as this reducer may sit inside
- // a divergent control flow. Here it uses a variable to cache the current
- // active channels.
- //
+ if (IsWarpReduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
+ std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
- Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
- {
- PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
- if (group_extent > 1) {
- mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
- << (reduce_extent * cast(mask_dtype, group_index)));
+ PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+ if (reduce_extent <= warp_size_) {
+ if (group_extent > 1 && reduce_extent < warp_size_) {
Review Comment:
If `reduce_extent` equals warp size we will skip mask here (0xFFF...)
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
- if (warp_size_ % reduce_extent == 0) {
- return true; // warp size is multiple of reduce extent
+ if (warp_size_ % reduce_extent == 0 ||
+ (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ *
warp_size_ &&
+ reduce_extent % warp_size_ == 0)) {
+ return true; // warp size is multiple or factor of reduce extent
Review Comment:
Please update the comment here.
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
- if (is_warp_reduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
- ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
- //
- // This is the index to the reduction variable, one reduction
- // variable per warp. Local scope seems easier to reason without
- // relying on a pattern match pass to fix it later.
- Array<PrimExpr> zero_indices = {0};
-
- for (size_t idx = 0; idx < size; ++idx) {
- Array<PrimExpr> shape = {1};
-
- Buffer buffer = decl_buffer(shape, types[idx], "red_buf" +
std::to_string(idx));
- Var buffer_var = buffer->data;
-
- shared_buffer_vars[idx] = buffer_var;
- shared_bufs[idx] = buffer;
-
- PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
zero_indices));
-
- // Uses a local variable to store the shuffled data. Later
- // on, an allocation will be built for this local variable.
- local_bufs.push_back(decl_buffer(shape, types[idx], "t" +
std::to_string(idx)));
- }
-
- // The mask for this reducer, as this reducer may sit inside
- // a divergent control flow. Here it uses a variable to cache the current
- // active channels.
- //
+ if (IsWarpReduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
+ std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
- Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
- {
- PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
- if (group_extent > 1) {
- mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
- << (reduce_extent * cast(mask_dtype, group_index)));
+ PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+ if (reduce_extent <= warp_size_) {
+ if (group_extent > 1 && reduce_extent < warp_size_) {
+ mask = mask &
+ (((1 << reduce_extent) - 1) << (reduce_extent *
cast(mask_dtype, group_index)));
}
- seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
- // Push the buffer description. Later this will have an
- // allocation built for it.
- local_bufs.push_back(mask_buffer);
- }
+ std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
+ values, types, combiner, reduce_index, reduce_extent, group_index,
mask, NullOpt, &seq);
+ } else {
Review Comment:
Enter multiwarp reduction branch otherwise.
##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
Review Comment:
Please explain the two-stage logic in the comments as we did previously.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]