This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new df29e82629 [TIR][CUDA] Fix sub-warp reduction using "max" (#12275)
df29e82629 is described below

commit df29e826290a3eba606a93b59f640e025b5cbd4d
Author: MoebiusMeow <[email protected]>
AuthorDate: Thu Aug 4 02:02:52 2022 +0800

    [TIR][CUDA] Fix sub-warp reduction using "max" (#12275)
    
    * upd subwarp unittest
    
    * fix range check in sub-warp reduction
    
    * upd: sub-warp max unit test
---
 src/tir/transforms/lower_thread_allreduce.cc       | 14 +++++++-
 .../python/unittest/test_subwarp_reduction_cuda.py | 40 ++++++++++++++++++++--
 2 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index aeb819c516..43f7a103db 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -420,7 +420,19 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
           Buffer buf = shared_bufs[i];
           stores[i] = BufferStore(buf, ret[i], zero_indices);
         }
-        seq.push_back(SeqStmt::Flatten(stores));
+
+        // During the sub-warp reduction, values from inactive threads could 
be read,
+        // which is an undefined behavior according to the cuda document.
+        //
+        // In practise, the return value are usually 0, which does no harm to 
sum reduction.
+        // However, the result can be incorrect in max or prod reduction.
+        // Therefore an additional range check has to be performed to ensure 
the correctness.
+        if (offset * 2 > reduce_extent) {
+          PrimExpr cond = reduce_index + offset < reduce_extent;
+          seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
+        } else {
+          seq.push_back(SeqStmt::Flatten(stores));
+        }
       }
 
       // Broadcast the reduction result from lane 0 to all other lanes.
diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py 
b/tests/python/unittest/test_subwarp_reduction_cuda.py
index 8778c75f56..7a7b1b06ba 100644
--- a/tests/python/unittest/test_subwarp_reduction_cuda.py
+++ b/tests/python/unittest/test_subwarp_reduction_cuda.py
@@ -33,10 +33,23 @@ def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: 
T.int32, d3: T.int32) -> N
             B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
 
 
[email protected]_func
+def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: 
T.int32) -> None:
+    A = T.match_buffer(a, [1, d1, d2, d3])
+    B = T.match_buffer(b, [1, d1, d2])
+
+    for i, j, k, l in T.grid(1, d1, d2, d3):
+        with T.block("reduce"):
+            vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
+            with T.init():
+                B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
+            B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
+
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_subwarp_reduction():
-    def check(d1: int, d2: int, d3: int):
+    def check_sum(d1: int, d2: int, d3: int):
         _, _, _d1, _d2, _d3 = reduce.params
         mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
         sch = tvm.tir.Schedule(mod)
@@ -58,10 +71,33 @@ def test_cuda_subwarp_reduction():
         f(a, b)
         tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
 
+    def check_max(d1: int, d2: int, d3: int):
+        _, _, _d1, _d2, _d3 = reduce_max.params
+        mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
+        sch = tvm.tir.Schedule(mod)
+        blk = sch.get_block("reduce")
+        i, j, k, l = sch.get_loops(blk)
+        sch.bind(i, "blockIdx.x")
+        sch.bind(j, "threadIdx.z")
+        sch.bind(k, "threadIdx.y")
+        sch.bind(l, "threadIdx.x")
+        f = tvm.build(sch.mod["main"], target="cuda")
+
+        # prepare input and output array
+        a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
+        b_np = a_np.max(axis=-1).astype("float32")
+        a = tvm.nd.array(a_np, tvm.cuda(0))
+        b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
+
+        # launch kernel
+        f(a, b)
+        tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
+
     for d1 in range(1, 5):
         for d2 in range(1, 5):
             for d3 in range(2, 33):
-                check(d1, d2, d3)
+                check_sum(d1, d2, d3)
+                check_max(d1, d2, d3)
 
 
 if __name__ == "__main__":

Reply via email to