This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5029477268 [TIR] Allreduce broadcast result to each thread in 
multi-warp case (#15373)
5029477268 is described below

commit 50294772681a81ab94d55a6f6b036ea84220bd0d
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Jul 21 17:00:03 2023 -0700

    [TIR] Allreduce broadcast result to each thread in multi-warp case (#15373)
    
    PR #15327 introduces the warp-level primitive support in multi-warp
    allreduce. However, due to the specialty of the two-stage
    shuffle-down reduction implementation of the allreduce in multi-warp
    scenarios, PR #15327 did not broadcast the allreduce result to each
    reduction thread. This behavior does not align with the semantics
    of allreduce and is not ideal for many use cases. Therefore, this
    PR completes the implementation by inserting a stage of writing the
    reduction results to shared memory, so that each reduction thread
    across all the reduction warps can access the reduction results.
    
    This shared memory write-back stage will only be inserted in
    multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync`
    is used to broadcast the reduction results across reduction threads.
    Since in multi-warp settings we cannot leverage warp-level primitives
    to broadcast the value, we can only make use of shared memory.
    
    The numerical correctness are verified locally.
---
 src/tir/transforms/lower_thread_allreduce.cc       | 94 +++++++++-------------
 .../test_tir_transform_lower_thread_all_reduce.py  | 52 +++++++-----
 2 files changed, 70 insertions(+), 76 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index b47e837711..91a37dc35e 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -38,27 +38,6 @@
 namespace tvm {
 namespace tir {
 
-class UpdatePointerStorageScopeAllReduce final : public 
UpdatePointerStorageScope {
- public:
-  explicit UpdatePointerStorageScopeAllReduce(
-      const std::unordered_map<const VarNode*, String>& new_storage_scopes)
-      : UpdatePointerStorageScope(new_storage_scopes) {}
-
-  Stmt VisitStmt_(const AllocateNode* op) final {
-    auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
-    auto new_scope = GetPtrStorageScope(remapped);
-    if (new_scope != GetPtrStorageScope(op->buffer_var)) {
-      Stmt body = StmtExprMutator::VisitStmt(op->body);
-      if (new_scope == "shared") {
-        // use volatile access to shared buffer.
-        body = AttrStmt(remapped, attr::volatile_scope, 1, body);
-      }
-      return Allocate(remapped, op->dtype, op->extents, op->condition, body, 
op->annotations);
-    }
-    return StmtExprMutator::VisitStmt_(op);
-  }
-};
-
 class ThreadAllreduceBuilder final : public StmtExprMutator {
  public:
   explicit ThreadAllreduceBuilder(const TargetNode* target)
@@ -98,11 +77,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
     if (auto it = alloc_remap_.find(node->buffer_var.get()); it != 
alloc_remap_.end()) {
       const AllocateNode* repl = it->second.as<AllocateNode>();
-      if (warp_allocs_.count(repl)) {
-        new_storage_scopes_[repl->buffer_var.get()] = "local";
-      } else {
-        new_storage_scopes_[repl->buffer_var.get()] = "shared";
-      }
       auto write_ptr = node.CopyOnWrite();
       write_ptr->buffer_var = repl->buffer_var;
       write_ptr->dtype = repl->dtype;
@@ -161,8 +135,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     return std::move(store);
   }
 
-  std::unordered_map<const VarNode*, String> new_storage_scopes_;
-
  private:
   // Thread entry
   struct ThreadEntry {
@@ -310,6 +282,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     // In the second stage we use the first 16 lanes of the first warp to 
reduce
     // the remaining elements, and this reduction can also be optimized by
     // shuffle_down warp-level primitives.
+    PrimExpr zero_index = make_const(reduce_index->dtype, 0);
     if (IsWarpReduction(types, group_extent, reduce_extent, 
contiguous_reduce_extent)) {
       std::vector<PrimExpr> reduce_results;
       DataType mask_dtype = DataType::UInt(32);
@@ -322,6 +295,18 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
         }
         std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
             values, types, combiner, reduce_index, reduce_extent, group_index, 
mask, NullOpt, &seq);
+
+        // Broadcast the reduction result from lane 0 to all other lanes.
+        // This avoids to emit predicated stores, as all threads are
+        // uniformly writing the same result.
+        for (size_t i = 0; i < size; ++i) {
+          Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
+          PrimExpr val = BufferLoad(buf, {zero_index});
+          ICHECK_EQ(val->dtype, types[i]);
+          PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), 
new_alloc_bufs.back(), val,
+                                       reduce_extent * group_index);
+          seq.push_back(BufferStore(buf, splat, {zero_index}));
+        }
       } else {
         int n_warps = reduce_extent / warp_size_;
         std::vector<Buffer> local_bufs;
@@ -352,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
               /*value=*/reduce_results[i],
               /*indices=*/{group_index * n_warps + floordiv(reduce_index, 
warp_size_)}));
         }
-        PrimExpr cond = floormod(reduce_index, warp_size_) == 
make_const(reduce_index->dtype, 0);
+        PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
         seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
         seq.push_back(SyncThread("shared"));
 
@@ -369,6 +354,23 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
             /*predicate=*/reduce_index < make_const(reduce_index->dtype, 
group_extent * n_warps),
             &seq);
         new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), 
local_bufs.end());
+
+        // 5. Create shared memory buffer(s) of `group_extent` elements, 
storing
+        // the allreduce results so each thread can access.
+        std::vector<Stmt> write_result;
+        write_result.reserve(size);
+        for (size_t i = 0; i < size; ++i) {
+          
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
+          Buffer broadcast_shared_buf = decl_buffer(
+              /*shape=*/{make_const(reduce_index->dtype, group_extent)},
+              /*dtype=*/buffers[i]->dtype, /*name=*/"red_result", 
/*storage_scope=*/"shared");
+          write_result.push_back(
+              BufferStore(broadcast_shared_buf, reduce_results[i], 
{zero_index}));
+          // Update `reduce_results`, pointing to the value loaded from the 
shared memory buffer.
+          reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
+        }
+        seq.push_back(IfThenElse(reduce_index == zero_index, 
SeqStmt::Flatten(write_result)));
+        seq.push_back(SyncThread("shared"));
       }
 
       // Write back allreduce results and update existing allocations.
@@ -379,12 +381,10 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
         ICHECK_EQ(reduce_results[i]->dtype, types[i]);
         load_remap_[buffers[i]->data.get()] = reduce_results[i];
 
-        Array<PrimExpr> extents{PrimExpr(1)};
-        auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
+        auto node = Allocate(buf->data, types[i], buf->shape, pred, 
Evaluate(0));
         alloc_remap_[buffers[i]->data.get()] = node;
         var_remap_[buffers[i]->data.get()] = buf->data;
         buf_remap_[buffers[i].get()] = buf;
-        warp_allocs_.insert(node.get());
       }
     } else {
       std::vector<Buffer> shared_bufs(size);
@@ -400,7 +400,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
       // previous iteration on the same buffer.
       seq.emplace_back(SyncThread("shared"));
       for (size_t idx = 0; idx < size; ++idx) {
-        shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + 
std::to_string(idx));
+        shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + 
std::to_string(idx), "shared");
         seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
                                      {BufIndex(reduce_index, group_index, 
reduce_extent)}));
       }
@@ -426,9 +426,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     Stmt body = SeqStmt::Flatten(seq);
     for (Buffer buf : new_alloc_bufs) {
       body = Allocate(buf->data, buf->dtype, buf->shape, 
const_true(buf->dtype.lanes()), body);
-      if (buf.scope() != "shared") {
-        new_storage_scopes_[buf->data.get()] = "local";
-      }
     }
 
     return body;
@@ -457,12 +454,13 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
     std::vector<Stmt> load_values;
     load_values.reserve(n_buffers);
     for (int idx = 0; idx < n_buffers; ++idx) {
-      shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + 
std::to_string(idx)));
+      shared_bufs.push_back(
+          decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx), 
"local"));
       load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], 
zero_indices));
 
       // Uses a local variable to store the shuffled data.  Later
       // on, an allocation will be built for this local variable.
-      local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + 
std::to_string(idx)));
+      local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + 
std::to_string(idx), "local"));
     }
 
     if (predicate.defined()) {
@@ -474,7 +472,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     // The mask for this reducer, as this reducer may sit inside
     // a divergent control flow. Here it uses a variable to cache the current
     // active channels.
-    Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask");
+    Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
     {
       seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
       // Push the buffer description.  Later this will have an
@@ -543,18 +541,6 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       }
     }
 
-    // Broadcast the reduction result from lane 0 to all other lanes.
-    // This avoids to emit predicated stores, as all threads are
-    // uniformly writing the same result.
-    for (int i = 0; i < n_buffers; ++i) {
-      Buffer buf = shared_bufs[i];
-      PrimExpr val = BufferLoad(buf, zero_indices);
-      ICHECK_EQ(val->dtype, dtypes[i]);
-      PrimExpr splat =
-          WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, 
reduce_extent * group_index);
-      seq->push_back(BufferStore(buf, splat, zero_indices));
-    }
-
     std::vector<PrimExpr> reduce_results;
     reduce_results.reserve(n_buffers);
     for (int i = 0; i < n_buffers; ++i) {
@@ -791,8 +777,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
   std::unordered_map<const VarNode*, Var> var_remap_;
   // Buffer remap
   std::unordered_map<const BufferNode*, Buffer> buf_remap_;
-  // Allocate from warp reductions
-  std::unordered_set<const void*> warp_allocs_;
   // Internal analyzer
   arith::Analyzer analyzer_;
 };
@@ -806,9 +790,7 @@ Pass LowerThreadAllreduce() {
     ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target 
attribute";
     const TargetNode* target_node = target.as<TargetNode>();
     ThreadAllreduceBuilder thread_all_reduce(target_node);
-    auto reduce_body = thread_all_reduce(n->body);
-    n->body =
-        
UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
+    n->body = thread_all_reduce(n->body);
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
diff --git 
a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py 
b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
index f354dfe9ca..1fb8aea66e 100644
--- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -386,13 +386,14 @@ class TestMultiWarpReduce1(BaseCompare):
         T.func_attr({"target": T.target("cuda", host="llvm")})
         for i in range(128):
             threadIdx_x = T.launch_thread("threadIdx.x", 128)
-            red_buf0 = T.allocate([1], "float32", "local")
-            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+            red_result = T.allocate([1], "float32", "shared")
+            red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
             with T.attr(
                 T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
                 "reduce_scope",
                 T.reinterpret("handle", T.uint64(0)),
             ):
+                red_buf0 = T.allocate([1], "float32", "local")
                 mask = T.allocate([1], "uint32", "local")
                 t0 = T.allocate([1], "float32", "local")
                 red_buf0_1 = T.allocate([1], "float32", "local")
@@ -415,11 +416,11 @@ class TestMultiWarpReduce1(BaseCompare):
                 red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
                 t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 
32, 32)
                 red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
-                red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 
0, 32, 32)
                 red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, 
scope="shared")
                 if threadIdx_x % 32 == 0:
                     red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
                 T.tvm_storage_sync("shared")
+                red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
                 if threadIdx_x < 4:
                     red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
                 mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -429,10 +430,12 @@ class TestMultiWarpReduce1(BaseCompare):
                 red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
                 t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 
32, 32)
                 red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
-                red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 
0, 32, 32)
+                if threadIdx_x == 0:
+                    red_result_1[0] = red_buf0_3[0]
+                T.tvm_storage_sync("shared")
             if threadIdx_x == 0:
                 B_1 = T.Buffer((128,), data=B.data)
-                B_1[i] = red_buf0_3[0]
+                B_1[i] = red_result_1[0]
 
 
 class TestMultiWarpReduce2(BaseCompare):
@@ -459,13 +462,14 @@ class TestMultiWarpReduce2(BaseCompare):
     def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), 
"float32")):
         T.func_attr({"target": T.target("cuda", host="llvm")})
         threadIdx_x = T.launch_thread("threadIdx.x", 1024)
-        red_buf0 = T.allocate([1], "float32", "local")
-        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        red_result = T.allocate([1], "float32", "shared")
+        red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
         with T.attr(
             T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
             "reduce_scope",
             T.reinterpret("handle", T.uint64(0)),
         ):
+            red_buf0 = T.allocate([1], "float32", "local")
             mask = T.allocate([1], "uint32", "local")
             t0 = T.allocate([1], "float32", "local")
             red_buf0_1 = T.allocate([1], "float32", "local")
@@ -488,11 +492,11 @@ class TestMultiWarpReduce2(BaseCompare):
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
             t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 
32)
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
-            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 
32, 32)
             red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, 
scope="shared")
             if threadIdx_x % 32 == 0:
                 red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
             T.tvm_storage_sync("shared")
+            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
             if threadIdx_x < 32:
                 red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
             mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -508,10 +512,12 @@ class TestMultiWarpReduce2(BaseCompare):
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
             t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 
32)
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
-            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 
32, 32)
+            if threadIdx_x == 0:
+                red_result_1[0] = red_buf0_3[0]
+            T.tvm_storage_sync("shared")
         if threadIdx_x == 0:
             B_1 = T.Buffer((1,), data=B.data)
-            B_1[0] = red_buf0_3[0]
+            B_1[0] = red_result_1[0]
 
 
 class TestMultiGroupMultiWarpReduction(BaseCompare):
@@ -543,14 +549,15 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
     def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), 
"float32")):
         T.func_attr({"target": T.target("cuda", host="llvm")})
         threadIdx_y = T.launch_thread("threadIdx.y", 4)
-        red_buf0 = T.allocate([1], "float32", "local")
+        red_result = T.allocate([4], "float32", "shared")
         threadIdx_x = T.launch_thread("threadIdx.x", 128)
-        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
         with T.attr(
             T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
             "reduce_scope",
             T.reinterpret("handle", T.uint64(0)),
         ):
+            red_buf0 = T.allocate([1], "float32", "local")
             mask = T.allocate([1], "uint32", "local")
             t0 = T.allocate([1], "float32", "local")
             red_buf0_1 = T.allocate([1], "float32", "local")
@@ -573,11 +580,11 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
             t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 
32)
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
-            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * 
threadIdx_y, 32, 32)
             red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, 
scope="shared")
             if threadIdx_x % 32 == 0:
                 red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = 
red_buf0_2[0]
             T.tvm_storage_sync("shared")
+            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
             if threadIdx_x < 16:
                 red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
             mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -589,10 +596,12 @@ class TestMultiGroupMultiWarpReduction(BaseCompare):
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
             t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 
32)
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
-            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 4 * 
threadIdx_y, 32, 32)
+            if threadIdx_x == 0:
+                red_result_1[0] = red_buf0_3[0]
+            T.tvm_storage_sync("shared")
         if threadIdx_x == 0:
             B_1 = T.Buffer((4,), data=B.data)
-            B_1[threadIdx_y] = red_buf0_3[0]
+            B_1[threadIdx_y] = red_result_1[0]
 
 
 class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
@@ -626,19 +635,20 @@ class 
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
         T.func_attr({"target": T.target("cuda", host="llvm")})
         threadIdx_y = T.launch_thread("threadIdx.y", 2)
         in_thread_B = T.allocate([1], "float32", "local")
-        red_buf0 = T.allocate([1], "float32", "local")
+        red_result = T.allocate([2], "float32", "shared")
         threadIdx_x = T.launch_thread("threadIdx.x", 512)
         in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
         in_thread_B_1[0] = T.float32(0)
         if threadIdx_x < 70:
             A_1 = T.Buffer((140,), data=A.data)
             in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + 
threadIdx_x]
-        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
         with T.attr(
             T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
             "reduce_scope",
             T.reinterpret("handle", T.uint64(0)),
         ):
+            red_buf0 = T.allocate([1], "float32", "local")
             mask = T.allocate([1], "uint32", "local")
             t0 = T.allocate([1], "float32", "local")
             red_buf0_1 = T.allocate([1], "float32", "local")
@@ -660,11 +670,11 @@ class 
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
             t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 
32)
             red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
-            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * 
threadIdx_y, 32, 32)
             red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, 
scope="shared")
             if threadIdx_x % 32 == 0:
                 red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = 
red_buf0_2[0]
             T.tvm_storage_sync("shared")
+            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
             if threadIdx_x < 32:
                 red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
             mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -680,10 +690,12 @@ class 
TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
             t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 
32)
             red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
-            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 16 * 
threadIdx_y, 32, 32)
+            if threadIdx_x == 0:
+                red_result_1[0] = red_buf0_3[0]
+            T.tvm_storage_sync("shared")
         if threadIdx_x == 0:
             B_1 = T.Buffer((2,), data=B.data)
-            B_1[threadIdx_y] = red_buf0_3[0]
+            B_1[threadIdx_y] = red_result_1[0]
 
 
 if __name__ == "__main__":

Reply via email to