This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 35318ab7b4 [KVCache] Support fork in sliding window sink part (#17127)
35318ab7b4 is described below

commit 35318ab7b4d90933a9f0ffb8c5fbc5af50ab2b2f
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Jul 1 10:21:10 2024 -0700

    [KVCache] Support fork in sliding window sink part (#17127)
    
    This PR adds the support of forking in sliding window attention sink part.
---
 src/runtime/relax_vm/paged_kv_cache.cc             | 23 ++++-
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 97 ++++++++++++++++------
 2 files changed, 90 insertions(+), 30 deletions(-)

diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 0162124cab..ec1cc3593a 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1184,9 +1184,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         << "The parent sequence \"" << parent_seq_id << "\" cannot be found in 
KV cache.";
     CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
         << "The child sequence \"" << child_seq_id << "\" is already in the KV 
cache.";
-    CHECK_EQ(parent_it->second.sliding_window_size, -1)
-        << "The parent sequence \"" << parent_seq_id
-        << "\" is enabled with sliding window and thus cannot be forked.";
     CHECK_GE(fork_pos, -1)
         << "The forked position should be non-negative, or -1 for last 
position as default.";
     CHECK_LE(fork_pos, parent_it->second.seq_length)
@@ -1199,6 +1196,18 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       fork_pos = parent_it->second.seq_length;
     }
 
+    if (parent_it->second.sliding_window_size != -1) {
+      // If forked sequence has been enabled sliding window, check the forked 
position is within
+      // sliding window sink size.
+      const Sequence& seq = parent_it->second;
+      int32_t sink_size = seq.seq_length - 
global_block_pool_[seq.last_block_idx].seq_length +
+                          seq.last_block_attn_sink_size;
+      CHECK_LE(fork_pos, sink_size)
+          << "The parent sequence \"" << parent_seq_id
+          << "\" is enabled with sliding window and thus only can be forked 
within sink size = "
+          << sink_size << ". But the forked position = " << fork_pos << ".";
+    }
+
     if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0 
&&
         global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) {
       // To enable the parent sequence to continue decode after the fork,
@@ -1258,6 +1267,14 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         // Update in-block sequence length per blocks
         global_block_pool_[parent_block_idx].seq_length = moved_offset;
         global_block_pool_[forked_block_idx].seq_length -= moved_offset;
+
+        // Update sliding window sink size if sliding window is enabled and 
the forked block is the
+        // last block
+        if (parent_it->second.sliding_window_size != -1 &&
+            forked_block_idx == parent_it->second.last_block_idx) {
+          CHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size);
+          parent_it->second.last_block_attn_sink_size -= moved_offset;
+        }
       }
       global_block_pool_[child_block_idx].start_pos = fork_pos - 
in_page_offset;
       global_block_pool_[child_block_idx].seq_length = in_page_offset;
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 87256720bd..34680160c8 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -468,8 +468,11 @@ def apply_attention(
 
     for seq_id, _ in batch:
         if sliding_window_sizes is not None and len(sliding_window_sizes) > 
seq_id:
+            assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) 
> seq_id
             sliding_window_size = sliding_window_sizes[seq_id]
             attn_sink_size = attn_sink_sizes[seq_id]
+            if sliding_window_size == 0:
+                continue
             if cached_k[seq_id].shape[1] > sliding_window_size:
                 # Apply sliding window and sink to cached kv.
                 length_to_slide = cached_k[seq_id].shape[1] - 
sliding_window_size
@@ -746,34 +749,74 @@ def 
test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
             attn_sink_sizes,
         )
 
-    # Sliding window with fork
-    sliding_window_sizes += [0, 18]
-    attn_sink_sizes += [0, 12]
-    apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
-    ffork_sequence(kv_cache, 5, 6, -1)
-    cached_k[6] = cached_k[5]
-    cached_v[6] = cached_v[5]
+
[email protected]_gpu
[email protected]_cuda
+def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    sliding_window_sizes = [30, 35, 40]
+    attn_sink_sizes = [15, 20, 25]
+    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+        zip(sliding_window_sizes, attn_sink_sizes)
+    ):
+        fadd_sequence(kv_cache, seq_id)
+        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
+        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 12), (1, 18), (2, 28)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [12, 18, 25+3]
+    sliding_window_sizes += [0, 0, 0]
+    attn_sink_sizes += [0, 0, 0]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [12, 18, 25+3, 18, 38, 36]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
+    sliding_window_sizes += [25]
+    attn_sink_sizes += [24]
+    ffork_sequence(kv_cache, 3, 6, 18)
     fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], 
attn_sink_sizes[-1])
-    for _ in range(2):
-        apply_attention(
-            kv_cache,
-            rope_mode,
-            [(6, 10)],
-            cached_k,
-            cached_v,
-            sliding_window_sizes,
-            attn_sink_sizes,
-        )
-    for _ in range(16):
-        apply_attention(
-            kv_cache,
-            rope_mode,
-            [(6, 1)],
-            cached_k,
-            cached_v,
-            sliding_window_sizes,
-            attn_sink_sizes,
-        )
+    cached_k[6] = cached_k[3][::, :18]
+    cached_v[6] = cached_v[3][::, :18]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(3, 10), (6, 12)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
 
 
 @tvm.testing.requires_gpu

Reply via email to