This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c3be89a407 [KVCache] Support forking sequence at specific posotion 
(#16813)
c3be89a407 is described below

commit c3be89a4070287cb98fded112a48a3d295564dea
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Mar 29 16:09:37 2024 -0700

    [KVCache] Support forking sequence at specific posotion (#16813)
    
    This PR enables KVCache to fork a sequence at specific position.
---
 src/runtime/relax_vm/kv_state.h                    |   5 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 127 +++++++++++++++++----
 src/runtime/relax_vm/rnn_state.cc                  |   2 +-
 ..._builtin_paged_attention_kv_cache_flashinfer.py | 102 ++++++++++++++---
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 101 +++++++++++++---
 .../python/relax/test_runtime_builtin_rnn_state.py |   2 +-
 6 files changed, 283 insertions(+), 56 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index f6857a9dce..e3c6e9608c 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -59,9 +59,12 @@ class KVStateObj : public Object {
    * \param parent_seq_id The parent (source) of the fork.
    * \param child_seq_id The child (destination) of the fork.
    * The child sequence id should not exist in cache prior to fork.
+   * \param fork_pos The parent position to fork, the legal forking position 
is within
+   * [0, parent_seq_length] and -1 as default for last position. And if 
forking position is 0,
+   * it equals to add a new sequence with child sequence id.
    * \throws Error if the given sequence ids are not valid.
    */
-  virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0;
+  virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, 
int64_t fork_pos = -1) = 0;
 
   /*!
    * \brief Pop out the trailing `n` tokens from the KV cache for the
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 9c3ee5d427..3ccab3826d 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   Optional<PackedFunc> f_attention_decode_end_forward_;
   PackedFunc f_merge_inplace_;
   PackedFunc f_split_rotary_;
+  PackedFunc f_copy_single_page_;
   Optional<PackedFunc> f_debug_get_kv_;
 
   /*! \brief Number of fork depth in the current round of forward. */
@@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
       Optional<PackedFunc> f_attention_prefill_end_forward,
       Optional<PackedFunc> f_attention_decode_begin_forward,
       Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
-      PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
+      PackedFunc f_split_rotary, PackedFunc f_copy_single_page, 
Optional<PackedFunc> f_debug_get_kv)
       : page_size_(page_size),
         num_layers_(num_layers),
         num_qo_heads_(num_qo_heads),
@@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
         
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
         f_merge_inplace_(std::move(f_merge_inplace)),
         f_split_rotary_(std::move(f_split_rotary)),
+        f_copy_single_page_(std::move(f_copy_single_page)),
         f_debug_get_kv_(std::move(f_debug_get_kv)),
         device_(device) {
     pages_.reserve(num_layers);
@@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   void RemoveSequence(int64_t seq_id) final {
     auto it = seq_map_.find(seq_id);
     CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot 
be found in KV cache.";
-    const Block& block = global_block_pool_[it->second.last_block_idx];
-    CHECK_EQ(block.external_ref_cnt, 0)
+    int32_t block_idx = it->second.last_block_idx;
+    CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
         << "The sequence is currently referenced by other sequence and thus 
cannot be removed.";
-
-    // - Decrease the external reference of the parent block.
-    if (block.parent_idx != -1) {
-      Block& parent_block = global_block_pool_[block.parent_idx];
-      ICHECK_GT(parent_block.external_ref_cnt, 0);
-      --parent_block.external_ref_cnt;
+    while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt 
== 0) {
+      // - Free pages in the last block.
+      for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
+        free_page_ids_.push_back(page_id);
+      }
+      free_block_idx_.push_back(block_idx);
+      block_idx = global_block_pool_[block_idx].parent_idx;
     }
-    // - Free pages in the last block.
-    for (int32_t page_id : block.page_ids) {
-      free_page_ids_.push_back(page_id);
+    // - Decrease the external reference of the parent block.
+    if (block_idx != -1) {
+      ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
+      --global_block_pool_[block_idx].external_ref_cnt;
     }
-    // - Remove the sequence from seq_map.
-    free_block_idx_.push_back(it->second.last_block_idx);
     seq_map_.erase(it);
     dirty_aux_data_device_ = true;
   }
 
-  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t 
fork_pos = -1) final {
     auto parent_it = seq_map_.find(parent_seq_id);
     CHECK(parent_it != seq_map_.end())
         << "The parent sequence \"" << parent_seq_id << "\" cannot be found in 
KV cache.";
@@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     CHECK_EQ(parent_it->second.sliding_window_size, -1)
         << "The parent sequence \"" << parent_seq_id
         << "\" is enabled with sliding window and thus cannot be forked.";
+    CHECK_GE(fork_pos, -1)
+        << "The forked position should be non-negative, or -1 for last 
position as default.";
+    CHECK_LE(fork_pos, parent_it->second.seq_length)
+        << "The forked position should not exceed the total length of parent 
sequence.";
 
-    int32_t parent_block_idx = parent_it->second.last_block_idx;
-    ++global_block_pool_[parent_block_idx].external_ref_cnt;
-    // Create a child block with the parent block pointer.
     int32_t child_block_idx = GetFreeBlock();
-    global_block_pool_[child_block_idx].start_pos = 
parent_it->second.seq_length;
-    global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+    if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
+      // Fork at last by appending a new block directly
+      int32_t parent_block_idx = parent_it->second.last_block_idx;
+      ++global_block_pool_[parent_block_idx].external_ref_cnt;
+      // Update child block start position and parent index
+      global_block_pool_[child_block_idx].start_pos = 
parent_it->second.seq_length;
+      global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+    } else {
+      // Locate the block to fork from and calculate in-block offset
+      std::vector<int32_t> trace = 
parent_it->second.GetBlockTrace(global_block_pool_);
+      int64_t in_block_offset = fork_pos;
+      int32_t forked_block_idx = -1;
+      for (int32_t block_idx : trace) {
+        if (in_block_offset < global_block_pool_[block_idx].seq_length) {
+          forked_block_idx = block_idx;
+          break;
+        }
+        in_block_offset -= global_block_pool_[block_idx].seq_length;
+      }
+      int32_t in_page_offset = in_block_offset % page_size_;
+      int32_t moved_offset = in_block_offset - in_page_offset;
+      if (moved_offset == 0) {
+        // Forked at the first page in block
+        int32_t parent_block_idx = 
global_block_pool_[forked_block_idx].parent_idx;
+        if (parent_block_idx != -1) {
+          ++global_block_pool_[parent_block_idx].external_ref_cnt;
+        }
+        // Update child block start position and parent index
+        global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+      } else {
+        // Forked at the second or latter page in block
+        int32_t parent_block_idx = GetFreeBlock();
+        // Insert new parent block before forked block and link child block
+        global_block_pool_[parent_block_idx].parent_idx =
+            global_block_pool_[forked_block_idx].parent_idx;
+        global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
+        global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+        global_block_pool_[parent_block_idx].external_ref_cnt = 1;
+
+        // Move common leading pages to new parent block
+        auto first_page = 
global_block_pool_[forked_block_idx].page_ids.begin();
+        auto last_page =
+            global_block_pool_[forked_block_idx].page_ids.begin() + 
moved_offset / page_size_;
+        global_block_pool_[parent_block_idx].page_ids = {first_page, 
last_page};
+        global_block_pool_[forked_block_idx].page_ids.erase(first_page, 
last_page);
+
+        // Update start position per blocks
+        global_block_pool_[parent_block_idx].start_pos =
+            global_block_pool_[forked_block_idx].start_pos;
+        global_block_pool_[forked_block_idx].start_pos += moved_offset;
+
+        // Update in-block sequence length per blocks
+        global_block_pool_[parent_block_idx].seq_length = moved_offset;
+        global_block_pool_[forked_block_idx].seq_length -= moved_offset;
+      }
+      global_block_pool_[child_block_idx].start_pos = fork_pos - 
in_page_offset;
+      global_block_pool_[child_block_idx].seq_length = in_page_offset;
+
+      if (in_page_offset > 0) {
+        // Fork within a page and copy common page to child block partially
+        int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
+        int32_t tgt_page_id = GetFreePage();
+        global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
+        CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
+      }
+    }
     // Create the child sequence with the child block.
     seq_map_.insert({child_seq_id, Sequence(global_block_pool_, 
child_block_idx)});
     dirty_aux_data_device_ = true;
   }
 
