This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e13110f  Support sub warp reduction for CUDA target. (#10207)
e13110f is described below

commit e13110f217dfdce04566d62999d07fdcd1ed1b72
Author: Zihao Ye <[email protected]>
AuthorDate: Mon Feb 14 19:41:35 2022 -0800

    Support sub warp reduction for CUDA target. (#10207)
    
    * upd
    
    * upd
    
    * upd
    
    * lint
    
    * fix
    
    * upd docstring
    
    * upd
---
 src/tir/transforms/lower_thread_allreduce.cc       | 101 ++++++++++++++-------
 .../python/unittest/test_subwarp_reduction_cuda.py |  68 ++++++++++++++
 2 files changed, 138 insertions(+), 31 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index 6f7c09c..1c6aa16 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -218,6 +218,33 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
     int reduce_extent, group_extent;
     PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
     PrimExpr group_index = FlattenThread(vpar, &group_extent);
+
+    // the longest contiguous reduce extent after flattening
+    int contiguous_reduce_extent = 1;
+    std::vector<std::tuple<int, int, bool>> block_threads;  // 
tuple(dim_index, extent, is_reduce)
+    for (const ThreadEntry& thr : vred) {
+      if (thr.scope.rank == 1) {  // threadIdx
+        block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
+      }
+    }
+    for (const ThreadEntry& thr : vpar) {
+      if (thr.scope.rank == 1) {  // threadIdx
+        block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
+      }
+    }
+    // sort according to dim_index
+    std::sort(block_threads.begin(), block_threads.end());
+    for (auto&& thr_attr : block_threads) {
+      int dim_index, extent;
+      bool is_reduce;
+      std::tie(dim_index, extent, is_reduce) = thr_attr;
+      if (is_reduce) {
+        contiguous_reduce_extent *= extent;
+      } else {
+        break;
+      }
+    }
+
     std::vector<Stmt> seq;
     std::vector<Var> shared_bufs(size);
     std::vector<Stmt> local_vars;
@@ -238,9 +265,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //
-    if (is_warp_reduction(types)) {
-      // TODO(tvm-team) sub-warp reduction support.
-      ICHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction";
+    if (is_warp_reduction(types, group_extent, reduce_extent, 
contiguous_reduce_extent)) {
+      ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
       //
       // This is the index to the reduction variable, one reduction
       // variable per warp. Local scope seems easier to reason without
@@ -269,6 +295,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
       {
         PrimExpr pred = const_true(1);
         PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+        if (group_extent > 1) {
+          mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * 
group_index));
+        }
         seq.emplace_back(Store(mask_var, mask, index, pred));
         // Push allocation with an empty body. Later this will be fixed
         // when the entire body is ready.
@@ -277,7 +306,11 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       }
 
       // Emit reductions within a warp.
-      for (int offset = warp_size_ / 2; offset > 0; offset /= 2) {
+      int start_offset = 1;
+      while (start_offset * 2 < reduce_extent) {
+        start_offset *= 2;
+      }
+      for (int offset = start_offset; offset > 0; offset /= 2) {
         // Load reduction values, no synchronization needed.
         Array<PrimExpr> a, b;
         for (size_t i = 0; i < size; ++i) {
@@ -323,13 +356,14 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
 
       // Broadcast the reduction result from lane 0 to all other lanes.
       // This avoids to emit predicated stores, as all threads are
-      // uniformmly writting the same result.
+      // uniformly writting the same result.
       //
       for (size_t i = 0; i < size; ++i) {
         Var var = shared_bufs[i];
         PrimExpr pred = const_true(types[i].lanes());
         PrimExpr val = Load(types[i], var, index, pred);
-        PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, 
val, 0);
+        PrimExpr splat =
+            WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 
reduce_extent * group_index);
         seq.push_back(Store(var, splat, index, pred));
       }
 
@@ -346,7 +380,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
         warp_allocs_.insert(node.get());
       }
     } else {
-      int threadx_extent = 1;
       if (reduce_extent == 1) {
         // special case, no reduction is needed.
         std::vector<Stmt> stores(size);
@@ -357,10 +390,6 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
         }
         return SeqStmt::Flatten(stores);
       }
-      // Whether the threadIdx.x is involved in reduction.
-      if (vred[0].scope.dim_index == 0) {
-        threadx_extent = vred[0].extent;
-      }
       // This sync is necessary because there might be incomplete read of
       // previous iteration on the same buffer.
       seq.emplace_back(SyncThread("shared"));
@@ -372,7 +401,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
       }
       seq.emplace_back(SyncThread("shared"));
       seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, 
