This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e13110f Support sub warp reduction for CUDA target. (#10207)
e13110f is described below
commit e13110f217dfdce04566d62999d07fdcd1ed1b72
Author: Zihao Ye <[email protected]>
AuthorDate: Mon Feb 14 19:41:35 2022 -0800
Support sub warp reduction for CUDA target. (#10207)
* upd
* upd
* upd
* lint
* fix
* upd docstring
* upd
---
src/tir/transforms/lower_thread_allreduce.cc | 101 ++++++++++++++-------
.../python/unittest/test_subwarp_reduction_cuda.py | 68 ++++++++++++++
2 files changed, 138 insertions(+), 31 deletions(-)
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index 6f7c09c..1c6aa16 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -218,6 +218,33 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);
+
+ // the longest contiguous reduce extent after flattening
+ int contiguous_reduce_extent = 1;
+ std::vector<std::tuple<int, int, bool>> block_threads; //
tuple(dim_index, extent, is_reduce)
+ for (const ThreadEntry& thr : vred) {
+ if (thr.scope.rank == 1) { // threadIdx
+ block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
+ }
+ }
+ for (const ThreadEntry& thr : vpar) {
+ if (thr.scope.rank == 1) { // threadIdx
+ block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
+ }
+ }
+ // sort according to dim_index
+ std::sort(block_threads.begin(), block_threads.end());
+ for (auto&& thr_attr : block_threads) {
+ int dim_index, extent;
+ bool is_reduce;
+ std::tie(dim_index, extent, is_reduce) = thr_attr;
+ if (is_reduce) {
+ contiguous_reduce_extent *= extent;
+ } else {
+ break;
+ }
+ }
+
std::vector<Stmt> seq;
std::vector<Var> shared_bufs(size);
std::vector<Stmt> local_vars;
@@ -238,9 +265,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
- if (is_warp_reduction(types)) {
- // TODO(tvm-team) sub-warp reduction support.
- ICHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction";
+ if (is_warp_reduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
+ ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
//
// This is the index to the reduction variable, one reduction
// variable per warp. Local scope seems easier to reason without
@@ -269,6 +295,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
{
PrimExpr pred = const_true(1);
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+ if (group_extent > 1) {
+ mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent *
group_index));
+ }
seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
@@ -277,7 +306,11 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
}
// Emit reductions within a warp.
- for (int offset = warp_size_ / 2; offset > 0; offset /= 2) {
+ int start_offset = 1;
+ while (start_offset * 2 < reduce_extent) {
+ start_offset *= 2;
+ }
+ for (int offset = start_offset; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
@@ -323,13 +356,14 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
- // uniformmly writting the same result.
+ // uniformly writting the same result.
//
for (size_t i = 0; i < size; ++i) {
Var var = shared_bufs[i];
PrimExpr pred = const_true(types[i].lanes());
PrimExpr val = Load(types[i], var, index, pred);
- PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var,
val, 0);
+ PrimExpr splat =
+ WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val,
reduce_extent * group_index);
seq.push_back(Store(var, splat, index, pred));
}
@@ -346,7 +380,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
warp_allocs_.insert(node.get());
}
} else {
- int threadx_extent = 1;
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores(size);
@@ -357,10 +390,6 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
}
return SeqStmt::Flatten(stores);
}
- // Whether the threadIdx.x is involved in reduction.
- if (vred[0].scope.dim_index == 0) {
- threadx_extent = vred[0].extent;
- }
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
@@ -372,7 +401,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs,
reduce_index, group_index,
- reduce_extent, threadx_extent));
+ reduce_extent, group_extent,
contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]));
PrimExpr pred = const_true(types[idx].lanes());
@@ -402,7 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode* combiner, const
std::vector<DataType>& types,
const Array<Var>& shared_bufs, PrimExpr reduce_index,
PrimExpr group_index,
- int reduce_extent, int threadx_extent) {
+ int reduce_extent, int group_extent, int
contiguous_reduce_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
@@ -444,9 +473,13 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
- ICHECK(threadx_extent >= 1 && warp_size_ >= 1);
+
// normal synchronization
- while (reduce_align > threadx_extent || reduce_align > warp_size_) {
+ bool warp_align = group_extent == 1 || contiguous_reduce_extent %
warp_size_ == 0;
+ while (reduce_align > contiguous_reduce_extent || reduce_align >
warp_size_ || !warp_align) {
+ if (reduce_align == 1) {
+ break;
+ }
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
@@ -534,22 +567,21 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
}
// Emit warp shuffle calls.
- PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int
delta_or_lane) {
+ PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, PrimExpr
delta_or_lane) {
PrimExpr pred = const_true(1);
PrimExpr index(0);
PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
- Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane),
width, width};
+ Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
}
- // Check if this is a reduction on threadIdx.x and its extent matches
- // the warp size.
+ // Check if we can use warp level reduction.
//
- // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
- bool is_warp_reduction(const std::vector<DataType>& types) const {
+ bool is_warp_reduction(const std::vector<DataType>& types, int group_extent,
int reduce_extent,
+ int contiguous_reduce_extent) const {
// Only cuda target supports warp reductions.
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm"))
return false;
@@ -575,18 +607,25 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
return false;
}
- const AttrStmtNode* op = thread_extents_.back();
- DCHECK_EQ(op->attr_key, attr::thread_extent);
-
- IterVar iv = Downcast<IterVar>(op->node);
- ThreadEntry e;
- e.scope = runtime::ThreadScope::Create(iv->thread_tag);
- e.extent = 0;
- if (auto ptr = op->value.as<IntImmNode>()) {
- e.extent = static_cast<int>(ptr->value);
+ // reduce region must be contiguous.
+ if (contiguous_reduce_extent != reduce_extent) {
+ return false;
}
- return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank ==
1;
+ // whether reduce_extent and group_extent are vaild for warp reduction.
+ if (target_->kind->name == "rocm") {
+ return reduce_extent == warp_size_;
+ } else { // target_->kind->name == "cuda"
+ if (reduce_extent == 1) {
+ return false; // no need to warp reduce
+ } else {
+ if (warp_size_ % reduce_extent == 0) {
+ return true; // warp size is multiple of reduce extent
+ } else {
+ return group_extent == 1 && reduce_extent <= warp_size_;
+ }
+ }
+ }
}
// The target.
diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py
b/tests/python/unittest/test_subwarp_reduction_cuda.py
new file mode 100644
index 0000000..8778c75
--- /dev/null
+++ b/tests/python/unittest/test_subwarp_reduction_cuda.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_func
+def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) ->
None:
+ A = T.match_buffer(a, [1, d1, d2, d3])
+ B = T.match_buffer(b, [1, d1, d2])
+
+ for i, j, k, l in T.grid(1, d1, d2, d3):
+ with T.block("reduce"):
+ vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
+ with T.init():
+ B[vi, vj, vk] = 0.0
+ B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
+
+
[email protected]_gpu
[email protected]_cuda
+def test_cuda_subwarp_reduction():
+ def check(d1: int, d2: int, d3: int):
+ _, _, _d1, _d2, _d3 = reduce.params
+ mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
+ sch = tvm.tir.Schedule(mod)
+ blk = sch.get_block("reduce")
+ i, j, k, l = sch.get_loops(blk)
+ sch.bind(i, "blockIdx.x")
+ sch.bind(j, "threadIdx.z")
+ sch.bind(k, "threadIdx.y")
+ sch.bind(l, "threadIdx.x")
+ f = tvm.build(sch.mod["main"], target="cuda")
+
+ # prepare input and output array
+ a_np = np.random.rand(1, d1, d2, d3).astype("float32")
+ b_np = a_np.sum(axis=-1).astype("float32")
+ a = tvm.nd.array(a_np, tvm.cuda(0))
+ b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
+
+ # launch kernel
+ f(a, b)
+ tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
+
+ for d1 in range(1, 5):
+ for d2 in range(1, 5):
+ for d3 in range(2, 33):
+ check(d1, d2, d3)
+
+
+if __name__ == "__main__":
+ test_cuda_subwarp_reduction()