+  void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t 
copy_length) {
+    for (int layer = 0; layer < num_layers_; ++layer) {
+      f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, 
copy_length);
+    }
+  }
+
   void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
                                  int32_t attn_sink_size) final {
     CHECK(support_sliding_window_) << "The KV cache does not support sliding 
window.";
@@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // - Reset the dirty flag to false.
     dirty_aux_data_device_ = false;
   }
-};
+};  // namespace relax_vm
 
 TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
 
@@ -1412,7 +1485,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
                        PackedFunc f_attention_prefill_end_forward,
                        PackedFunc f_attention_decode_begin_forward,
                        PackedFunc f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, Optional<PackedFunc> 
f_debug_get_kv) {
+                       PackedFunc f_split_rotary, PackedFunc 
f_copy_single_page,
+                       Optional<PackedFunc> f_debug_get_kv) {
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1435,7 +1509,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           std::move(f_attention_prefill_ragged_end_forward),
           std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_debug_get_kv));
+          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
+          std::move(f_debug_get_kv));
       return AttentionKVCache(std::move(n));
     });
 
@@ -1447,7 +1522,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
                        PackedFunc f_attention_prefill_sliding_window,
                        PackedFunc f_attention_decode_sliding_window,
                        PackedFunc f_attention_prefill_ragged, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, Optional<PackedFunc> 
