This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new df29e82629 [TIR][CUDA] Fix sub-warp reduction using "max" (#12275)
df29e82629 is described below
commit df29e826290a3eba606a93b59f640e025b5cbd4d
Author: MoebiusMeow <[email protected]>
AuthorDate: Thu Aug 4 02:02:52 2022 +0800
[TIR][CUDA] Fix sub-warp reduction using "max" (#12275)
* upd subwarp unittest
* fix range check in sub-warp reduction
* upd: sub-warp max unit test
---
src/tir/transforms/lower_thread_allreduce.cc | 14 +++++++-
.../python/unittest/test_subwarp_reduction_cuda.py | 40 ++++++++++++++++++++--
2 files changed, 51 insertions(+), 3 deletions(-)
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index aeb819c516..43f7a103db 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -420,7 +420,19 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
Buffer buf = shared_bufs[i];
stores[i] = BufferStore(buf, ret[i], zero_indices);
}
- seq.push_back(SeqStmt::Flatten(stores));
+
+ // During the sub-warp reduction, values from inactive threads could
be read,
+ // which is an undefined behavior according to the cuda document.
+ //
+ // In practise, the return value are usually 0, which does no harm to
sum reduction.
+ // However, the result can be incorrect in max or prod reduction.
+ // Therefore an additional range check has to be performed to ensure
the correctness.
+ if (offset * 2 > reduce_extent) {
+ PrimExpr cond = reduce_index + offset < reduce_extent;
+ seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
+ } else {
+ seq.push_back(SeqStmt::Flatten(stores));
+ }
}
// Broadcast the reduction result from lane 0 to all other lanes.
diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py
b/tests/python/unittest/test_subwarp_reduction_cuda.py
index 8778c75f56..7a7b1b06ba 100644
--- a/tests/python/unittest/test_subwarp_reduction_cuda.py
+++ b/tests/python/unittest/test_subwarp_reduction_cuda.py
@@ -33,10 +33,23 @@ def reduce(a: T.handle, b: T.handle, d1: T.int32, d2:
T.int32, d3: T.int32) -> N
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
[email protected]_func
+def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3:
T.int32) -> None:
+ A = T.match_buffer(a, [1, d1, d2, d3])
+ B = T.match_buffer(b, [1, d1, d2])
+
+ for i, j, k, l in T.grid(1, d1, d2, d3):
+ with T.block("reduce"):
+ vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
+ with T.init():
+ B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
+ B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
+
+
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_subwarp_reduction():
- def check(d1: int, d2: int, d3: int):
+ def check_sum(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
@@ -58,10 +71,33 @@ def test_cuda_subwarp_reduction():
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
+ def check_max(d1: int, d2: int, d3: int):
+ _, _, _d1, _d2, _d3 = reduce_max.params
+ mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
+ sch = tvm.tir.Schedule(mod)
+ blk = sch.get_block("reduce")
+ i, j, k, l = sch.get_loops(blk)
+ sch.bind(i, "blockIdx.x")
+ sch.bind(j, "threadIdx.z")
+ sch.bind(k, "threadIdx.y")
+ sch.bind(l, "threadIdx.x")
+ f = tvm.build(sch.mod["main"], target="cuda")
+
+ # prepare input and output array
+ a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
+ b_np = a_np.max(axis=-1).astype("float32")
+ a = tvm.nd.array(a_np, tvm.cuda(0))
+ b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
+
+ # launch kernel
+ f(a, b)
+ tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
+
for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in range(2, 33):
- check(d1, d2, d3)
+ check_sum(d1, d2, d3)
+ check_max(d1, d2, d3)
if __name__ == "__main__":