This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 236eb31f09 [BugFix][TIR] Fix multi-grouped multi-warp allreduce
(#15399)
236eb31f09 is described below
commit 236eb31f09e998ade8bbb395cf6b9de1032dd9c3
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Jul 25 07:54:49 2023 -0700
[BugFix][TIR] Fix multi-grouped multi-warp allreduce (#15399)
PR #15327 and #15373 introduced multi-warp allreduce implementation.
At the time of the introduction, I tested the correctness numerically
via the workload of "taking a matrix of ones as input, computing the
summation over each row". Both PR passed this numerical tess, while
I didn't realize that this test is not complete and cannot guarantee
the correctness.
The previous implementation has bug which can be tested by turning
the input matrix from ones to random floating-point numbers. This will
expose the issues of the previous implementation.
Therefore, this PR fixes the issues, and add the numerical tests
for multi-warp allreduce into `test_allreduce_cuda.py`. By reducing
some of the redundant tests in that file, we hope this can reduce the
testing time a bit while still guarantee the correctness.
Sorry for not testing the implementation completely before.
---
src/tir/transforms/lower_thread_allreduce.cc | 38 ++++++++++++----------
...rp_reduction_cuda.py => test_allreduce_cuda.py} | 8 +++--
.../test_tir_transform_lower_thread_all_reduce.py | 24 ++++++++------
3 files changed, 40 insertions(+), 30 deletions(-)
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index 91a37dc35e..438dccff0b 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
if (auto it = alloc_remap_.find(node->buffer_var.get()); it !=
alloc_remap_.end()) {
- const AllocateNode* repl = it->second.as<AllocateNode>();
+ Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
- write_ptr->buffer_var = repl->buffer_var;
- write_ptr->dtype = repl->dtype;
- write_ptr->extents = repl->extents;
- write_ptr->condition = repl->condition;
+ write_ptr->buffer_var = buf->data;
+ write_ptr->dtype = buf->dtype;
+ write_ptr->extents = buf->shape;
+ write_ptr->condition = const_true(buf->dtype.lanes());
+
+ if (buf.scope() == "shared") {
+ // Use volatile access to shared buffer.
+ write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1,
write_ptr->body);
+ }
}
return std::move(node);
}
@@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
- values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{reduce_index});
+ values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
+ /*indices=*/{group_index * n_warps +
reduce_index});
}
if (n_warps < warp_size_) {
- mask = mask & (((1 << n_warps) - 1) << group_index);
+ mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
- /*predicate=*/reduce_index < make_const(reduce_index->dtype,
group_extent * n_warps),
- &seq);
+ /*predicate=*/reduce_index < make_const(reduce_index->dtype,
n_warps), &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 5. Create shared memory buffer(s) of `group_extent` elements,
storing
@@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result",
/*storage_scope=*/"shared");
write_result.push_back(
- BufferStore(broadcast_shared_buf, reduce_results[i],
{zero_index}));
+ BufferStore(broadcast_shared_buf, reduce_results[i],
{group_index}));
// Update `reduce_results`, pointing to the value loaded from the
shared memory buffer.
- reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
+ reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index,
SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
@@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
load_remap_[buffers[i]->data.get()] = reduce_results[i];
auto node = Allocate(buf->data, types[i], buf->shape, pred,
Evaluate(0));
- alloc_remap_[buffers[i]->data.get()] = node;
+ alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
@@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" +
std::to_string(idx), "shared");
+ shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype,
group_extent * reduce_extent)},
+ types[idx], "red_buf" +
std::to_string(idx), "shared");
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index,
reduce_extent)}));
}
@@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
{BufIndex(make_zero(reduce_index.dtype()),
group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
- alloc_remap_[buffers[idx]->data.get()] =
- Allocate(shared_bufs[idx]->data, types[idx],
- {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred,
Evaluate(0));
+ alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
@@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// Allocate remap
- std::unordered_map<const VarNode*, Stmt> alloc_remap_;
+ std::unordered_map<const VarNode*, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py
b/tests/python/unittest/test_allreduce_cuda.py
similarity index 94%
rename from tests/python/unittest/test_subwarp_reduction_cuda.py
rename to tests/python/unittest/test_allreduce_cuda.py
index 7a7b1b06ba..e9a8ef81cf 100644
--- a/tests/python/unittest/test_subwarp_reduction_cuda.py
+++ b/tests/python/unittest/test_allreduce_cuda.py
@@ -48,7 +48,7 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2:
T.int32, d3: T.int32)
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
-def test_cuda_subwarp_reduction():
+def test_allreduce_cuda():
def check_sum(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
@@ -95,10 +95,12 @@ def test_cuda_subwarp_reduction():
for d1 in range(1, 5):
for d2 in range(1, 5):
- for d3 in range(2, 33):
+ for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512,
1024]:
+ if d1 * d2 * d3 > 1024:
+ continue
check_sum(d1, d2, d3)
check_max(d1, d2, d3)
if __name__ == "__main__":
- test_cuda_subwarp_reduction()
+ test_allreduce_cuda()
diff --git
a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
index 1fb8aea66e..9d53b1f9df 100644
--- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -387,6 +387,7 @@ class TestMultiWarpReduce1(BaseCompare):
for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result = T.allocate([1], "float32", "shared")
+ T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -463,6 +464,7 @@ class TestMultiWarpReduce2(BaseCompare):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
red_result = T.allocate([1], "float32", "shared")
+ T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -550,6 +552,7 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 4)
red_result = T.allocate([4], "float32", "shared")
+ T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
with T.attr(
@@ -585,11 +588,11 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] =
red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
- if threadIdx_x < 16:
- red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+ if threadIdx_x < 4:
+ red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 +
threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
- T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15,
threadIdx_y))
+ T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15,
threadIdx_y * 4))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32,
32)
@@ -597,11 +600,11 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32,
32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
- red_result_1[0] = red_buf0_3[0]
+ red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((4,), data=B.data)
- B_1[threadIdx_y] = red_result_1[0]
+ B_1[threadIdx_y] = red_result_1[threadIdx_y]
class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
@@ -636,6 +639,7 @@ class
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
threadIdx_y = T.launch_thread("threadIdx.y", 2)
in_thread_B = T.allocate([1], "float32", "local")
red_result = T.allocate([2], "float32", "shared")
+ T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 512)
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
in_thread_B_1[0] = T.float32(0)
@@ -675,11 +679,11 @@ class
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] =
red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
- if threadIdx_x < 32:
- red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+ if threadIdx_x < 16:
+ red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 +
threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
- T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535,
threadIdx_y))
+ T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535,
threadIdx_y * 16))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32,
32)
@@ -691,11 +695,11 @@ class
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32,
32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
- red_result_1[0] = red_buf0_3[0]
+ red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
- B_1[threadIdx_y] = red_result_1[0]
+ B_1[threadIdx_y] = red_result_1[threadIdx_y]
if __name__ == "__main__":