This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f69ed43a8 [TIR] Finer predicate handling in cross-thread reduction
(#15374)
3f69ed43a8 is described below
commit 3f69ed43a8723e91abdbbc928a5826aa9461a953
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Jul 22 06:12:55 2023 -0700
[TIR] Finer predicate handling in cross-thread reduction (#15374)
This PR fixes the predicate handling logic of the cross-thread
reduction lowering pass.
For the cross-thread reduction write-back block, prior to this PR, its
predicate is the conjunction of `t == 0` for each reduction thread dim
of the cross-thread reduction. This is problematic when the write-back
buffer is stored in local memory, where each thread is supposed to
have a copy of the final value, while the final value is only stored
by the first thread. In this PR, the predicate is changed to be the
conjunction of the clauses from the two parts:
* the clause of the original reduction block's predicate which contains
spatial loop var,
* `t == 0` for each reduction thread dim **only when the write-back
buffer is global or shared**.
So the first part ensures that the write-back will not go out of bound,
and the second part ensures that when the write-back buffer is local,
every thread gets a value and when the write-back buffer is non-local,
only one thread writes the value out.
Meanwhile, this PR fixes the cross-thread broadcasting detection with
the awareness of the storage scope of the write buffer of the
broadcasting block. Specifically, for each consumer block of a buffer
produced by cross-thread reduction under the same kernel (i.e., same
set of `blockIdx`) of the cross-thread reduction block, when the
write buffer of this consumer block is in local memory, we do not treat
it as broadcasting, and will not add a predicate to it. Otherwise,
we will add the predicate according to the broadcasting handling
introduced by #15192.
---
src/tir/transforms/lower_cross_thread_reduction.cc | 120 +++++++++++++++---
...t_tir_transform_lower_cross_thread_reduction.py | 135 ++++++++++++++++++++-
2 files changed, 237 insertions(+), 18 deletions(-)
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc
b/src/tir/transforms/lower_cross_thread_reduction.cc
index 79dbfb2a02..413894264e 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -426,12 +426,48 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}),
wb_indices));
wb_regions.push_back(BufferRegion(wb_buffers[i], region));
}
+
+ // Construct the predicate of the write-back block. It is the conjunction
of
+ // - each predicate clause of the original block which contains spatial
loop var, and
+ // - `t == 0` for each reduction thread dim when the write-back buffer is
not local.
PrimExpr wb_predicate = const_true();
- for (const ForNode* loop : reduction_loops) {
- if (loop->thread_binding.defined()) {
- wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ std::unordered_set<const VarNode*> reduction_loop_vars;
+ reduction_loop_vars.reserve(reduction_loops.size());
+ for (const ForNode* reduction_loop : reduction_loops) {
+ reduction_loop_vars.insert(reduction_loop->loop_var.get());
+ }
+ PostOrderVisit(realize->predicate, [&wb_predicate,
&reduction_loop_vars](const ObjectRef& obj) {
+ if (const auto* and_node = obj.as<AndNode>()) {
+ Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
+ for (PrimExpr sub_expr : sub_exprs) {
+ if (sub_expr->IsInstance<AndNode>()) {
+ continue;
+ }
+ bool is_reduction = [sub_expr, &reduction_loop_vars]() {
+ Array<Var> vars = UndefinedVars(sub_expr);
+ for (Var var : vars) {
+ if (reduction_loop_vars.find(var.get()) !=
reduction_loop_vars.end()) {
+ return true;
+ }
+ }
+ return false;
+ }();
+ if (!is_reduction) {
+ wb_predicate = wb_predicate && sub_expr;
+ }
+ }
+ return true;
+ }
+ return false;
+ });
+ if (wb_buffers[0].scope() != "local") {
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ }
}
}
+
stmts.push_back(BlockRealize(
/*iter_values=*/std::move(bindings),
/*predicate=*/wb_predicate,
@@ -498,21 +534,45 @@ class CrossThreadReductionTransformer : public
StmtMutator {
}
// Check if the input block needs thread broadcast rewrite.
- // One block needs broadcast rewrite when there exists one or more thread
- // vars which vars free variables to this block.
+ // One block needs broadcast rewrite when
+ // 1. it consumes a buffer produced by cross-thread reduction under
+ // the same kernel (i.e., same group of blockIdx),
+ // 2. it writes to non-local memory,
+ // 3. at least one of the reduction thread vars of the cross-thread reduction
+ // is free to this block (i.e., not bound to the block).
std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
const BlockRealizeNode* realize) {
- std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
unbound_thread2range =
- thread2range_;
+ Block block = realize->block;
+
+ // If the block writes to local memory, no rewrite is needed.
+ for (BufferRegion write_region : block->writes) {
+ if (write_region->buffer.scope() == "local") {
+ return {};
+ }
+ }
+
+ // Find out the reduction threads for the read-buffers which are produced
by
+ // cross-thread reduction.
+ std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range;
+ for (BufferRegion read_region : block->reads) {
+ auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
+ if (buf_it == crt_buf2threads_.end()) {
+ continue;
+ }
+ for (auto [scope, range] : buf_it->second) {
+ thread2range[scope] = range;
+ }
+ }
+
+ // Erase those threads which are not free to this block.
for (const ForNode* loop : loop_stack_) {
if (loop->thread_binding.defined()) {
ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
- unbound_thread2range.erase(scope);
+ thread2range.erase(scope);
}
}
-
std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
- for (auto [scope, range] : unbound_thread2range) {
+ for (auto [scope, range] : thread2range) {
unbound_thread2range_list.emplace_back(scope, range);
}
return unbound_thread2range_list;
@@ -582,13 +642,28 @@ class CrossThreadReductionTransformer : public
StmtMutator {
std::tie(reducer, combiner_lhs, combiner_rhs) =
GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates);
+ // Condition 4. All reduction buffers should be all local or all non-local.
+ int is_local_buf = -1;
Array<Buffer> reduction_buffers;
reduction_buffers.reserve(updates.size());
for (const BufferStore& buf_store : updates) {
reduction_buffers.push_back(buf_store->buffer);
+ if (buf_store->buffer.scope() == "local") {
+ CHECK_NE(is_local_buf, 0)
+ << "ValueError: Cross-thread reduction requires all reduction
buffers to be all "
+ "local or all non-local. However, here some buffer is local
while some buffer is "
+ "shared or global.";
+ is_local_buf = 1;
+ } else {
+ CHECK_NE(is_local_buf, 1)
+ << "ValueError: Cross-thread reduction requires all reduction
buffers to be all "
+ "local or all non-local. However, here some buffer is local
while some buffer is "
+ "shared or global.";
+ is_local_buf = 0;
+ }
}
- // Condition 4. The block should be the last block under the first
reduction-related loop.
+ // Condition 5. The block should be the last block under the first
reduction-related loop.
bool visit = false;
PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const
ObjectRef& obj) {
if (const auto* realize = obj.as<BlockRealizeNode>()) {
@@ -631,8 +706,6 @@ class CrossThreadReductionTransformer : public StmtMutator {
if (scope.rank == 1 && scope.dim_index >= 0) {
is_thread_idx = true;
++thread_idx_depth;
- thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent);
- thread_loop_var2scope_[loop->loop_var.get()] = scope;
} else if (scope.rank == 0) {
is_block_idx = true;
++block_idx_depth;
@@ -649,7 +722,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
--block_idx_depth;
}
if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 &&
block_idx_depth == 0)) {
- thread2range_.clear();
+ crt_buf2threads_.clear();
}
// Replace `result` with the pre-stored result if `loop` appears as a key
in `loop2new_stmt_`.
@@ -716,6 +789,21 @@ class CrossThreadReductionTransformer : public StmtMutator
{
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers,
reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);
+
+ // Step 5. Record the reduction thread dims for the write-back buffers.
+ // The information is used for consumer block broadcasting detection.
+ std::vector<std::pair<ThreadScope, Range>> reduction_threads;
+ reduction_threads.reserve(reduction_loops.size());
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ reduction_threads.emplace_back(
+ ThreadScope::Create(loop->thread_binding.value()->thread_tag),
+ Range::FromMinExtent(loop->min, loop->extent));
+ }
+ }
+ for (const Buffer& reduction_buf : reduction_buffers) {
+ crt_buf2threads_[reduction_buf.get()] = reduction_threads;
+ }
}
Stmt MakeCrossThreadBroadcast(
@@ -792,8 +880,8 @@ class CrossThreadReductionTransformer : public StmtMutator {
int block_idx_depth = 0;
int thread_idx_depth = 0;
- std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range_;
- std::unordered_map<const VarNode*, ThreadScope> thread_loop_var2scope_;
+ std::unordered_map<const BufferNode*, std::vector<std::pair<ThreadScope,
Range>>>
+ crt_buf2threads_;
};
PrimFunc LowerCrossThreadReduction(PrimFunc f) {
diff --git
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 6162233b65..f42f8ca85f 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -496,6 +496,64 @@ def lowered_single_reduction_loop_with_block_predicate(
)
[email protected]_func
+def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B:
T.Buffer((2,), "float32")):
+ for i_0 in range(1):
+ for i_1 in T.thread_binding(16, thread="threadIdx.y"):
+ for k_0 in range(1):
+ for k_1 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("block"):
+ vi = T.axis.spatial(2, i_0 * 16 + i_1)
+ vk = T.axis.reduce(32, k_0 * 64 + k_1)
+ T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
+ T.reads(A[vi, vk])
+ T.writes(B[vi])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_reduction_spatial_loop_predicate(
+ A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")
+):
+ cross_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
+ in_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
+ for i_0 in range(1):
+ for i_1 in T.thread_binding(16, thread="threadIdx.y"):
+ for k_1 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("block_in_thread_init"):
+ T.reads()
+ T.writes(in_thread_B[0])
+ in_thread_B[0] = T.float32(0)
+ for k_0 in range(1):
+ with T.block("block_in_thread"):
+ vi = T.axis.spatial(2, i_0 * 16 + i_1)
+ vk = T.axis.reduce(32, k_0 * 64 + k_1)
+ T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
+ T.reads(A[vi, vk])
+ T.writes(in_thread_B[0])
+ in_thread_B[0] = in_thread_B[0] + A[vi, vk]
+ with T.block("block_cross_thread"):
+ T.reads(in_thread_B[0])
+ T.writes(cross_thread_B[0])
+ T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ )
+ T.tvm_thread_allreduce(
+ T.uint32(1), in_thread_B[0], T.bool(True),
cross_thread_B[0], k_1
+ )
+ k_0 = T.int32()
+ with T.block("block_write_back"):
+ vi = T.axis.spatial(2, i_0 * 16 + i_1)
+ T.where(i_0 * 16 + i_1 < 2 and k_1 == 0)
+ T.reads(cross_thread_B[0])
+ T.writes(B[vi])
+ B[vi] = cross_thread_B[0]
+
+
@T.prim_func
def single_reduction_loop_with_tensorize(
input_A: T.Buffer((1, 64, 7, 7, 32), "uint8"),
@@ -1315,7 +1373,6 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256),
"float32"), B: T.Buffer((
)
with T.block("sum_write_back"):
vi = T.axis.spatial(256, i)
- T.where(k == 0)
T.reads(cross_thread_temp_local[0])
T.writes(temp_local[vi])
temp_local[vi] = cross_thread_temp_local[0]
@@ -1428,7 +1485,7 @@ def lowered_thread_broadcast_2(lv1605:
T.Buffer((T.int64(1), T.int64(32), T.int6
with T.block("NT_matmul_write_back"):
v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
- T.where(ax0_fused == T.int64(0))
+ T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused
// n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n <
n)
T.reads(cross_thread_var_NT_matmul_intermediate_local[0])
T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0,
T.int64(0), v1])
var_NT_matmul_intermediate_local[T.int64(0), v0,
T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0]
@@ -1442,6 +1499,72 @@ def lowered_thread_broadcast_2(lv1605:
T.Buffer((T.int64(1), T.int64(32), T.int6
var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] =
T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0,
T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)),
lv1582[T.int64(0), T.int64(0), T.int64(0), v1]))
# fmt: on
+
[email protected]_func
+def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,
256), "float32")):
+ temp_1_local = T.alloc_buffer((256,), scope="local")
+ temp_2_local = T.alloc_buffer((1,), scope="local")
+ for i in T.thread_binding(256, thread="blockIdx.x"):
+ for k in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("sum"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads(A[vi, vk])
+ T.writes(temp_1_local[vi])
+ with T.init():
+ temp_1_local[vi] = T.float32(0)
+ temp_1_local[vi] = temp_1_local[vi] + A[vi, vk]
+ with T.block("add"):
+ vi = T.axis.spatial(256, i)
+ T.reads(temp_1_local[vi])
+ T.writes(temp_2_local[0])
+ temp_2_local[0] = temp_1_local[vi] + T.float32(1)
+ for j in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("sum"):
+ vi, vj = T.axis.remap("SR", [i, j])
+ T.reads(temp_2_local[0])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] + temp_2_local[0]
+
+
[email protected]_func
+def lowered_no_thread_broadcast(
+ A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")
+):
+ temp_1_local = T.alloc_buffer((256,), scope="local")
+ temp_2_local = T.alloc_buffer((1,), scope="local")
+ cross_thread_temp_1_local = T.alloc_buffer((1,), strides=(1,),
scope="local")
+ for i in T.thread_binding(256, thread="blockIdx.x"):
+ for k in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("sum_cross_thread"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads(A[vi, vk])
+ T.writes(cross_thread_temp_1_local[0])
+ T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ )
+ T.tvm_thread_allreduce(
+ T.uint32(1), A[vi, vk], T.bool(True),
cross_thread_temp_1_local[0], k
+ )
+ with T.block("sum_write_back"):
+ vi = T.axis.spatial(256, i)
+ T.reads(cross_thread_temp_1_local[0])
+ T.writes(temp_1_local[vi])
+ temp_1_local[vi] = cross_thread_temp_1_local[0]
+ with T.block("add"):
+ vi = T.axis.spatial(256, i)
+ T.reads(temp_1_local[vi])
+ T.writes(temp_2_local[0])
+ temp_2_local[0] = temp_1_local[vi] + T.float32(1)
+ for j in T.thread_binding(256, thread="threadIdx.x"):
+ with T.block("sum"):
+ vi, vj = T.axis.remap("SR", [i, j])
+ T.reads(temp_2_local[0])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] + temp_2_local[0]
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@@ -1472,6 +1595,10 @@ def test_single_reduction_loop_with_block_predicate():
)
+def test_spatial_reduction_loop_predicate():
+ _check(spatial_reduction_loop_predicate,
lowered_reduction_spatial_loop_predicate)
+
+
def test_single_reduction_loop_with_tensorize():
_check(
single_reduction_loop_with_tensorize,
@@ -1534,6 +1661,10 @@ def test_thread_broadcast_rewrite_2():
_check(thread_broadcast_2, lowered_thread_broadcast_2)
+def test_no_thread_broadcast_rewrite():
+ _check(no_thread_broadcast, lowered_no_thread_broadcast)
+
+
def test_lower_te():
a = te.placeholder((32, 2, 2))
k1 = te.reduce_axis((0, 2), "k1")