f_debug_get_kv) {
+                       PackedFunc f_split_rotary, PackedFunc 
f_copy_single_page,
+                       Optional<PackedFunc> f_debug_get_kv) {
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1467,7 +1543,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
           std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),  //
           NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
                 //
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_debug_get_kv));
+          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
+          std::move(f_debug_get_kv));
       return AttentionKVCache(std::move(n));
     });
 
diff --git a/src/runtime/relax_vm/rnn_state.cc 
b/src/runtime/relax_vm/rnn_state.cc
index 09873ba5f7..69225d6b2c 100644
--- a/src/runtime/relax_vm/rnn_state.cc
+++ b/src/runtime/relax_vm/rnn_state.cc
@@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj {
     dirty_aux_data_device_ = true;
   }
 
-  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t 
fork_pos = -1) final {
     auto parent_it = seq_map_.find(parent_seq_id);
     CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << 
parent_seq_id
                                        << "\" cannot be found in space state 
storage.";
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index d30ccd0224..c71b0dde3e 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -66,6 +66,7 @@ fattention_merge_state = None
 
 ftranspose_append = None
 fsplit_rotary = None
+fcopy_single_page = None
 fcopy_cache = None
 
 
@@ -222,6 +223,46 @@ def copy_cache(
             ]
 
 
+def _copy_single_page(num_heads, page_size, head_dim, dtype, target):
+    tx = 256 if str(target.kind) == "webgpu" else 1024
+
+    @T.prim_func
+    def copy_single_page(
+        pages: T.handle,
+        src_page_id: T.int64,
+        tgt_page_id: T.int64,
+        copy_length: T.int64,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        num_pages = T.int32()
+        P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, 
head_dim), dtype)
+
+        for b in T.thread_binding(
+            (copy_length * num_heads * head_dim + tx - 1) // tx, 
thread="blockIdx.x"
+        ):
+            for t in T.thread_binding(tx, thread="threadIdx.x"):
+                with T.block("copy"):
+                    vh = T.axis.spatial(
+                        num_heads,
+                        T.Cast("int32", (b * tx + t) // (copy_length * 
head_dim)),
+                    )
+                    vp = T.axis.spatial(
+                        copy_length,
+                        (b * tx + t) % (copy_length * head_dim) // head_dim,
+                    )
+                    vd = T.axis.spatial(
+                        head_dim,
+                        T.Cast(
+                            "int32",
+                            (b * tx + t) % head_dim,
+                        ),
+                    )
+                    P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, 
vd]
+                    P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, 
vd]
+
+    return copy_single_page
+
+
 def set_global_func():
     global fclear, fcreate, fadd_sequence, fremove_sequence, ffork_sequence, 
fpopn
     global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, 
fdebug_get_kv
@@ -230,7 +271,7 @@ def set_global_func():
     global fattention_prefill_ragged
     global fattention_prefill_ragged_begin_forward
     global fattention_prefill_ragged_end_forward
-    global fattention_merge_state, fsplit_rotary
+    global fattention_merge_state, fsplit_rotary, fcopy_single_page
     global ftranspose_append, fcopy_cache
 
     fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
@@ -282,6 +323,7 @@ def set_global_func():
         llama_rope_with_position_map(
             rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
         ),
+        _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
         copy_cache,
     ]:
         mod = tvm.IRModule({"main": tir_func})
