This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c3be89a407 [KVCache] Support forking sequence at specific posotion
(#16813)
c3be89a407 is described below
commit c3be89a4070287cb98fded112a48a3d295564dea
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Mar 29 16:09:37 2024 -0700
[KVCache] Support forking sequence at specific posotion (#16813)
This PR enables KVCache to fork a sequence at specific position.
---
src/runtime/relax_vm/kv_state.h | 5 +-
src/runtime/relax_vm/paged_kv_cache.cc | 127 +++++++++++++++++----
src/runtime/relax_vm/rnn_state.cc | 2 +-
..._builtin_paged_attention_kv_cache_flashinfer.py | 102 ++++++++++++++---
...runtime_builtin_paged_attention_kv_cache_tir.py | 101 +++++++++++++---
.../python/relax/test_runtime_builtin_rnn_state.py | 2 +-
6 files changed, 283 insertions(+), 56 deletions(-)
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index f6857a9dce..e3c6e9608c 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -59,9 +59,12 @@ class KVStateObj : public Object {
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
* The child sequence id should not exist in cache prior to fork.
+ * \param fork_pos The parent position to fork, the legal forking position
is within
+ * [0, parent_seq_length] and -1 as default for last position. And if
forking position is 0,
+ * it equals to add a new sequence with child sequence id.
* \throws Error if the given sequence ids are not valid.
*/
- virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0;
+ virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id,
int64_t fork_pos = -1) = 0;
/*!
* \brief Pop out the trailing `n` tokens from the KV cache for the
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 9c3ee5d427..3ccab3826d 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
Optional<PackedFunc> f_attention_decode_end_forward_;
PackedFunc f_merge_inplace_;
PackedFunc f_split_rotary_;
+ PackedFunc f_copy_single_page_;
Optional<PackedFunc> f_debug_get_kv_;
/*! \brief Number of fork depth in the current round of forward. */
@@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc
f_merge_inplace,
- PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
+ PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
num_qo_heads_(num_qo_heads),
@@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
+ f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
@@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
void RemoveSequence(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot
be found in KV cache.";
- const Block& block = global_block_pool_[it->second.last_block_idx];
- CHECK_EQ(block.external_ref_cnt, 0)
+ int32_t block_idx = it->second.last_block_idx;
+ CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
<< "The sequence is currently referenced by other sequence and thus
cannot be removed.";
-
- // - Decrease the external reference of the parent block.
- if (block.parent_idx != -1) {
- Block& parent_block = global_block_pool_[block.parent_idx];
- ICHECK_GT(parent_block.external_ref_cnt, 0);
- --parent_block.external_ref_cnt;
+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt
== 0) {
+ // - Free pages in the last block.
+ for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
+ free_page_ids_.push_back(page_id);
+ }
+ free_block_idx_.push_back(block_idx);
+ block_idx = global_block_pool_[block_idx].parent_idx;
}
- // - Free pages in the last block.
- for (int32_t page_id : block.page_ids) {
- free_page_ids_.push_back(page_id);
+ // - Decrease the external reference of the parent block.
+ if (block_idx != -1) {
+ ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
+ --global_block_pool_[block_idx].external_ref_cnt;
}
- // - Remove the sequence from seq_map.
- free_block_idx_.push_back(it->second.last_block_idx);
seq_map_.erase(it);
dirty_aux_data_device_ = true;
}
- void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+ void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t
fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end())
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in
KV cache.";
@@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
CHECK_EQ(parent_it->second.sliding_window_size, -1)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus cannot be forked.";
+ CHECK_GE(fork_pos, -1)
+ << "The forked position should be non-negative, or -1 for last
position as default.";
+ CHECK_LE(fork_pos, parent_it->second.seq_length)
+ << "The forked position should not exceed the total length of parent
sequence.";
- int32_t parent_block_idx = parent_it->second.last_block_idx;
- ++global_block_pool_[parent_block_idx].external_ref_cnt;
- // Create a child block with the parent block pointer.
int32_t child_block_idx = GetFreeBlock();
- global_block_pool_[child_block_idx].start_pos =
parent_it->second.seq_length;
- global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+ if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
+ // Fork at last by appending a new block directly
+ int32_t parent_block_idx = parent_it->second.last_block_idx;
+ ++global_block_pool_[parent_block_idx].external_ref_cnt;
+ // Update child block start position and parent index
+ global_block_pool_[child_block_idx].start_pos =
parent_it->second.seq_length;
+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+ } else {
+ // Locate the block to fork from and calculate in-block offset
+ std::vector<int32_t> trace =
parent_it->second.GetBlockTrace(global_block_pool_);
+ int64_t in_block_offset = fork_pos;
+ int32_t forked_block_idx = -1;
+ for (int32_t block_idx : trace) {
+ if (in_block_offset < global_block_pool_[block_idx].seq_length) {
+ forked_block_idx = block_idx;
+ break;
+ }
+ in_block_offset -= global_block_pool_[block_idx].seq_length;
+ }
+ int32_t in_page_offset = in_block_offset % page_size_;
+ int32_t moved_offset = in_block_offset - in_page_offset;
+ if (moved_offset == 0) {
+ // Forked at the first page in block
+ int32_t parent_block_idx =
global_block_pool_[forked_block_idx].parent_idx;
+ if (parent_block_idx != -1) {
+ ++global_block_pool_[parent_block_idx].external_ref_cnt;
+ }
+ // Update child block start position and parent index
+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+ } else {
+ // Forked at the second or latter page in block
+ int32_t parent_block_idx = GetFreeBlock();
+ // Insert new parent block before forked block and link child block
+ global_block_pool_[parent_block_idx].parent_idx =
+ global_block_pool_[forked_block_idx].parent_idx;
+ global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+ global_block_pool_[parent_block_idx].external_ref_cnt = 1;
+
+ // Move common leading pages to new parent block
+ auto first_page =
global_block_pool_[forked_block_idx].page_ids.begin();
+ auto last_page =
+ global_block_pool_[forked_block_idx].page_ids.begin() +
moved_offset / page_size_;
+ global_block_pool_[parent_block_idx].page_ids = {first_page,
last_page};
+ global_block_pool_[forked_block_idx].page_ids.erase(first_page,
last_page);
+
+ // Update start position per blocks
+ global_block_pool_[parent_block_idx].start_pos =
+ global_block_pool_[forked_block_idx].start_pos;
+ global_block_pool_[forked_block_idx].start_pos += moved_offset;
+
+ // Update in-block sequence length per blocks
+ global_block_pool_[parent_block_idx].seq_length = moved_offset;
+ global_block_pool_[forked_block_idx].seq_length -= moved_offset;
+ }
+ global_block_pool_[child_block_idx].start_pos = fork_pos -
in_page_offset;
+ global_block_pool_[child_block_idx].seq_length = in_page_offset;
+
+ if (in_page_offset > 0) {
+ // Fork within a page and copy common page to child block partially
+ int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
+ int32_t tgt_page_id = GetFreePage();
+ global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
+ CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
+ }
+ }
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(global_block_pool_,
child_block_idx)});
dirty_aux_data_device_ = true;
}
+ void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t
copy_length) {
+ for (int layer = 0; layer < num_layers_; ++layer) {
+ f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id,
copy_length);
+ }
+ }
+
void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) final {
CHECK(support_sliding_window_) << "The KV cache does not support sliding
window.";
@@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// - Reset the dirty flag to false.
dirty_aux_data_device_ = false;
}
-};
+}; // namespace relax_vm
TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
@@ -1412,7 +1485,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_attention_prefill_end_forward,
PackedFunc f_attention_decode_begin_forward,
PackedFunc f_attention_decode_end_forward, PackedFunc
f_merge_inplace,
- PackedFunc f_split_rotary, Optional<PackedFunc>
f_debug_get_kv) {
+ PackedFunc f_split_rotary, PackedFunc
f_copy_single_page,
+ Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
@@ -1435,7 +1509,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_prefill_ragged_end_forward),
std::move(f_attention_prefill_begin_forward),
std::move(f_attention_prefill_end_forward),
std::move(f_attention_decode_begin_forward),
std::move(f_attention_decode_end_forward),
- std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_debug_get_kv));
+ std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_copy_single_page),
+ std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});
@@ -1447,7 +1522,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc
f_merge_inplace,
- PackedFunc f_split_rotary, Optional<PackedFunc>
f_debug_get_kv) {
+ PackedFunc f_split_rotary, PackedFunc
f_copy_single_page,
+ Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
@@ -1467,7 +1543,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,
//
- std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_debug_get_kv));
+ std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_copy_single_page),
+ std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});
diff --git a/src/runtime/relax_vm/rnn_state.cc
b/src/runtime/relax_vm/rnn_state.cc
index 09873ba5f7..69225d6b2c 100644
--- a/src/runtime/relax_vm/rnn_state.cc
+++ b/src/runtime/relax_vm/rnn_state.cc
@@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj {
dirty_aux_data_device_ = true;
}
- void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
+ void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t
fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" <<
parent_seq_id
<< "\" cannot be found in space state
storage.";
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index d30ccd0224..c71b0dde3e 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -66,6 +66,7 @@ fattention_merge_state = None
ftranspose_append = None
fsplit_rotary = None
+fcopy_single_page = None
fcopy_cache = None
@@ -222,6 +223,46 @@ def copy_cache(
]
+def _copy_single_page(num_heads, page_size, head_dim, dtype, target):
+ tx = 256 if str(target.kind) == "webgpu" else 1024
+
+ @T.prim_func
+ def copy_single_page(
+ pages: T.handle,
+ src_page_id: T.int64,
+ tgt_page_id: T.int64,
+ copy_length: T.int64,
+ ):
+ T.func_attr({"tir.is_scheduled": 1})
+ num_pages = T.int32()
+ P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size,
head_dim), dtype)
+
+ for b in T.thread_binding(
+ (copy_length * num_heads * head_dim + tx - 1) // tx,
thread="blockIdx.x"
+ ):
+ for t in T.thread_binding(tx, thread="threadIdx.x"):
+ with T.block("copy"):
+ vh = T.axis.spatial(
+ num_heads,
+ T.Cast("int32", (b * tx + t) // (copy_length *
head_dim)),
+ )
+ vp = T.axis.spatial(
+ copy_length,
+ (b * tx + t) % (copy_length * head_dim) // head_dim,
+ )
+ vd = T.axis.spatial(
+ head_dim,
+ T.Cast(
+ "int32",
+ (b * tx + t) % head_dim,
+ ),
+ )
+ P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp,
vd]
+ P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp,
vd]
+
+ return copy_single_page
+
+
def set_global_func():
global fclear, fcreate, fadd_sequence, fremove_sequence, ffork_sequence,
fpopn
global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv,
fdebug_get_kv
@@ -230,7 +271,7 @@ def set_global_func():
global fattention_prefill_ragged
global fattention_prefill_ragged_begin_forward
global fattention_prefill_ragged_end_forward
- global fattention_merge_state, fsplit_rotary
+ global fattention_merge_state, fsplit_rotary, fcopy_single_page
global ftranspose_append, fcopy_cache
fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
@@ -282,6 +323,7 @@ def set_global_func():
llama_rope_with_position_map(
rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
),
+ _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
copy_cache,
]:
mod = tvm.IRModule({"main": tir_func})
@@ -290,7 +332,7 @@ def set_global_func():
f = tvm.build(mod["main"], target=target)
builts.append(f.entry_func)
- ftranspose_append, fsplit_rotary, fcopy_cache = builts
+ ftranspose_append, fsplit_rotary, fcopy_single_page, fcopy_cache = builts
def create_kv_cache(rope_mode):
@@ -327,6 +369,7 @@ def create_kv_cache(rope_mode):
fattention_decode_end_forward,
fattention_merge_state,
fsplit_rotary,
+ fcopy_single_page,
fcopy_cache,
)
return cache
@@ -384,7 +427,7 @@ def f_apply_rotary(x, offset, scale, theta):
def apply_attention(
kv_cache,
rope_mode: RopeMode,
- batch: List[Tuple[Union[int, Tuple[int, int]], int]],
+ batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
) -> None:
@@ -394,16 +437,20 @@ def apply_attention(
fork_parent_id = None
if isinstance(seq_id, tuple):
# Fork sequence
- seq_id, fork_parent_id = seq_id
+ seq_id, fork_parent_id, fork_pos = seq_id
batch[i] = (seq_id, append_length)
seq_ids.append(seq_id)
append_lengths.append(append_length)
if fork_parent_id is not None:
assert fork_parent_id in cached_k
assert seq_id not in cached_k
- ffork_sequence(kv_cache, fork_parent_id, seq_id)
- cached_k[seq_id] = cached_k[fork_parent_id]
- cached_v[seq_id] = cached_v[fork_parent_id]
+ ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+ if fork_pos == -1:
+ cached_k[seq_id] = cached_k[fork_parent_id]
+ cached_v[seq_id] = cached_v[fork_parent_id]
+ else:
+ cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+ cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
elif seq_id not in cached_k:
fadd_sequence(kv_cache, seq_id)
cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
@@ -563,12 +610,15 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
# Fork existing sequences.
- apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k,
cached_v)
+ # 0 <- 5 <- 6,8,9
+ # 0 <- 7
+ # 3 <- 4
# Mixture of decode and prefill.
operation_seq = [
[(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -579,6 +629,32 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
for batch in operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k,
cached_v)
+
+ operation_seq = [
+ [(6, 1), (11, 1), (13, 1), (9, 1)],
+ [(10, 1), (16, 1), (18, 1), (19, 1)],
+ [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+ [(10, 10), (6, 2), (8, 3), (19, 4)],
+ ]
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+ for i in range(19, -1, -1):
+ fremove_sequence(kv_cache, i)
+ cached_k.pop(i)
+ cached_v.pop(i)
+ verify_cached_kv(kv_cache, seq_ids=list(range(i)),
expected_k=cached_k, expected_v=cached_v)
+
@pytest.mark.skip(reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index c33686d16e..3ed89ecd0f 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -66,6 +66,7 @@ fattn_prefill_ragged = None
fmerge_state = None
fsplit_rotary = None
fattention_rotary = None
+fcopy_single_page = None
def set_global_func(head_dim, dtype):
@@ -73,7 +74,7 @@ def set_global_func(head_dim, dtype):
global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv,
fdebug_get_kv
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode,
fattn_prefill_ragged
global fattn_prefill_sliding_window, fattn_decode_sliding_window
- global fmerge_state, fsplit_rotary, fattention_rotary
+ global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page
fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
@@ -104,6 +105,7 @@ def set_global_func(head_dim, dtype):
llama_rope_with_position_map(
rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
),
+ _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
]:
mod = tvm.IRModule({"main": tir_func})
with target:
@@ -121,6 +123,7 @@ def set_global_func(head_dim, dtype):
fattn_prefill_ragged,
fmerge_state,
fsplit_rotary,
+ fcopy_single_page,
) = builts
@@ -152,6 +155,7 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
fattn_prefill_ragged,
fmerge_state,
fsplit_rotary,
+ fcopy_single_page,
fcopy_cache,
)
return cache
@@ -226,7 +230,7 @@ def f_apply_rotary(x, offset, scale, theta):
def apply_attention(
kv_cache,
rope_mode: RopeMode,
- batch: List[Tuple[Union[int, Tuple[int, int]], int]],
+ batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
sliding_window_sizes: Optional[List[int]] = None,
@@ -238,16 +242,20 @@ def apply_attention(
fork_parent_id = None
if isinstance(seq_id, tuple):
# Fork sequence
- seq_id, fork_parent_id = seq_id
+ seq_id, fork_parent_id, fork_pos = seq_id
batch[i] = (seq_id, append_length)
seq_ids.append(seq_id)
append_lengths.append(append_length)
if fork_parent_id is not None:
assert fork_parent_id in cached_k
assert seq_id not in cached_k
- ffork_sequence(kv_cache, fork_parent_id, seq_id)
- cached_k[seq_id] = cached_k[fork_parent_id]
- cached_v[seq_id] = cached_v[fork_parent_id]
+ ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+ if fork_pos == -1:
+ cached_k[seq_id] = cached_k[fork_parent_id]
+ cached_v[seq_id] = cached_v[fork_parent_id]
+ else:
+ cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+ cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
elif seq_id not in cached_k:
fadd_sequence(kv_cache, seq_id)
cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
@@ -442,12 +450,15 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
# Fork existing sequences.
- apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k,
cached_v)
+ # 0 <- 5 <- 6,8,9
+ # 0 <- 7
+ # 3 <- 4
# Mixture of decode and prefill.
operation_seq = [
[(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -458,7 +469,27 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
for batch in operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
- for i in range(9, -1, -1):
+ apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k,
cached_v)
+
+ operation_seq = [
+ [(6, 1), (11, 1), (13, 1), (9, 1)],
+ [(10, 1), (16, 1), (18, 1), (19, 1)],
+ [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+ [(10, 10), (6, 2), (8, 3), (19, 4)],
+ ]
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+ for i in range(19, -1, -1):
fremove_sequence(kv_cache, i)
cached_k.pop(i)
cached_v.pop(i)
@@ -477,7 +508,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config):
cached_v = {}
batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)]
for seq_id, pop_length in popn_operations:
@@ -539,7 +570,7 @@ def
test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
sliding_window_sizes += [0, 18]
attn_sink_sizes += [0, 12]
apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
- ffork_sequence(kv_cache, 5, 6)
+ ffork_sequence(kv_cache, 5, 6, -1)
cached_k[6] = cached_k[5]
cached_v[6] = cached_v[5]
fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1],
attn_sink_sizes[-1])
@@ -1845,6 +1876,46 @@ def _merge_state_inplace(
return merge_state_inplace
+def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target):
+ tx = 256 if str(target.kind) == "webgpu" else 1024
+
+ @T.prim_func
+ def copy_single_page(
+ pages: T.handle,
+ src_page_id: T.int64,
+ tgt_page_id: T.int64,
+ copy_length: T.int64,
+ ):
+ T.func_attr({"tir.is_scheduled": 1})
+ num_pages = T.int32()
+ P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size,
head_dim), dtype)
+
+ for b in T.thread_binding(
+ (copy_length * num_heads * head_dim + tx - 1) // tx,
thread="blockIdx.x"
+ ):
+ for t in T.thread_binding(tx, thread="threadIdx.x"):
+ with T.block("copy"):
+ vh = T.axis.spatial(
+ num_heads,
+ T.Cast("int32", (b * tx + t) // (copy_length *
head_dim)),
+ )
+ vp = T.axis.spatial(
+ copy_length,
+ (b * tx + t) % (copy_length * head_dim) // head_dim,
+ )
+ vd = T.axis.spatial(
+ head_dim,
+ T.Cast(
+ "int32",
+ (b * tx + t) % head_dim,
+ ),
+ )
+ P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp,
vd]
+ P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp,
vd]
+
+ return copy_single_page
+
+
if __name__ == "__main__":
HEAD_DIMS = [64, 128]
DTYPES = ["float16", "float32"]
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 28f370bca0..de35ad5d77 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -172,7 +172,7 @@ def test_rnn_state_fork_sequence(rnn_state): # pylint:
disable=redefined-outer-
f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device))
f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32),
device=device))
f_end_forward(state)
- f_fork_sequence(state, 0, 1)
+ f_fork_sequence(state, 0, 1, -1)
verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]])
# Verify popn for the forked sequence
f_popn(state, 1, 1)