This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 76b954a09e [3rdparty] Bump FlashInfer (#17236)
76b954a09e is described below
commit 76b954a09e781b7f664b1d345e1494123c19484c
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 3 04:28:02 2024 -0400
[3rdparty] Bump FlashInfer (#17236)
This PR bumps FlashInfer and updates PagedKVCache accordingly
for performance improvement.
Some notes on this bump:
* When the Grouped-Query Attention group size is at least 4 and
FlashInfer is enabled, we use the prefill attn kernel for better
performance.
* We enlarge the temporary workspace for FlashInfer use accordingly,
as FlashInfer in the current version may consume much larger workspace.
We turn off the workspace when FlashInfer is not enabled.
* We reduce the max block depth to be 2, in observation of the limited
help of cascade inference when batch size is not large and the prompt
reuse is low.
---
3rdparty/flashinfer | 2 +-
src/runtime/relax_vm/paged_kv_cache.cc | 48 +++++++++++++++-------
..._builtin_paged_attention_kv_cache_flashinfer.py | 13 +++++-
...runtime_builtin_paged_attention_kv_cache_tir.py | 13 +++++-
4 files changed, 58 insertions(+), 18 deletions(-)
diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 7e9cc7ff42..0dd801d202 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2
+Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 2fb8a72f42..5aa1411ec1 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -54,11 +54,11 @@ namespace relax_vm {
* \brief The maximum allowed block depth (a.k.a. number of common
* prefixes) in paged KV cache.
*/
-constexpr const int kPagedKVCacheMaxBlockDepth = 5;
+constexpr const int kPagedKVCacheMaxBlockDepth = 2;
/*! \brief The maximum tree size of a single sequence in tree attention. */
constexpr const int kTreeAttnMaxTreeSize = 256;
/*! \brief The 8MB workspace size for attention auxiliary data. */
-constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
+constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
/*! \brief The id of the temporary logical page, which is useful for sliding
window. */
constexpr const int kPagedKVCacheTempPageId = -1;
@@ -119,6 +119,9 @@ struct Block {
void Reset() {
page_ids.clear();
seq_length = 0;
+ start_pos = 0;
+ sink_length = 0;
+ sliding_window_offset = 0;
parent_idx = -1;
external_ref_cnt = 0;
}
@@ -169,11 +172,9 @@ struct Sequence {
this->last_block_idx = last_block_idx;
int32_t block_ptr = last_block_idx;
// Go through each block in the sequence, sum up the length.
- int depth = 0;
while (true) {
const Block& block = global_block_pool->at(block_ptr);
this->seq_length += block.seq_length;
- ++depth;
if (block.parent_idx == -1) {
break;
}
@@ -1078,8 +1079,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
dtype_aux_, preferred_host_device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
- temp_attn_workspace_.push_back(
- NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32),
device));
+ if (NeedKernelBeginForward()) {
+ temp_attn_workspace_.push_back(
+ NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32),
device));
+ }
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
page_indices_on_depths_view_.push_back(NDArray());
@@ -1087,8 +1090,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
k_rope_pos_offset_view_.push_back(NDArray());
}
// Additional workspace for the "prefill with ragged kv" kernel.
- temp_attn_workspace_.push_back(
- NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
+ if (NeedKernelBeginForward()) {
+ temp_attn_workspace_.push_back(
+ NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32),
device));
+ }
temp_attn_q_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype,
device);
@@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 &&
use_decode_kernel_[0];
+ if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
+ // When GQA group size is at least 4 and FlashInfer is enabled,
+ // we always use prefill kernel for better performance.
+ std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(),
/*value=*/false);
+ }
+
if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
@@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
use_decode_kernel};
}
+ /*! \brief Check whether BeginForward for kernels is needed. */
+ bool NeedKernelBeginForward() {
+ return f_attention_prefill_begin_forward_.defined() &&
+ f_attention_decode_begin_forward_.defined() &&
+ f_attention_prefill_ragged_begin_forward_.defined();
+ }
+
/*! \brief Invoke the "begin forward" functions of underlying kernels. */
void KernelBeginForward() {
- if (!f_attention_prefill_begin_forward_.defined() ||
- !f_attention_decode_begin_forward_.defined() ||
- !f_attention_prefill_ragged_begin_forward_.defined()) {
+ if (!NeedKernelBeginForward()) {
return;
}
@@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
} else {
f_attention_prefill_ragged_begin_forward_.value()(
- temp_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
- num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_);
+ temp_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
+ cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
num_qo_heads_,
+ num_kv_heads_, head_dim_, copy_stream_);
if (support_sliding_window_) {
return;
}
@@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(),
- length_info_on_depths_view_[d]->shape[0], num_qo_heads_,
num_kv_heads_, head_dim_,
- copy_stream_);
+ page_indptr_on_depths_host_[d].as_ndarray(),
+ static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1,
num_qo_heads_,
+ num_kv_heads_, head_dim_, page_size_, copy_stream_);
}
}
}
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index bade04a7d7..cab10f84cd 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -29,7 +29,7 @@ from tvm.runtime import ShapeTuple
from tvm.script import tir as T
reserved_nseq = 32
-maximum_total_seq_length = 1024
+maximum_total_seq_length = 2048
prefill_chunk_size = 512
page_size = 16
num_layers = 4
@@ -249,6 +249,7 @@ def _copy_single_page(num_heads, page_size, head_dim,
dtype, target):
):
for t in T.thread_binding(tx, thread="threadIdx.x"):
with T.block("copy"):
+ T.where(b * tx + t < copy_length * num_heads * head_dim)
vh = T.axis.spatial(
num_heads,
T.Cast("int32", (b * tx + t) // (copy_length *
head_dim)),
@@ -662,6 +663,16 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
cached_v.pop(i)
verify_cached_kv(kv_cache, seq_ids=list(range(i)),
expected_k=cached_k, expected_v=cached_v)
+ # Test fork after page recycle
+ apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+ apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k,
cached_v)
+
@pytest.mark.skip(reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 9192bb901f..3c85a13e4c 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -33,7 +33,7 @@ from tvm.script import tir as T
from tvm.target import Target
reserved_nseq = 32
-maximum_total_seq_length = 1024
+maximum_total_seq_length = 2048
prefill_chunk_size = 512
page_size = 16
num_layers = 4
@@ -615,6 +615,16 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
+ # Test fork after page recycle
+ apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+ apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k,
cached_v)
+
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
@@ -2547,6 +2557,7 @@ def _copy_single_page(num_heads, page_size, head_dim,
dtype, target: Target):
):
for t in T.thread_binding(tx, thread="threadIdx.x"):
with T.block("copy"):
+ T.where(b * tx + t < copy_length * num_heads * head_dim)
vh = T.axis.spatial(
num_heads,
T.Cast("int32", (b * tx + t) // (copy_length *
head_dim)),