@@ -290,7 +332,7 @@ def set_global_func():
         f = tvm.build(mod["main"], target=target)
         builts.append(f.entry_func)
 
-    ftranspose_append, fsplit_rotary, fcopy_cache = builts
+    ftranspose_append, fsplit_rotary, fcopy_single_page, fcopy_cache = builts
 
 
 def create_kv_cache(rope_mode):
@@ -327,6 +369,7 @@ def create_kv_cache(rope_mode):
         fattention_decode_end_forward,
         fattention_merge_state,
         fsplit_rotary,
+        fcopy_single_page,
         fcopy_cache,
     )
     return cache
@@ -384,7 +427,7 @@ def f_apply_rotary(x, offset, scale, theta):
 def apply_attention(
     kv_cache,
     rope_mode: RopeMode,
-    batch: List[Tuple[Union[int, Tuple[int, int]], int]],
+    batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
 ) -> None:
@@ -394,16 +437,20 @@ def apply_attention(
         fork_parent_id = None
         if isinstance(seq_id, tuple):
             # Fork sequence
-            seq_id, fork_parent_id = seq_id
+            seq_id, fork_parent_id, fork_pos = seq_id
             batch[i] = (seq_id, append_length)
         seq_ids.append(seq_id)
         append_lengths.append(append_length)
         if fork_parent_id is not None:
             assert fork_parent_id in cached_k
             assert seq_id not in cached_k
-            ffork_sequence(kv_cache, fork_parent_id, seq_id)
-            cached_k[seq_id] = cached_k[fork_parent_id]
-            cached_v[seq_id] = cached_v[fork_parent_id]
+            ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+            if fork_pos == -1:
+                cached_k[seq_id] = cached_k[fork_parent_id]
+                cached_v[seq_id] = cached_v[fork_parent_id]
+            else:
+                cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+                cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
         elif seq_id not in cached_k:
             fadd_sequence(kv_cache, seq_id)
             cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
@@ -563,12 +610,15 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
     batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
     apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
     # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k, 
cached_v)
+    # 0 <- 5 <- 6,8,9
+    # 0 <- 7
+    # 3 <- 4
     # Mixture of decode and prefill.
     operation_seq = [
         [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -579,6 +629,32 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
     for batch in operation_seq:
         apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
+    apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k, 
cached_v)
+
+    operation_seq = [
+        [(6, 1), (11, 1), (13, 1), (9, 1)],
+        [(10, 1), (16, 1), (18, 1), (19, 1)],
+        [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+        [(10, 10), (6, 2), (8, 3), (19, 4)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+    for i in range(19, -1, -1):
+        fremove_sequence(kv_cache, i)
+        cached_k.pop(i)
+        cached_v.pop(i)
+        verify_cached_kv(kv_cache, seq_ids=list(range(i)), 
expected_k=cached_k, expected_v=cached_v)
+
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
 def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index c33686d16e..3ed89ecd0f 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -66,6 +66,7 @@ fattn_prefill_ragged = None
 fmerge_state = None
 fsplit_rotary = None
 fattention_rotary = None
+fcopy_single_page = None
 
 
 def set_global_func(head_dim, dtype):
@@ -73,7 +74,7 @@ def set_global_func(head_dim, dtype):
     global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, 
fdebug_get_kv
     global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, 
fattn_prefill_ragged
     global fattn_prefill_sliding_window, fattn_decode_sliding_window
-    global fmerge_state, fsplit_rotary, fattention_rotary
+    global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page
 
     fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
     fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
@@ -104,6 +105,7 @@ def set_global_func(head_dim, dtype):
         llama_rope_with_position_map(
             rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
         ),
+        _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
     ]:
         mod = tvm.IRModule({"main": tir_func})
         with target:
@@ -121,6 +123,7 @@ def set_global_func(head_dim, dtype):
         fattn_prefill_ragged,
         fmerge_state,
         fsplit_rotary,
+        fcopy_single_page,
     ) = builts
 
 
@@ -152,6 +155,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window):
         fattn_prefill_ragged,
         fmerge_state,
         fsplit_rotary,
+        fcopy_single_page,
         fcopy_cache,
     )
     return cache
