This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ca0bea2d8 [Fix][TIR] LowerThreadAllreduce warp reduction mask (#17307)
6ca0bea2d8 is described below

commit 6ca0bea2d89bf11a315332983486437b6a4a90f2
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Aug 28 19:31:02 2024 -0400

    [Fix][TIR] LowerThreadAllreduce warp reduction mask (#17307)
    
    The warp reduction implemented by "shuffle down" primitive takes a
    mask denoting the active threads within the warp that participate
    in this shuffle.
    
    Previously we compute the mask, while in practice we find that it
    results in "CUDA illegal instruction" error on NVIDIA H100 GPU when
    the mask is set, and the issue is gone if we do not update the mask.
    Therefore, this PR updates the allreduce lowering to remove the mask
    update.
    
    Confirmed the correctness on the following devices:
    * NVIDIA H100,
    * NVIDIA RTX 4090,
    * AMD Radeon 7900 XTX,
    * Apple M2 Ultra.
---
 src/tir/transforms/lower_thread_allreduce.cc              |  7 -------
 .../test_tir_transform_lower_thread_all_reduce.py         | 15 ++++-----------
 2 files changed, 4 insertions(+), 18 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index 37d8f67580..dde33fa267 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
 
       if (reduce_extent <= warp_size_) {
-        if (group_extent > 1 && reduce_extent < warp_size_) {
-          mask = mask &
-                 (((1 << reduce_extent) - 1) << (reduce_extent * 
cast(mask_dtype, group_index)));
-        }
         std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
             values, types, combiner, reduce_index, reduce_extent, group_index, 
mask, NullOpt, &seq);
 
@@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
           values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
                                  /*indices=*/{group_index * n_warps + 
reduce_index});
         }
-        if (n_warps < warp_size_) {
-          mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
-        }
         std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
             values, types, combiner, reduce_index, n_warps, group_index, mask,
             /*predicate=*/reduce_index < make_const(reduce_index->dtype, 
n_warps), &seq);
diff --git 
a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py 
b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
index d8c9568da9..18d6339349 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
@@ -342,10 +342,7 @@ class TestMultiGroupMask1(BaseCompare):
             t0 = T.decl_buffer([1], "float32", scope="local")
             A_1 = T.Buffer((256,), data=A.data)
             red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x]
-            mask[0] = T.bitwise_and(
-                T.tvm_warp_activemask(),
-                T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", 
threadIdx_y)),
-            )
+            mask[0] = T.tvm_warp_activemask()
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32)
             red_buf0_1[0] = red_buf0_1[0] + t0[0]
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32)
@@ -421,7 +418,7 @@ class TestMultiWarpReduce1(BaseCompare):
                 T.tvm_storage_sync("shared")
                 if threadIdx_x < 4:
                     red_buf0[0] = red_buf_staging[threadIdx_x]
-                mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15))
+                mask[0] = T.tvm_warp_activemask()
                 t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 
32)
                 red_buf0[0] = red_buf0[0] + t0[0]
                 t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 
32)
@@ -573,9 +570,7 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
             T.tvm_storage_sync("shared")
             if threadIdx_x < 4:
                 red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x]
-            mask[0] = T.bitwise_and(
-                T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, 
threadIdx_y * 4))
-            )
+            mask[0] = T.tvm_warp_activemask()
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32)
             red_buf0[0] = red_buf0[0] + t0[0]
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32)
@@ -657,9 +652,7 @@ class 
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
             T.tvm_storage_sync("shared")
             if threadIdx_x < 16:
                 red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x]
-            mask[0] = T.bitwise_and(
-                T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, 
threadIdx_y * 16))
-            )
+            mask[0] = T.tvm_warp_activemask()
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32)
             red_buf0[0] = red_buf0[0] + t0[0]
             t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32)

Reply via email to