This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 35318ab7b4 [KVCache] Support fork in sliding window sink part (#17127)
35318ab7b4 is described below
commit 35318ab7b4d90933a9f0ffb8c5fbc5af50ab2b2f
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Jul 1 10:21:10 2024 -0700
[KVCache] Support fork in sliding window sink part (#17127)
This PR adds the support of forking in sliding window attention sink part.
---
src/runtime/relax_vm/paged_kv_cache.cc | 23 ++++-
...runtime_builtin_paged_attention_kv_cache_tir.py | 97 ++++++++++++++++------
2 files changed, 90 insertions(+), 30 deletions(-)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 0162124cab..ec1cc3593a 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1184,9 +1184,6 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in
KV cache.";
CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
<< "The child sequence \"" << child_seq_id << "\" is already in the KV
cache.";
- CHECK_EQ(parent_it->second.sliding_window_size, -1)
- << "The parent sequence \"" << parent_seq_id
- << "\" is enabled with sliding window and thus cannot be forked.";
CHECK_GE(fork_pos, -1)
<< "The forked position should be non-negative, or -1 for last
position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
@@ -1199,6 +1196,18 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
fork_pos = parent_it->second.seq_length;
}
+ if (parent_it->second.sliding_window_size != -1) {
+ // If forked sequence has been enabled sliding window, check the forked
position is within
+ // sliding window sink size.
+ const Sequence& seq = parent_it->second;
+ int32_t sink_size = seq.seq_length -
global_block_pool_[seq.last_block_idx].seq_length +
+ seq.last_block_attn_sink_size;
+ CHECK_LE(fork_pos, sink_size)
+ << "The parent sequence \"" << parent_seq_id
+ << "\" is enabled with sliding window and thus only can be forked
within sink size = "
+ << sink_size << ". But the forked position = " << fork_pos << ".";
+ }
+
if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0
&&
global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) {
// To enable the parent sequence to continue decode after the fork,
@@ -1258,6 +1267,14 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// Update in-block sequence length per blocks
global_block_pool_[parent_block_idx].seq_length = moved_offset;
global_block_pool_[forked_block_idx].seq_length -= moved_offset;
+
+ // Update sliding window sink size if sliding window is enabled and
the forked block is the
+ // last block
+ if (parent_it->second.sliding_window_size != -1 &&
+ forked_block_idx == parent_it->second.last_block_idx) {
+ CHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size);
+ parent_it->second.last_block_attn_sink_size -= moved_offset;
+ }
}
global_block_pool_[child_block_idx].start_pos = fork_pos -
in_page_offset;
global_block_pool_[child_block_idx].seq_length = in_page_offset;
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 87256720bd..34680160c8 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -468,8 +468,11 @@ def apply_attention(
for seq_id, _ in batch:
if sliding_window_sizes is not None and len(sliding_window_sizes) >
seq_id:
+ assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes)
> seq_id
sliding_window_size = sliding_window_sizes[seq_id]
attn_sink_size = attn_sink_sizes[seq_id]
+ if sliding_window_size == 0:
+ continue
if cached_k[seq_id].shape[1] > sliding_window_size:
# Apply sliding window and sink to cached kv.
length_to_slide = cached_k[seq_id].shape[1] -
sliding_window_size
@@ -746,34 +749,74 @@ def
test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
attn_sink_sizes,
)
- # Sliding window with fork
- sliding_window_sizes += [0, 18]
- attn_sink_sizes += [0, 12]
- apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
- ffork_sequence(kv_cache, 5, 6, -1)
- cached_k[6] = cached_k[5]
- cached_v[6] = cached_v[5]
+
[email protected]_gpu
[email protected]_cuda
+def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ sliding_window_sizes = [30, 35, 40]
+ attn_sink_sizes = [15, 20, 25]
+ for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+ zip(sliding_window_sizes, attn_sink_sizes)
+ ):
+ fadd_sequence(kv_cache, seq_id)
+ fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size,
attn_sink_size)
+ cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+ cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 12), (1, 18), (2, 28)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [12, 18, 25+3]
+ sliding_window_sizes += [0, 0, 0]
+ attn_sink_sizes += [0, 0, 0]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [12, 18, 25+3, 18, 38, 36]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
+ sliding_window_sizes += [25]
+ attn_sink_sizes += [24]
+ ffork_sequence(kv_cache, 3, 6, 18)
fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1],
attn_sink_sizes[-1])
- for _ in range(2):
- apply_attention(
- kv_cache,
- rope_mode,
- [(6, 10)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
- for _ in range(16):
- apply_attention(
- kv_cache,
- rope_mode,
- [(6, 1)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
+ cached_k[6] = cached_k[3][::, :18]
+ cached_v[6] = cached_v[3][::, :18]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(3, 10), (6, 12)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
@tvm.testing.requires_gpu