@@ -226,7 +230,7 @@ def f_apply_rotary(x, offset, scale, theta):
 def apply_attention(
     kv_cache,
     rope_mode: RopeMode,
-    batch: List[Tuple[Union[int, Tuple[int, int]], int]],
+    batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
     sliding_window_sizes: Optional[List[int]] = None,
@@ -238,16 +242,20 @@ def apply_attention(
         fork_parent_id = None
         if isinstance(seq_id, tuple):
             # Fork sequence
-            seq_id, fork_parent_id = seq_id
+            seq_id, fork_parent_id, fork_pos = seq_id
             batch[i] = (seq_id, append_length)
         seq_ids.append(seq_id)
         append_lengths.append(append_length)
         if fork_parent_id is not None:
             assert fork_parent_id in cached_k
             assert seq_id not in cached_k
-            ffork_sequence(kv_cache, fork_parent_id, seq_id)
-            cached_k[seq_id] = cached_k[fork_parent_id]
-            cached_v[seq_id] = cached_v[fork_parent_id]
+            ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+            if fork_pos == -1:
+                cached_k[seq_id] = cached_k[fork_parent_id]
+                cached_v[seq_id] = cached_v[fork_parent_id]
+            else:
+                cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+                cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
         elif seq_id not in cached_k:
             fadd_sequence(kv_cache, seq_id)
             cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
@@ -442,12 +450,15 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
     batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
     apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
     # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k, 
cached_v)
+    # 0 <- 5 <- 6,8,9
+    # 0 <- 7
+    # 3 <- 4
     # Mixture of decode and prefill.
     operation_seq = [
         [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -458,7 +469,27 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
     for batch in operation_seq:
         apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
-    for i in range(9, -1, -1):
+    apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k, 
cached_v)
+
+    operation_seq = [
+        [(6, 1), (11, 1), (13, 1), (9, 1)],
+        [(10, 1), (16, 1), (18, 1), (19, 1)],
+        [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+        [(10, 10), (6, 2), (8, 3), (19, 4)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+    for i in range(19, -1, -1):
         fremove_sequence(kv_cache, i)
         cached_k.pop(i)
         cached_v.pop(i)
@@ -477,7 +508,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config):
     cached_v = {}
     batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
     apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
 
     popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)]
     for seq_id, pop_length in popn_operations:
@@ -539,7 +570,7 @@ def 
test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
     sliding_window_sizes += [0, 18]
     attn_sink_sizes += [0, 12]
     apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
-    ffork_sequence(kv_cache, 5, 6)
+    ffork_sequence(kv_cache, 5, 6, -1)
     cached_k[6] = cached_k[5]
     cached_v[6] = cached_v[5]
     fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], 
attn_sink_sizes[-1])
@@ -1845,6 +1876,46 @@ def _merge_state_inplace(
     return merge_state_inplace
 
 
+def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target):
+    tx = 256 if str(target.kind) == "webgpu" else 1024
+
+    @T.prim_func
+    def copy_single_page(
+        pages: T.handle,
+        src_page_id: T.int64,
+        tgt_page_id: T.int64,
+        copy_length: T.int64,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        num_pages = T.int32()
+        P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, 
head_dim), dtype)
+
+        for b in T.thread_binding(
+            (copy_length * num_heads * head_dim + tx - 1) // tx, 
thread="blockIdx.x"
+        ):
+            for t in T.thread_binding(tx, thread="threadIdx.x"):
+                with T.block("copy"):
+                    vh = T.axis.spatial(
+                        num_heads,
+                        T.Cast("int32", (b * tx + t) // (copy_length * 
head_dim)),
+                    )
+                    vp = T.axis.spatial(
+                        copy_length,
+                        (b * tx + t) % (copy_length * head_dim) // head_dim,
+                    )
+                    vd = T.axis.spatial(
+                        head_dim,
+                        T.Cast(
+                            "int32",
+                            (b * tx + t) % head_dim,
+                        ),
+                    )
+                    P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, 
vd]
+                    P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, 
vd]
+
+    return copy_single_page
+
+
 if __name__ == "__main__":
     HEAD_DIMS = [64, 128]
     DTYPES = ["float16", "float32"]
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py 
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 28f370bca0..de35ad5d77 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -172,7 +172,7 @@ def test_rnn_state_fork_sequence(rnn_state):  # pylint: 
disable=redefined-outer-
     f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
     f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), 
device=device))
     f_end_forward(state)
-    f_fork_sequence(state, 0, 1)
+    f_fork_sequence(state, 0, 1, -1)
     verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]])
     # Verify popn for the forked sequence
     f_popn(state, 1, 1)

Reply via email to