reduce_index, group_index,
-                                        reduce_extent, threadx_extent));
+                                        reduce_extent, group_extent, 
contiguous_reduce_extent));
       for (size_t idx = 0; idx < size; ++idx) {
         ICHECK(!load_remap_.count(buffers[idx]));
         PrimExpr pred = const_true(types[idx].lanes());
@@ -402,7 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
   // make allreduce.
   Stmt MakeBufAllreduce(const CommReducerNode* combiner, const 
std::vector<DataType>& types,
                         const Array<Var>& shared_bufs, PrimExpr reduce_index, 
PrimExpr group_index,
-                        int reduce_extent, int threadx_extent) {
+                        int reduce_extent, int group_extent, int 
contiguous_reduce_extent) {
     // Get next power of two
     int reduce_align = 1;
     while (reduce_extent > reduce_align) {
@@ -444,9 +473,13 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
-    ICHECK(threadx_extent >= 1 && warp_size_ >= 1);
+
     // normal synchronization
-    while (reduce_align > threadx_extent || reduce_align > warp_size_) {
+    bool warp_align = group_extent == 1 || contiguous_reduce_extent % 
warp_size_ == 0;
+    while (reduce_align > contiguous_reduce_extent || reduce_align > 
warp_size_ || !warp_align) {
+      if (reduce_align == 1) {
+        break;
+      }
       reduce_align = reduce_align >> 1;
       PrimExpr cond = reduce_index < reduce_align;
       seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
@@ -534,22 +567,21 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
   }
 
   // Emit warp shuffle  calls.
-  PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int 
delta_or_lane) {
+  PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr 
delta_or_lane) {
     PrimExpr pred = const_true(1);
     PrimExpr index(0);
     PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
     PrimExpr width = IntImm(DataType::Int(32), warp_size_);
-    Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), 
width, width};
+    Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
     return Call(val.dtype(), op, args);
   }
 
-  // Check if this is a reduction on threadIdx.x and its extent matches
-  // the warp size.
+  // Check if we can use warp level reduction.
   //
-  // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
   // Note: The ROCm backend will only have warp reductions for now.
   // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
-  bool is_warp_reduction(const std::vector<DataType>& types) const {
+  bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, 
int reduce_extent,
+                         int contiguous_reduce_extent) const {
     // Only cuda target supports warp reductions.
     if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) 
return false;
 
@@ -575,18 +607,25 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
       return false;
     }
 
-    const AttrStmtNode* op = thread_extents_.back();
-    DCHECK_EQ(op->attr_key, attr::thread_extent);
-
-    IterVar iv = Downcast<IterVar>(op->node);
-    ThreadEntry e;
-    e.scope = runtime::ThreadScope::Create(iv->thread_tag);
-    e.extent = 0;
-    if (auto ptr = op->value.as<IntImmNode>()) {
-      e.extent = static_cast<int>(ptr->value);
+    // reduce region must be contiguous.
+    if (contiguous_reduce_extent != reduce_extent) {
+      return false;
     }
 
-    return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 
1;
+    // whether reduce_extent and group_extent are vaild for warp reduction.
+    if (target_->kind->name == "rocm") {
+      return reduce_extent == warp_size_;
+    } else {  // target_->kind->name == "cuda"
+      if (reduce_extent == 1) {
+        return false;  // no need to warp reduce
+      } else {
+        if (warp_size_ % reduce_extent == 0) {
+          return true;  // warp size is multiple of reduce extent
+        } else {
+          return group_extent == 1 && reduce_extent <= warp_size_;
+        }
+      }
+    }
   }
 
   // The target.
diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py 
b/tests/python/unittest/test_subwarp_reduction_cuda.py
new file mode 100644
index 0000000..8778c75
--- /dev/null
+++ b/tests/python/unittest/test_subwarp_reduction_cuda.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_func
+def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> 
None:
+    A = T.match_buffer(a, [1, d1, d2, d3])
+    B = T.match_buffer(b, [1, d1, d2])
+
+    for i, j, k, l in T.grid(1, d1, d2, d3):
+        with T.block("reduce"):
+            vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
+            with T.init():
+                B[vi, vj, vk] = 0.0
+            B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
+
+
[email protected]_gpu
[email protected]_cuda
+def test_cuda_subwarp_reduction():
+    def check(d1: int, d2: int, d3: int):
+        _, _, _d1, _d2, _d3 = reduce.params
+        mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
+        sch = tvm.tir.Schedule(mod)
+        blk = sch.get_block("reduce")
+        i, j, k, l = sch.get_loops(blk)
+        sch.bind(i, "blockIdx.x")
+        sch.bind(j, "threadIdx.z")
+        sch.bind(k, "threadIdx.y")
+        sch.bind(l, "threadIdx.x")
+        f = tvm.build(sch.mod["main"], target="cuda")
+
+        # prepare input and output array
+        a_np = np.random.rand(1, d1, d2, d3).astype("float32")
+        b_np = a_np.sum(axis=-1).astype("float32")
+        a = tvm.nd.array(a_np, tvm.cuda(0))
+        b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
+
+        # launch kernel
+        f(a, b)
+        tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
+
+    for d1 in range(1, 5):
+        for d2 in range(1, 5):
+            for d3 in range(2, 33):
+                check(d1, d2, d3)
+
+
+if __name__ == "__main__":
+    test_cuda_subwarp_reduction()

Reply via email to