This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cd09ab64b5 [Runtime] Reorganize PagedKVCache attn kernel invocation
(#17237)
cd09ab64b5 is described below
commit cd09ab64b5ccf6ff0a96d887a968acd4602188a8
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 3 20:01:01 2024 -0400
[Runtime] Reorganize PagedKVCache attn kernel invocation (#17237)
This PR reorganizes the attention kernel invocation logic in the
PagedKVCache, so that in cases of sequence fork, we can effectively
merge one ragged-prefill kernel and a decode kernel into a single
decode kernel.
---
src/relax/transform/fuse_ops.cc | 2 +-
src/runtime/relax_vm/paged_kv_cache.cc | 127 +++++++++++++++++----------------
2 files changed, 65 insertions(+), 64 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index e791aeab06..85c739e083 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -646,7 +646,7 @@ class FunctionCreator : public ExprMutator {
return tvm::tir::UndefinedVars(prim_value->value).empty();
} else if (const auto* shape_expr = expr.as<ShapeExprNode>()) {
return std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
- [this](const PrimExpr& e) { return
tvm::tir::UndefinedVars(e).empty(); });
+ [](const PrimExpr& e) { return
tvm::tir::UndefinedVars(e).empty(); });
}
return false;
}
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 5aa1411ec1..cf5de97202 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
}
- append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 &&
use_decode_kernel_[0];
+ append_before_attn_ = !support_sliding_window_ &&
use_decode_kernel_.back();
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
// When GQA group size is at least 4 and FlashInfer is enabled,
// we always use prefill kernel for better performance.
@@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
return;
}
- if (append_before_attn_) {
- if (!support_sliding_window_) {
+ if (!append_before_attn_) {
+ if (is_chain_) {
+ f_attention_prefill_ragged_begin_forward_.value()(
+ temp_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
+ cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
num_qo_heads_,
+ num_kv_heads_, head_dim_, copy_stream_);
+ } else {
+ LOG(FATAL) << "Kernel BeginForward doesn't support tree attn.";
+ }
+ }
+ for (int d = 0; d < num_depths_; ++d) {
+ if (page_indices_on_depths_view_[d]->shape[0] == 0) {
+ continue;
+ }
+ CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support
sliding window.";
+ if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_.value()(
- /*depth=*/0, temp_attn_workspace_[1],
page_indptr_on_depths_host_[0].as_ndarray(),
- last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_,
num_kv_heads_, head_dim_,
+ d, temp_attn_workspace_[d + 1],
page_indptr_on_depths_host_[d].as_ndarray(),
+ last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_,
num_kv_heads_, head_dim_,
page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
- }
- } else {
- f_attention_prefill_ragged_begin_forward_.value()(
- temp_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
- cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
num_qo_heads_,
- num_kv_heads_, head_dim_, copy_stream_);
- if (support_sliding_window_) {
- return;
- }
- for (int d = 0; d < num_depths_; ++d) {
- if (page_indices_on_depths_view_[d]->shape[0] == 0) {
- continue;
- }
- if (use_decode_kernel_[d]) {
- f_attention_decode_begin_forward_.value()(
- d, temp_attn_workspace_[d + 1],
page_indptr_on_depths_host_[d].as_ndarray(),
- last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_,
num_kv_heads_,
- head_dim_, page_size_,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
- } else {
- f_attention_prefill_begin_forward_.value()(
- /*depth=*/d, temp_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(),
- page_indptr_on_depths_host_[d].as_ndarray(),
- static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1,
num_qo_heads_,
- num_kv_heads_, head_dim_, page_size_, copy_stream_);
- }
+ } else {
+ f_attention_prefill_begin_forward_.value()(
+ /*depth=*/d, temp_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(),
+ page_indptr_on_depths_host_[d].as_ndarray(),
+ static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1,
num_qo_heads_,
+ num_kv_heads_, head_dim_, page_size_, copy_stream_);
}
}
}
@@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
PackedFunc f_decode =
!support_sliding_window_ ? f_attention_decode_ :
f_attention_decode_sliding_window_;
CHECK_GE(num_depths_, 1) << "The number of effective depths must be
greater or equal to 1.";
- if (append_before_attn_) {
- f_decode(
- /*depth=*/0, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[0],
- page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
- k_rope_pos_offset_view_[0], q_rope_position_map_view_, output,
merged_attn_scores_view_,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_,
- attn_score_scaling_factor);
- } else {
- // Compute appended text self-attention
+
+ bool is_first_kernel = true;
+ if (!append_before_attn_) {
+ // The first part of attention, which only involves the q and the newly
appended k/v.
+ is_first_kernel = false;
if (is_chain_) {
// If the batch does not form a tree, use raggedness prefill kernel.
f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_,
k_data, v_data,
@@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ ==
RoPEMode::kInline,
rotary_scale_, rotary_theta_, attn_score_scaling_factor,
cur_batch_size_);
}
+ }
- for (int d = 0; d < num_depths_; ++d) {
- if (page_indices_on_depths_view_[d]->shape[0] == 0) {
- continue;
- }
- if (use_decode_kernel_[d]) {
- // Use decode kernel for depth d
- f_decode(/*depth=*/d, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[d],
- page_indices_on_depths_view_[d],
length_info_on_depths_view_[d],
- k_rope_pos_offset_view_[d], q_rope_position_map_view_,
temp_attn_output_view_,
- temp_attn_scores_view_,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline,
rotary_scale_, rotary_theta_,
- attn_score_scaling_factor);
- } else {
- // Use prefill kernel for depth d
- f_prefill(
- /*depth=*/d, q_data, qo_indptr_on_depths_view_[d],
pages_[local_layer_id],
- page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
- length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
q_rope_position_map_view_,
- temp_attn_output_view_, temp_attn_scores_view_,
- /*causal=*/0,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_,
- attn_score_scaling_factor);
- }
+ for (int d = 0; d < num_depths_; ++d) {
+ if (page_indices_on_depths_view_[d]->shape[0] == 0) {
+ continue;
+ }
+ NDArray attn_output;
+ NDArray attn_scores;
+ if (is_first_kernel) {
+ attn_output = output;
+ attn_scores = merged_attn_scores_view_;
+ } else {
+ attn_output = temp_attn_output_view_;
+ attn_scores = temp_attn_scores_view_;
+ }
+ if (use_decode_kernel_[d]) {
+ // Use decode kernel for depth d
+ f_decode(/*depth=*/d, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[d],
+ page_indices_on_depths_view_[d],
length_info_on_depths_view_[d],
+ k_rope_pos_offset_view_[d], q_rope_position_map_view_,
attn_output, attn_scores,
+ /*rotary_mode=*/rope_mode_ == RoPEMode::kInline,
rotary_scale_, rotary_theta_,
+ attn_score_scaling_factor);
+ } else {
+ // Use prefill kernel for depth d
+ f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d],
pages_[local_layer_id],
+ page_indptr_on_depths_view_[d],
page_indices_on_depths_view_[d],
+ length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
+ q_rope_position_map_view_, attn_output, attn_scores,
/*causal=*/0,
+ /*rotary_mode=*/rope_mode_ == RoPEMode::kInline,
rotary_scale_, rotary_theta_,
+ attn_score_scaling_factor);
+ }
+
+ if (!is_first_kernel) {
f_merge_inplace_(output, merged_attn_scores_view_,
temp_attn_output_view_,
temp_attn_scores_view_);
+ } else {
+ is_first_kernel = false;
}
}
}