This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d93f4ad4b8 [KVCache] Add KV Cache for CPU Runtime (#17615)
d93f4ad4b8 is described below
commit d93f4ad4b88b224523e258f04a4bae450eddb504
Author: Mengshiun Yu <[email protected]>
AuthorDate: Sun Feb 2 16:48:42 2025 -0500
[KVCache] Add KV Cache for CPU Runtime (#17615)
* [KVCache] Add KV Cache for CPU Runtime
This PR introduces KV Cache support for the CPU runtime:
* Implementation of KV Cache TIR for CPU-based processing.
* Updates to the relevant runtime components to integrate KV Cache.
Co-authored-by: HMZ <[email protected]>
Co-authored-by: ShouChenChiu <[email protected]>
* Add unit test
---------
Co-authored-by: HMZ <[email protected]>
Co-authored-by: ShouChenChiu <[email protected]>
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 882 ++++++++++++++++++-
python/tvm/relax/frontend/nn/llm/tree_attn.py | 405 +++++++++
src/runtime/cpu_device_api.cc | 48 ++
src/runtime/relax_vm/paged_kv_cache.cc | 3 +-
...runtime_builtin_paged_attention_kv_cache_cpu.py | 956 +++++++++++++++++++++
5 files changed, 2278 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 399e418c46..844b237381 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -30,7 +30,16 @@ from tvm.script import tir as T
from tvm.target import Target
from .position_embedding import llama_rope_with_position_map,
switch_rope_freq_func
-from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache
+from .tree_attn import (
+ tree_attn,
+ tree_attn_cpu,
+ tree_attn_with_paged_kv_cache,
+ tree_attn_with_paged_kv_cache_cpu,
+)
+
+
+def _var_cpu(dtype):
+ return T.alloc_buffer((1,), dtype)
def get_max_num_threads_per_block(target: Target) -> int:
@@ -371,23 +380,230 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
# pylint: disable=line-too-long
# fmt: off
bb.add_func(_kv_cache_transpose_append(num_key_value_heads,
head_dim, dtype), "kv_cache_transpose_append"),
- bb.add_func(_attention_prefill(num_key_value_heads,
num_attention_heads, head_dim, dtype, False, rope_scaling, target),
"tir_attention_prefill"),
- bb.add_func(_attention_decode(num_key_value_heads,
num_attention_heads, head_dim, dtype, False, rope_scaling, target),
"tir_attention_decode"),
- bb.add_func(_attention_prefill(num_key_value_heads,
num_attention_heads, head_dim, dtype, True, rope_scaling, target),
"tir_attention_prefill_sliding_window"),
- bb.add_func(_attention_decode(num_key_value_heads,
num_attention_heads, head_dim, dtype, True, rope_scaling, target),
"tir_attention_decode_sliding_window"),
- bb.add_func(_attention_prefill_ragged(num_key_value_heads,
num_attention_heads, head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_ragged"),
- bb.add_func(_merge_state_inplace(num_attention_heads, head_dim,
dtype, target), "tir_attention_merge_state"),
- bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale,
head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling,
rotary_dim), "tir_split_rotary"),
- bb.add_func(_copy_single_page(num_key_value_heads, page_size,
head_dim, dtype, target), "kv_cache_copy_single_page"),
- bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers,
num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
- bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype,
target), "kv_cache_compact_kv_copy"),
- bb.add_func(tree_attn(num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
- bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
- rope_ext_factors,
- rx.PrimValue(enable_disaggregation),
# fmt: on
# pylint: enable=line-too-long
]
+
+ if str(target.kind) == "llvm":
+ args.extend(
+ [
+ bb.add_func(
+ _attention_prefill_cpu(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ False,
+ rope_scaling,
+ ),
+ "tir_attention_prefill_cpu",
+ ),
+ bb.add_func(
+ _attention_decode_cpu(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ False,
+ rope_scaling,
+ ),
+ "tir_attention_decode_cpu",
+ ),
+ bb.add_func(
+ _attention_prefill_cpu(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ True,
+ rope_scaling,
+ ),
+ "tir_attention_prefill_cpu_sliding_window",
+ ),
+ bb.add_func(
+ _attention_decode_cpu(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ True,
+ rope_scaling,
+ ),
+ "tir_attention_decode_cpu_sliding_window",
+ ),
+ bb.add_func(
+ _attention_prefill_ragged_cpu(
+ num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling
+ ),
+ "tir_attention_prefill_ragged_cpu",
+ ),
+ bb.add_func(
+ _merge_state_inplace_cpu(dtype),
+ "tir_attention_merge_state_cpu",
+ ),
+ bb.add_func(
+ llama_rope_with_position_map(
+ rope_theta,
+ rope_scale,
+ head_dim,
+ num_attention_heads,
+ num_key_value_heads,
+ dtype,
+ rope_scaling,
+ rotary_dim,
+ ),
+ "tir_split_rotary",
+ ),
+ bb.add_func(
+ _copy_single_page_cpu(num_key_value_heads, page_size,
head_dim, dtype),
+ "kv_cache_copy_single_page_cpu",
+ ),
+ bb.add_func(
+ _kv_cache_debug_get_kv(
+ num_hidden_layers, num_key_value_heads, head_dim,
dtype
+ ),
+ "kv_cache_debug_get_kv",
+ ),
+ bb.add_func(
+ _compact_kv_copy_cpu(num_key_value_heads, head_dim,
dtype),
+ "kv_cache_compact_kv_copy_cpu",
+ ),
+ bb.add_func(
+ tree_attn_cpu(
+ num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling
+ ),
+ "tir_attention_prefill_with_tree_mask_cpu",
+ ),
+ bb.add_func(
+ tree_attn_with_paged_kv_cache_cpu(
+ num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling
+ ),
+
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu",
+ ),
+ rope_ext_factors,
+ rx.PrimValue(enable_disaggregation),
+ ]
+ )
+ else:
+ args.extend(
+ [
+ bb.add_func(
+ _attention_prefill(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ False,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_prefill",
+ ),
+ bb.add_func(
+ _attention_decode(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ False,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_decode",
+ ),
+ bb.add_func(
+ _attention_prefill(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ True,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_prefill_sliding_window",
+ ),
+ bb.add_func(
+ _attention_decode(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ True,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_decode_sliding_window",
+ ),
+ bb.add_func(
+ _attention_prefill_ragged(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_prefill_ragged",
+ ),
+ bb.add_func(
+ _merge_state_inplace(num_attention_heads, head_dim,
dtype, target),
+ "tir_attention_merge_state",
+ ),
+ bb.add_func(
+ llama_rope_with_position_map(
+ rope_theta,
+ rope_scale,
+ head_dim,
+ num_attention_heads,
+ num_key_value_heads,
+ dtype,
+ rope_scaling,
+ rotary_dim,
+ ),
+ "tir_split_rotary",
+ ),
+ bb.add_func(
+ _copy_single_page(num_key_value_heads, page_size,
head_dim, dtype, target),
+ "kv_cache_copy_single_page",
+ ),
+ bb.add_func(
+ _kv_cache_debug_get_kv(
+ num_hidden_layers, num_key_value_heads, head_dim,
dtype
+ ),
+ "kv_cache_debug_get_kv",
+ ),
+ bb.add_func(
+ _compact_kv_copy(num_key_value_heads, head_dim, dtype,
target),
+ "kv_cache_compact_kv_copy",
+ ),
+ bb.add_func(
+ tree_attn(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ rope_scaling,
+ target,
+ ),
+ "tir_attention_prefill_with_tree_mask",
+ ),
+ bb.add_func(
+ tree_attn_with_paged_kv_cache(
+ num_key_value_heads,
+ num_attention_heads,
+ head_dim,
+ dtype,
+ rope_scaling,
+ target,
+ ),
+
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache",
+ ),
+ rope_ext_factors,
+ rx.PrimValue(enable_disaggregation),
+ ]
+ )
+
super().__init__(
_expr=rx.call_pure_packed(
"vm.builtin.paged_attention_kv_cache_create_reduced",
@@ -553,6 +769,161 @@ def _get_seq_offset(pos, seq_id, length_info,
sliding_window):
)
+def _attention_prefill_cpu(h_kv, h_q, d, dtype, sliding_window: bool,
rope_scaling: Dict[str, Any]):
+ global_symbol = "batch_prefill_paged_kv_cpu"
+ if sliding_window:
+ global_symbol += "_sliding_window"
+
+ group_size = h_q // h_kv
+ sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+ # pylint: disable=line-too-long,too-many-branches
+ # fmt: off
+ @T.prim_func
+ def batch_prefill_paged_kv_cpu(
+ _0: T.int32, # pylint: disable=unused-argument
+ var_q: T.handle, # [total_len, h_q, d]
+ var_q_indptr: T.handle, # [batch_size + 1]
+ var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d]
+ var_page_indptr: T.handle, # [batch_size + 1]
+ var_page_values: T.handle, # [nnz_pages]
+ var_length_info: T.handle, # [b] when sliding window = False, or
otherwise [3, b]
+ var_k_rope_pos_offset: T.handle, # [b]
+ var_q_rope_position: T.handle, # [total_len]
+ var_output: T.handle, # [total_len, h_q, d]
+ var_lse: T.handle, # [total_len, h_q]
+ causal: T.int32,
+ rotary_mode: T.int32,
+ rope_scale: T.float32,
+ rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
+ ):
+ T.func_attr({"global_symbol": global_symbol})
+ batch_size = T.int32(is_size_var=True)
+ total_len = T.int32(is_size_var=True)
+ nnz_pages = T.int32(is_size_var=True)
+ max_num_pages = T.int32(is_size_var=True)
+ q_indptr_elem_offset = T.int32(is_size_var=True)
+ page_indptr_elem_offset = T.int32(is_size_var=True)
+ page_values_elem_offset = T.int32(is_size_var=True)
+ k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+ q_rope_position_elem_offset = T.int32(is_size_var=True)
+ length_info_elem_offset = T.int32(is_size_var=True)
+
+ q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
+ q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset)
+ pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d),
dtype)
+ page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,),
"int32", elem_offset=page_indptr_elem_offset)
+ page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32",
elem_offset=page_values_elem_offset)
+ k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset,
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
+ q_rope_position = T.match_buffer(var_q_rope_position, (total_len,),
"int32", elem_offset=q_rope_position_elem_offset)
+ output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
+ lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint:
disable=unused-variable
+ # The length information of the sequences.
+ # - It is in shape `(3, batch_size)` when sliding window is enabled.
+ # For a sequence "i", location
+ # - "(0, i)" is the number of KV slots used in the last page of the
seq ("last_page_len"),
+ # - "(1, i)" is the starting offset of the sliding window in the seq,
+ # - "(2, i)" is the attn sink length of the sequence.
+ # - It is in shape `(batch_size,)` when sliding window is disabled,
+ # denoting the "last_page_len".
+ length_info = _declare_length_info(var_length_info, batch_size,
sliding_window, length_info_elem_offset)
+
+
+ for h_qo in T.serial(h_q):
+ for b_idx in T.serial(batch_size):
+ with T.block("attn"):
+ O_local = T.alloc_buffer((d, ), "float32")
+ Q_local = T.alloc_buffer((d, ), "float32")
+ K_local = T.alloc_buffer((d, ), "float32")
+ V_local = T.alloc_buffer((d, ), "float32")
+
+ kv_chunk_len = T.alloc_buffer((1, ), "int32")
+
+ m_val = T.alloc_buffer((1, ), "float32")
+ new_m = T.alloc_buffer((1, ), "float32")
+ d_val = T.alloc_buffer((1, ), "float32")
+ S_val = T.alloc_buffer((1, ), "float32")
+ scale_O = T.alloc_buffer((1, ), "float32")
+ factor = T.alloc_buffer((1, ), "float32")
+ cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
+ cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
+ #max_kv_len: T.int32 = max_num_pages * 16
+ kv_chunk_len[0] = T.if_then_else(
+ cur_page_indptr_begin != cur_page_indptr_end,
+ _get_kv_chunk_len(cur_page_indptr_end -
cur_page_indptr_begin, 16, b_idx, length_info, sliding_window),
+ 0
+ )
+
+
+ for q_idx in T.serial(q_indptr[b_idx + 1] -
q_indptr[b_idx]):
+ #init m, d, O
+ m_val[0] = -5e4
+ d_val[0] = 1.0
+ for d_idx in T.serial(d):
+ O_local[d_idx] = 0.0
+ curl_q: T.int32 = q_indptr[b_idx] + q_idx
+
+ for d_idx in T.serial(d):
+
+ Q_local[d_idx] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(q, q_rope_position[curl_q], d,
rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
+ q[curl_q, h_qo, d_idx]
+ )
+ for row_idx in T.serial(max_num_pages * 16):
+ if row_idx < kv_chunk_len[0]:
+ # seq_offset: T.int32(is_size_var=True) =
_get_seq_offset(row_idx, b_idx, length_info, sliding_window)
+ #seq_offset: T.int32(is_size_var=True) =
row_idx
+ page_no: T.int32(is_size_var=True) =
page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx,
length_info, sliding_window) // 16)]
+ page_offset: T.int32(is_size_var=True) =
_get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16
+
+ # Load KV
+ for d_idx in T.serial(d):
+ K_local[d_idx] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(pages, k_rope_pos_offset[b_idx]
+ row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size,
page_offset, d_idx), dtype, rope_scaling),
+ pages[page_no, 0, h_qo // group_size,
page_offset, d_idx]
+ )
+ V_local[d_idx] = pages[page_no, 1, h_qo //
group_size, page_offset, d_idx]
+
+ # Compute S
+ # Q[i] * K[i] * attn_score * sm_scale
+ S_val[0] = 0.0
+ for d_idx in T.serial(d):
+ S_val[0] += Q_local[d_idx] * K_local[d_idx]
+ S_val[0] *= attn_score_scaling_factor *
sm_scale
+
+ # update m_val, d_val , O_local
+ if _causal_mask(causal,
+ row=q_idx,
+ col=row_idx,
+ kv_len=kv_chunk_len[0],
+ qo_len=q_indptr[b_idx + 1] -
q_indptr[b_idx]):
+ new_m[0] = T.max(m_val[0], S_val[0])
+ else:
+ S_val[0] = -5e4
+ # update d_val
+ d_val[0] *= T.exp2(m_val[0] - new_m[0])
+ d_val[0] += T.exp2(S_val[0] - new_m[0])
+
+ # restore O_local then update O_local
+ scale_O[0] = T.exp2(m_val[0] - new_m[0])
+ m_val[0] = new_m[0]
+ factor[0] = T.exp2(S_val[0] - m_val[0])
+ for d_idx in T.serial(d):
+ O_local[d_idx] = O_local[d_idx] *
scale_O[d_idx]
+
+
+ for d_idx in T.serial(d):
+ O_local[d_idx] += V_local[d_idx] *
factor[0]
+ # Store Output
+ for d_idx in T.serial(d):
+ O_local[d_idx] = O_local[d_idx] /d_val[0]
+ output[curl_q, h_qo, d_idx] = O_local[d_idx]
+ lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0])
+ return batch_prefill_paged_kv_cpu
+
+
def _attention_prefill(
h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any],
target: Target
):
@@ -920,6 +1291,189 @@ def _attention_prefill(
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+def _attention_decode_cpu(
+ num_kv_heads,
+ num_qo_heads,
+ head_dim,
+ qkv_dtype,
+ sliding_window: bool,
+ rope_scaling: Dict[str, Any],
+):
+ log2e = math.log2(math.exp(1))
+ H_qo = num_qo_heads
+ H_kv = num_kv_heads
+ D = head_dim
+ group_size = num_qo_heads // num_kv_heads
+
+ global_symbol = "batch_decode_paged_kv_cpu"
+ if sliding_window:
+ global_symbol += "_sliding_window"
+
+ @T.prim_func(check_well_formed=False)
+ def batch_decode_paged_kv(
+ _0: T.int32, # pylint: disable=unused-argument
+ Q_handle: T.handle,
+ pages_handle: T.handle,
+ page_table_indptr_handle: T.handle,
+ page_table_values_handle: T.handle,
+ var_length_info: T.handle, # [b] when sliding window = False, or
otherwise [3, b]
+ k_rope_pos_offset_handle: T.handle,
+ q_rope_position_handle: T.handle,
+ output_handle: T.handle,
+ lse_handle: T.handle,
+ rotary_mode: T.int32,
+ rope_scale: T.float32,
+ rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
+ ):
+ T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol})
+ B = T.int32(is_size_var=True)
+ nnz_pages = T.int32(is_size_var=True)
+ max_num_pages = T.int32(is_size_var=True)
+ page_indptr_elem_offset = T.int32(is_size_var=True)
+ page_values_elem_offset = T.int32(is_size_var=True)
+ k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+ q_rope_position_elem_offset = T.int32(is_size_var=True)
+ length_info_elem_offset = T.int32(is_size_var=True)
+
+ Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) # query 值
+ pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, D),
qkv_dtype)
+ page_table_indptr = T.match_buffer(
+ page_table_indptr_handle, (B + 1,), "int32",
elem_offset=page_indptr_elem_offset
+ )
+ page_table_values = T.match_buffer(
+ page_table_values_handle, (nnz_pages,), "int32",
elem_offset=page_values_elem_offset
+ )
+ k_rope_pos_offset = T.match_buffer(
+ k_rope_pos_offset_handle, (B,), "int32",
elem_offset=k_rope_pos_offset_elem_offset
+ )
+ q_rope_position = T.match_buffer(
+ q_rope_position_handle, (B,), "int32",
elem_offset=q_rope_position_elem_offset
+ )
+ output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype)
+ lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint:
disable=unused-variable
+ # The length information of the sequences.
+ # - It is in shape `(3, batch_size)` when sliding window is enabled.
+ # For a sequence "i", location
+ # - "(0, i)" is the number of KV slots used in the last page of the
seq ("last_page_len"),
+ # - "(1, i)" is the starting offset of the sliding window in the seq,
+ # - "(2, i)" is the attn sink length of the sequence.
+ # - It is in shape `(batch_size,)` when sliding window is disabled,
+ # denoting the "last_page_len".
+ length_info = _declare_length_info(
+ var_length_info, B, sliding_window, length_info_elem_offset
+ )
+
+ sm_scale = 1.0 / math.sqrt(float(D)) * log2e
+
+ for b in T.serial(B):
+ with T.block("attn"):
+ O_local = T.alloc_buffer((D,), "float32")
+ Q_local = T.alloc_buffer((D,), "float32")
+ K_local = T.alloc_buffer((D,), "float32")
+ V_local = T.alloc_buffer((D,), "float32")
+
+ kv_chunk_len = T.alloc_buffer((1,), "int32")
+
+ m_val = T.alloc_buffer((1,), "float32")
+ new_m = T.alloc_buffer((1,), "float32")
+ d_val = T.alloc_buffer((1,), "float32")
+ S_val = T.alloc_buffer((1,), "float32")
+ scale_O = T.alloc_buffer((1,), "float32")
+ factor = T.alloc_buffer((1,), "float32")
+
+ cur_page_indptr_begin: T.int32 = page_table_indptr[b]
+ cur_page_indptr_end: T.int32 = page_table_indptr[b + 1]
+
+ kv_chunk_len[0] = T.if_then_else(
+ cur_page_indptr_begin != cur_page_indptr_end,
+ _get_kv_chunk_len(
+ cur_page_indptr_end - cur_page_indptr_begin,
+ 16,
+ b,
+ length_info,
+ sliding_window,
+ ),
+ 0,
+ )
+
+ for h_qo in T.serial(H_qo):
+ m_val[0] = -5e4
+ d_val[0] = 1.0
+
+ for d in T.serial(D):
+ O_local[d] = 0.0
+
+ for d in T.serial(D):
+ Q_local[d] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ Q,
+ q_rope_position[b],
+ head_dim,
+ rope_theta,
+ rope_scale,
+ (b, h_qo, d),
+ qkv_dtype,
+ rope_scaling,
+ ),
+ Q[b, h_qo, d],
+ )
+
+ for row_idx in T.serial(kv_chunk_len[0]):
+ seq_offset: T.int32(is_size_var=True) =
_get_seq_offset(
+ row_idx, b, length_info, sliding_window
+ )
+ page_no: T.int32(is_size_var=True) = page_table_values[
+ cur_page_indptr_begin + (seq_offset // 16)
+ ]
+ page_offset: T.int32(is_size_var=True) = seq_offset %
16
+
+ for d in T.serial(D):
+ K_local[d] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ pages,
+ k_rope_pos_offset[b] + row_idx,
+ head_dim,
+ rope_theta,
+ rope_scale,
+ (page_no, 0, h_qo // group_size,
page_offset, d),
+ qkv_dtype,
+ rope_scaling,
+ ),
+ pages[page_no, 0, h_qo // group_size,
page_offset, d],
+ )
+ S_val[0] = 0.0
+ for d in T.serial(D):
+ S_val[0] += Q_local[d] * K_local[d]
+ S_val[0] *= attn_score_scaling_factor * sm_scale
+
+ new_m[0] = T.max(m_val[0], S_val[0])
+ d_val[0] = (d_val[0] * T.exp2(m_val[0] - new_m[0])) +
T.exp2(
+ S_val[0] - new_m[0]
+ )
+
+ scale_O[0] = T.exp2(m_val[0] - new_m[0])
+
+ for d in T.serial(D):
+ O_local[d] = O_local[d] * scale_O[0]
+
+ m_val[0] = new_m[0]
+ for d in T.serial(D):
+ V_local[d] = pages[page_no, 1, h_qo // group_size,
page_offset, d]
+
+ factor[0] = T.exp2(S_val[0] - m_val[0])
+ for d in T.serial(D):
+ O_local[d] = O_local[d] + V_local[d] * factor[0]
+ for d in T.serial(D):
+ O_local[d] = O_local[d] / d_val[0]
+ output[b, h_qo, d] = O_local[d]
+ lse[b, h_qo] = m_val[0] + T.log2(d_val[0])
+
+ return batch_decode_paged_kv
+
+
def _attention_decode(
num_kv_heads,
num_qo_heads,
@@ -1179,6 +1733,47 @@ def _attention_decode(
return batch_decode_paged_kv
+def _merge_state_inplace_cpu(v_dtype):
+ @T.prim_func
+ def merge_state_inplace_cpu(
+ v: T.handle,
+ s: T.handle,
+ v_other: T.handle,
+ s_other: T.handle,
+ ):
+ T.func_attr({"tir.is_scheduled": 1})
+ N = T.int32(is_size_var=True)
+ H = T.int32(is_size_var=True)
+ D = T.int32(is_size_var=True)
+
+ V = T.match_buffer(v, (N, H, D), v_dtype)
+ S = T.match_buffer(s, (N, H), "float32")
+ V_other = T.match_buffer(v_other, (N, H, D), v_dtype)
+ S_other = T.match_buffer(s_other, (N, H), "float32")
+
+ for n in T.serial(N):
+ for h in T.serial(H):
+ with T.block("merge"):
+ s_val = _var_cpu("float32")
+ s_other_val = _var_cpu("float32")
+ s_max = _var_cpu("float32")
+ scale = _var_cpu("float32")
+ other_scale = _var_cpu("float32")
+
+ s_val[0] = S[n, h]
+ s_other_val[0] = S_other[n, h]
+ s_max[0] = T.max(s_val[0], s_other_val[0])
+ s_val[0] = T.exp2(s_val[0] - s_max[0])
+ s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
+ scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
+ other_scale[0] = s_other_val[0] / (s_val[0] +
s_other_val[0])
+ for d in T.serial(D):
+ V[n, h, d] = V[n, h, d] * scale[0] + V_other[n, h, d]
* other_scale[0]
+ S[n, h] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]
+
+ return merge_state_inplace_cpu
+
+
def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target):
v_dtype_bytes = 2
VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4)
@@ -1577,6 +2172,175 @@ def _attention_sequence_prefill(
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str,
Any]):
+ group_size = h_q // h_kv
+ sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+ @T.prim_func
+ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
+ var_q: T.handle, # [total_len, h_q, d]
+ var_q_indptr: T.handle, # [batch_size + 1]
+ var_k: T.handle, # [total_len, h_kv, d]
+ var_v: T.handle, # [total_len, h_kv, d]
+ var_kv_indptr: T.handle, # [batch_size + 1]
+ var_q_rope_position: T.handle, # [total_q_len]
+ var_k_rope_pos_offset: T.handle, # [b]
+ var_output: T.handle, # [total_len, h_q, d]
+ var_lse: T.handle, # [total_len, h_q]
+ causal: T.int32,
+ rotary_mode: T.int32,
+ rope_scale: T.float32,
+ rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
+ ):
+ batch_size = T.int32(is_size_var=True)
+ qo_len = T.int32(is_size_var=True)
+ kv_len = T.int32(is_size_var=True)
+ q_indptr_elem_offset = T.int32(is_size_var=True)
+ kv_indptr_elem_offset = T.int32(is_size_var=True)
+ q_rope_position_elem_offset = T.int32(is_size_var=True)
+ k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+
+ q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+ q_indptr = T.match_buffer(
+ var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset
+ )
+ k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+ v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+ kv_indptr = T.match_buffer(
+ var_kv_indptr, (batch_size + 1,), "int32",
elem_offset=kv_indptr_elem_offset
+ )
+ q_rope_position = T.match_buffer(
+ var_q_rope_position, (qo_len,), "int32",
elem_offset=q_rope_position_elem_offset
+ )
+ k_rope_pos_offset = T.match_buffer(
+ var_k_rope_pos_offset, (batch_size,), "int32",
elem_offset=k_rope_pos_offset_elem_offset
+ )
+ output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+ lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint:
disable=unused-variable
+
+ for b in T.serial(batch_size):
+ with T.block("attn"):
+ softmax_sum = T.alloc_buffer([h_q], "float32")
+ m_prev = T.alloc_buffer([h_q], "float32")
+ m_new = T.alloc_buffer([h_q], "float32")
+ d_prev = T.alloc_buffer([h_q], "float32")
+ d_new = T.alloc_buffer([h_q], "float32")
+ p_sum = T.alloc_buffer([d], "float32")
+ max_score = T.alloc_buffer([h_q], "float32")
+ attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
+ exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
+ attention_score = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ query_val = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ key_val = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ result = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+
+ for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
+ for i in T.serial(h_q):
+ max_score[i] = -5e4
+ m_prev[i] = -5e4
+ d_prev[i] = 1.0
+
+ for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ for h in T.serial(h_q):
+ h_kv_idx = h // group_size
+
+ if _causal_mask(
+ causal,
+ row=q_idx,
+ col=k_idx,
+ kv_len=kv_indptr[b + 1] - kv_indptr[b],
+ qo_len=q_indptr[b + 1] - q_indptr[b],
+ ):
+ result[0] = 0.0
+ for d_idx in T.serial(d):
+ query_val[0] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ q,
+ q_rope_position[q_indptr[b] +
q_idx],
+ d,
+ rope_theta,
+ rope_scale,
+ (q_indptr[b] + q_idx, h, d_idx),
+ dtype,
+ rope_scaling,
+ ),
+ q[q_indptr[b] + q_idx, h, d_idx],
+ )
+
+ key_val[0] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ k,
+ k_rope_pos_offset[b] + k_idx,
+ d,
+ rope_theta,
+ rope_scale,
+ (kv_indptr[b] + k_idx, h_kv_idx,
d_idx),
+ dtype,
+ rope_scaling,
+ ),
+ k[kv_indptr[b] + k_idx, h_kv_idx,
d_idx],
+ )
+
+ result[0] += query_val[0] * key_val[0]
+ attention_score[0] = (
+ result[0] * sm_scale *
attn_score_scaling_factor
+ )
+ else:
+ attention_score[0] = -5e4 * sm_scale *
attn_score_scaling_factor
+ attention_scores[k_idx, h] = attention_score[0]
+ max_score[h] = T.max(max_score[h],
attention_score[0])
+ m_new[h] = T.max(m_prev[h], max_score[h])
+
+ for h in T.serial(h_q):
+ d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])
+
+ for h in T.serial(h_q):
+ softmax_sum[h] = 0.0
+ for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ exp_scores[k_idx, h] =
T.exp2(attention_scores[k_idx, h] - m_new[h])
+ softmax_sum[h] += exp_scores[k_idx, h]
+ d_new[h] += softmax_sum[h]
+ d_prev = d_new
+ m_prev = m_new
+
+ for h in T.serial(h_q):
+ h_kv_idx = h // group_size
+ for i in T.serial(d):
+ p_sum[i] = 0.0
+ for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ weight = exp_scores[v_idx, h] / d_new[h]
+ for i in T.serial(d):
+ p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx,
i] * weight
+ for i in T.serial(d):
+ output[q_indptr[b] + q_idx, h, i] = p_sum[i]
+ lse[q_indptr[b] + q_idx, h] = m_prev[h] +
T.log2(d_prev[h])
+
+ return batch_prefill_ragged_kv
+
+
def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str,
Any], target: Target):
# pylint: disable=line-too-long
NUM_BLKS = 16
@@ -1949,6 +2713,45 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype):
+ tx = 1
+
+ @T.prim_func
+ def copy_single_page_cpu(
+ var_pages: T.handle,
+ src_page_id: T.int64,
+ tgt_page_id: T.int64,
+ copy_length: T.int64,
+ ):
+ T.func_attr({"tir.is_scheduled": 1})
+ num_pages = T.int32()
+ pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size,
head_dim), dtype)
+
+ for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx):
+ for t in T.serial(tx):
+ with T.block("copy"):
+ T.where(b * tx + t < copy_length * num_heads * head_dim)
+ vh = T.axis.spatial(
+ num_heads,
+ T.Cast("int32", (b * tx + t) // (copy_length *
head_dim)),
+ )
+ vp = T.axis.spatial(
+ copy_length,
+ (b * tx + t) % (copy_length * head_dim) // head_dim,
+ )
+ vd = T.axis.spatial(
+ head_dim,
+ T.Cast(
+ "int32",
+ (b * tx + t) % head_dim,
+ ),
+ )
+ pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0,
vh, vp, vd]
+ pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1,
vh, vp, vd]
+
+ return copy_single_page_cpu
+
+
def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target):
tx = get_max_num_threads_per_block(target)
@@ -1996,6 +2799,55 @@ def _copy_single_page(num_heads, page_size, head_dim,
dtype, target: Target):
return copy_single_page
+def _compact_kv_copy_cpu(num_heads, head_dim, dtype):
+ tx = 8
+
+ @T.prim_func
+ def compact_kv_copy_cpu(
+ var_pages: T.handle,
+ var_copy_length_indptr: T.handle,
+ var_copy_src_dst_pos: T.handle,
+ batch_size: T.int32,
+ ):
+ T.func_attr({"tir.is_scheduled": 1})
+ num_pages = T.int32()
+ total_copy_length = T.int32()
+ copy_length_indptr_elem_offset = T.int32()
+ copy_src_dst_pos_elem_offset = T.int32()
+ pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16,
head_dim), dtype)
+ copy_length_indptr = T.match_buffer(
+ var_copy_length_indptr,
+ (batch_size + 1,),
+ "int32",
+ elem_offset=copy_length_indptr_elem_offset,
+ )
+ copy_src_dst_pos = T.match_buffer(
+ var_copy_src_dst_pos,
+ (2, total_copy_length),
+ "int32",
+ elem_offset=copy_src_dst_pos_elem_offset,
+ )
+
+ with T.block("root"):
+ for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1)
// tx):
+ for bhd_i in T.serial(tx):
+ b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim)
+ h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads
+ d: T.int32 = (bhd_o * tx + bhd_i) % head_dim
+ if (bhd_o * tx + bhd_i) < batch_size * num_heads *
head_dim:
+ for i in T.serial(copy_length_indptr[b + 1] -
copy_length_indptr[b]):
+ src_pos: T.int32 = copy_src_dst_pos[0,
copy_length_indptr[b] + i]
+ dst_pos: T.int32 = copy_src_dst_pos[1,
copy_length_indptr[b] + i]
+ pages[dst_pos // 16, 0, h, dst_pos % 16, d] =
pages[
+ src_pos // 16, 0, h, src_pos % 16, d
+ ]
+ pages[dst_pos // 16, 1, h, dst_pos % 16, d] =
pages[
+ src_pos // 16, 1, h, src_pos % 16, d
+ ]
+
+ return compact_kv_copy_cpu
+
+
def _compact_kv_copy(num_heads, head_dim, dtype, target: Target):
tx = get_max_num_threads_per_block(target)
diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index 9e4a7ed97e..fa0146afb6 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -82,6 +82,213 @@ def _check_tree_order(tree_order_indptr, tree_order, batch,
row, col, kv_len, qo
)
+def _declare_length_info(var_length_info, batch_size, sliding_window,
elem_offset):
+ return (
+ T.match_buffer(var_length_info, (3, batch_size), "int32",
elem_offset=elem_offset)
+ if sliding_window
+ else T.match_buffer(var_length_info, (batch_size,), "int32",
elem_offset=elem_offset)
+ )
+
+
+def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]):
+ """Generate tree attention kernel for batched tree attention.
+
+ Parameters
+ ----------
+ h_kv : int
+ Number of heads for key and value.
+ h_q : int
+ Number of heads for query.
+ d : int
+ Hidden dimension.
+ dtype : str
+ Data type.
+ target : Target
+ The target device.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The generated IR module.
+ """
+ group_size = h_q // h_kv
+ sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+ # fmt: off
+ @T.prim_func
+ def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long
+ var_q: T.handle, # [total_len, h_q, d]
+ var_q_indptr: T.handle, # [batch_size + 1]
+ var_k: T.handle, # [total_len, h_kv, d]
+ var_v: T.handle, # [total_len, h_kv, d]
+ var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the
same as q_indptr in this case
+ var_q_rope_position: T.handle, # [total_q_len]
+ var_mn_indptr: T.handle, # [batch_size + 1]
+ var_mask: T.handle, # [mn_indptr[batch_size]]
+ var_output: T.handle, # [total_len, h_q, d]
+ var_lse: T.handle, # [total_len, h_q]
+ rotary_mode: T.int32,
+ rope_scale: T.float32,
+ rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
+ batch_size: T.int32,
+ ):
+ qo_len = T.int32(is_size_var=True)
+ kv_len = T.int32(is_size_var=True)
+ q_indptr_elem_offset = T.int32(is_size_var=True)
+ kv_indptr_elem_offset = T.int32(is_size_var=True)
+ q_rope_position_elem_offset = T.int32(is_size_var=True)
+ mn_indptr_elem_offset = T.int32(is_size_var=True)
+ mask_elem_offset = T.int32(is_size_var=True)
+ tree_size = T.int32(is_size_var=True)
+
+ q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+ q_indptr = T.match_buffer(
+ var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset
+ )
+ k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+ v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+ kv_indptr = T.match_buffer(
+ var_kv_indptr, (batch_size + 1,), "int32",
elem_offset=kv_indptr_elem_offset
+ )
+ q_rope_position = T.match_buffer(
+ var_q_rope_position, (qo_len,), "int32",
elem_offset=q_rope_position_elem_offset
+ )
+ mn_indptr = T.match_buffer(
+ var_mn_indptr, (batch_size + 1,), "int32",
elem_offset=mn_indptr_elem_offset
+ )
+ mask = T.match_buffer(var_mask, (tree_size, 2), "int32",
elem_offset=mask_elem_offset)
+ output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+ lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint:
disable=unused-variable
+
+ for b in T.serial(batch_size):
+ with T.block("attn"):
+
+ softmax_sum = T.alloc_buffer([h_q], "float32")
+ m_prev = T.alloc_buffer([h_q], "float32")
+ m_new = T.alloc_buffer([h_q], "float32")
+ d_prev = T.alloc_buffer([h_q], "float32")
+ d_new = T.alloc_buffer([h_q], "float32")
+ p_sum = T.alloc_buffer([d], "float32")
+
+ max_score = T.alloc_buffer([h_q], "float32")
+ attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
+ exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
+ attention_score = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ query_val = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ key_val = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+ result = T.alloc_buffer(
+ [
+ 1,
+ ],
+ "float32",
+ )
+
+ for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
+ for i in T.serial(h_q):
+ max_score[i] = -5e4
+ m_prev[i] = -5e4
+ d_prev[i] = 1.0
+
+ for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ for h in T.serial(h_q):
+ h_kv_idx = h // group_size
+
+ if _check_tree_order(
+ row=q_idx,
+ col=k_idx,
+ batch=b,
+ tree_order=mask,
+ tree_order_indptr=mn_indptr,
+ kv_len=kv_indptr[b + 1] - kv_indptr[b],
+ qo_len=q_indptr[b + 1] - q_indptr[b],
+ ):
+ result[0] = 0.0
+ for d_idx in T.serial(d):
+ query_val[0] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ q,
+ q_rope_position[q_indptr[b] +
q_idx],
+ d,
+ rope_theta,
+ rope_scale,
+ (q_indptr[b] + q_idx, h, d_idx),
+ dtype,
+ rope_scaling,
+ ),
+ q[q_indptr[b] + q_idx, h, d_idx],
+ )
+
+ key_val[0] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(
+ k,
+ q_rope_position[kv_indptr[b] +
k_idx],
+ d,
+ rope_theta,
+ rope_scale,
+ (kv_indptr[b] + k_idx, h_kv_idx,
d_idx),
+ dtype,
+ rope_scaling,
+ ),
+ k[kv_indptr[b] + k_idx, h_kv_idx,
d_idx],
+ )
+
+ result[0] += query_val[0] * key_val[0]
+ attention_score[0] = (
+ result[0] * sm_scale *
attn_score_scaling_factor
+ )
+ else:
+ attention_score[0] = -5e4 * sm_scale *
attn_score_scaling_factor
+ attention_scores[k_idx, h] = attention_score[0]
+ max_score[h] = T.max(max_score[h],
attention_score[0])
+ m_new[h] = T.max(m_prev[h], max_score[h])
+
+ for h in T.serial(h_q):
+ d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])
+
+ for h in T.serial(h_q):
+ softmax_sum[h] = 0.0
+ for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ exp_scores[k_idx, h] =
T.exp2(attention_scores[k_idx, h] - m_new[h])
+ softmax_sum[h] += exp_scores[k_idx, h]
+ d_new[h] += softmax_sum[h]
+ d_prev = d_new
+ m_prev = m_new
+
+ for h in T.serial(h_q):
+ h_kv_idx = h // group_size
+ for i in T.serial(d):
+ p_sum[i] = 0.0
+ for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+ weight = exp_scores[v_idx, h] / d_new[h]
+ for i in T.serial(d):
+ p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx,
i] * weight
+ for i in T.serial(d):
+ output[q_indptr[b] + q_idx, h, i] = p_sum[i]
+ lse[q_indptr[b] + q_idx, h] = m_prev[h] +
T.log2(d_prev[h])
+
+ # fmt: on
+ # pylint: enable=line-too-long,too-many-branches
+ return batch_tree_attn
+
+
def tree_attn(
h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target
): # pylint: disable=unused-argument
@@ -437,6 +644,204 @@ def tree_attn(
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling:
Dict[str, Any]):
+ """Generate tree attention kernel for batched tree attention with paged
key-value cache.
+
+ Parameters
+ ----------
+ h_kv : int
+ Number of heads for key and value.
+ h_q : int
+ Number of heads for query.
+ d : int
+ Hidden dimension.
+ dtype : str
+ Data type.
+ target : Target
+ The target device.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The generated IR module.
+ """
+ # pylint: disable=import-outside-toplevel
+ from .kv_cache import (
+ _declare_length_info,
+ _get_kv_chunk_len,
+ _get_seq_offset,
+ )
+
+ global_symbol = "tree_attn_paged_kv_cpu"
+ sliding_window = False
+ group_size = h_q // h_kv
+ sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+ # pylint: disable=line-too-long,too-many-branches
+ # fmt: off
+ @T.prim_func(check_well_formed=False)
+ def tree_attn_paged_kv_cpu(
+ _0: T.int32, # pylint: disable=unused-argument
+ var_q: T.handle, # [total_len, h_q, d]
+ var_q_indptr: T.handle, # [batch_size + 1]
+ var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d]
+ var_page_indptr: T.handle, # [batch_size + 1]
+ var_page_values: T.handle, # [nnz_pages]
+ var_length_info: T.handle, # [b] when sliding window = False, or
otherwise [3, b]
+ var_k_rope_pos_offset: T.handle, # [b]
+ var_q_rope_position: T.handle, # [total_len]
+ var_output: T.handle, # [total_len, h_q, d]
+ var_lse: T.handle, # [total_len, h_q]
+ rotary_mode: T.int32,
+ rope_scale: T.float32,
+ rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
+ tree_order_indptr_handle: T.handle, # [batch_size + 1]
+ tree_order_handle: T.handle, # [total_len, 2]
+ ):
+ T.func_attr({"global_symbol": global_symbol})
+ batch_size = T.int32(is_size_var=True)
+ total_len = T.int32(is_size_var=True)
+ nnz_pages = T.int32(is_size_var=True)
+ max_num_pages = T.int32(is_size_var=True)
+ q_indptr_elem_offset = T.int32(is_size_var=True)
+ page_indptr_elem_offset = T.int32(is_size_var=True)
+ page_values_elem_offset = T.int32(is_size_var=True)
+ k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+ q_rope_position_elem_offset = T.int32(is_size_var=True)
+ length_info_elem_offset = T.int32(is_size_var=True)
+ tree_order_elem_offset = T.int32(is_size_var=True)
+ tree_order_indptr_elem_offset = T.int32(is_size_var=True)
+
+ q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
+ q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset)
+ pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d),
dtype)
+ page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,),
"int32", elem_offset=page_indptr_elem_offset)
+ page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32",
elem_offset=page_values_elem_offset)
+ k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset,
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
+ q_rope_position = T.match_buffer(var_q_rope_position, (total_len,),
"int32", elem_offset=q_rope_position_elem_offset)
+ output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
+ lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint:
disable=unused-variable
+ tree_order_indptr = T.match_buffer(
+ tree_order_indptr_handle,
+ (batch_size + 1,),
+ "int32",
+ elem_offset=tree_order_indptr_elem_offset,
+ )
+ total_tree_order_len = T.int32(is_size_var=True)
+ tree_order = T.match_buffer(
+ tree_order_handle,
+ (total_tree_order_len, 2),
+ "int32",
+ elem_offset=tree_order_elem_offset,
+ )
+ # The length information of the sequences.
+ # - It is in shape `(3, batch_size)` when sliding window is enabled.
+ # For a sequence "i", location
+ # - "(0, i)" is the number of KV slots used in the last page of the
seq ("last_page_len"),
+ # - "(1, i)" is the starting offset of the sliding window in the seq,
+ # - "(2, i)" is the attn sink length of the sequence.
+ # - It is in shape `(batch_size,)` when sliding window is disabled,
+ # denoting the "last_page_len".
+ length_info = _declare_length_info(var_length_info, batch_size,
sliding_window, length_info_elem_offset)
+
+
+ T.Assert(
+ rotary_mode == T.int32(0), "Inline rotary mode is not supported in
tree attention."
+ )
+
+ for h_qo in T.serial(h_q):
+ for b_idx in T.serial(batch_size):
+ with T.block("attn"):
+ O_local = T.alloc_buffer((d, ), "float32")
+ Q_local = T.alloc_buffer((d, ), "float32")
+ K_local = T.alloc_buffer((d, ), "float32")
+ V_local = T.alloc_buffer((d, ), "float32")
+
+ kv_chunk_len = T.alloc_buffer((1, ), "int32")
+
+ m_val = T.alloc_buffer((1, ), "float32")
+ new_m = T.alloc_buffer((1, ), "float32")
+ d_val = T.alloc_buffer((1, ), "float32")
+ S_val = T.alloc_buffer((1, ), "float32")
+ scale_O = T.alloc_buffer((1, ), "float32")
+ factor = T.alloc_buffer((1, ), "float32")
+ cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
+ cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
+ kv_chunk_len[0] = T.if_then_else(
+ cur_page_indptr_begin != cur_page_indptr_end,
+ _get_kv_chunk_len(cur_page_indptr_end -
cur_page_indptr_begin, 16, b_idx, length_info, sliding_window),
+ 0
+ )
+
+ for q_idx in T.serial(q_indptr[b_idx + 1] -
q_indptr[b_idx]):
+ #init m, d, O
+ m_val[0] = -5e4
+ d_val[0] = 1.0
+ for d_idx in T.serial(d):
+ O_local[d_idx] = 0.0
+ curl_q: T.int32 = q_indptr[b_idx] + q_idx
+
+ for d_idx in T.serial(d):
+ Q_local[d_idx] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(q, q_rope_position[curl_q], d,
rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
+ q[curl_q, h_qo, d_idx]
+ )
+ for row_idx in T.serial(max_num_pages * 16):
+ if row_idx < kv_chunk_len[0]:
+ page_no: T.int32(is_size_var=True) =
page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx,
length_info, sliding_window) // 16)]
+ page_offset: T.int32(is_size_var=True) =
_get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16
+
+ # Load KV
+ for d_idx in T.serial(d):
+ K_local[d_idx] = T.if_then_else(
+ rotary_mode == 1,
+ _rope(pages, k_rope_pos_offset[b_idx]
+ row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size,
page_offset, d_idx), dtype, rope_scaling),
+ pages[page_no, 0, h_qo // group_size,
page_offset, d_idx]
+ )
+ V_local[d_idx] = pages[page_no, 1, h_qo //
group_size, page_offset, d_idx]
+
+ # Compute S
+ S_val[0] = 0.0
+ for d_idx in T.serial(d):
+ S_val[0] += Q_local[d_idx] * K_local[d_idx]
+ S_val[0] *= attn_score_scaling_factor *
sm_scale
+
+ # update m_val, d_val , O_local
+ if _check_tree_order(
+ tree_order_indptr=tree_order_indptr,
+ tree_order=tree_order,
+ batch=b_idx,
+ row=q_idx,
+ col=row_idx,
+ kv_len=kv_chunk_len[0],
+ qo_len=q_indptr[b_idx + 1] -
q_indptr[b_idx],
+ ):
+ new_m[0] = T.max(m_val[0], S_val[0])
+ else:
+ S_val[0] = -5e4
+ # update d_val
+ d_val[0] *= T.exp2(m_val[0] - new_m[0])
+ d_val[0] += T.exp2(S_val[0] - new_m[0])
+
+ # restore O_local then update O_local
+ scale_O[0] = T.exp2(m_val[0] - new_m[0])
+ m_val[0] = new_m[0]
+ factor[0] = T.exp2(S_val[0] - m_val[0])
+ for d_idx in T.serial(d):
+ O_local[d_idx] = O_local[d_idx] *
scale_O[d_idx]
+
+
+ for d_idx in T.serial(d):
+ O_local[d_idx] += V_local[d_idx] *
factor[0]
+ # Store Output
+ for d_idx in T.serial(d):
+ O_local[d_idx] = O_local[d_idx] /d_val[0]
+ output[curl_q, h_qo, d_idx] = O_local[d_idx]
+ lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0])
+ return tree_attn_paged_kv_cpu
+
+
def tree_attn_with_paged_kv_cache(
h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target
):
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index ccd726a6ec..5dd470f00d 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -34,6 +34,19 @@
#include <android/api-level.h>
#endif
+#if defined(__linux__) || defined(__ANDROID__)
+#include <sys/sysinfo.h>
+#endif
+
+#ifdef _WIN32
+#include <windows.h>
+#endif
+
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#include <sys/sysctl.h>
+#endif
+
namespace tvm {
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
@@ -43,6 +56,41 @@ class CPUDeviceAPI final : public DeviceAPI {
if (kind == kExist) {
*rv = 1;
}
+
+ switch (kind) {
+ case kExist:
+ break;
+ case kTotalGlobalMemory: {
+#if defined(__linux__) || defined(__ANDROID__)
+ struct sysinfo info;
+ if (sysinfo(&info) == 0) {
+ *rv = static_cast<int64_t>(info.totalram) * info.mem_unit; //
Convert to bytes
+ } else {
+ *rv = -1;
+ }
+#elif defined(_WIN32)
+ MEMORYSTATUSEX statex;
+ statex.dwLength = sizeof(statex);
+ if (GlobalMemoryStatusEx(&statex)) {
+ *rv = static_cast<int64_t>(statex.ullTotalPhys); // Total physical
memory in bytes
+ } else {
+ *rv = -1;
+ }
+#elif defined(__APPLE__)
+ int64_t mem;
+ size_t size = sizeof(mem);
+ if (sysctlbyname("hw.memsize", &mem, &size, nullptr, 0) == 0) {
+ *rv = mem;
+ } else {
+ *rv = -1;
+ }
+#else
+ *rv = -1;
+#endif
+ }
+ default:
+ break;
+ }
}
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType
type_hint) final {
void* ptr;
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 8e5dfb4bd8..0b83bb426d 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1366,7 +1366,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// Create the auxiliary data manager for attention.
// We only use the merged aux data for CUDA, since direct pointer
// operations may have issues on other platforms.
- if (device_.device_type == DLDeviceType::kDLCUDA) {
+ if (device_.device_type == DLDeviceType::kDLCUDA ||
+ device_.device_type == DLDeviceType::kDLCPU) {
aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_,
device,
preferred_host_device, copy_stream_);
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
new file mode 100644
index 0000000000..9487bbf860
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
@@ -0,0 +1,956 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import enum
+import itertools
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pytest
+import scipy.special
+
+import tvm
+import tvm.testing
+from tvm import dlight as dl
+from tvm.relax.frontend.nn.llm.kv_cache import (
+ _attention_decode_cpu,
+ _attention_prefill_cpu,
+ _attention_prefill_ragged_cpu,
+ _compact_kv_copy_cpu,
+ _copy_single_page_cpu,
+ _kv_cache_debug_get_kv,
+ _kv_cache_transpose_append,
+ _merge_state_inplace_cpu,
+ llama_rope_with_position_map,
+ tree_attn_cpu,
+ tree_attn_with_paged_kv_cache_cpu,
+)
+from tvm.runtime import ShapeTuple
+
+reserved_nseq = 32
+maximum_total_seq_length = 2048
+prefill_chunk_size = 512
+page_size = 16
+num_layers = 4
+num_qo_heads = 32
+num_kv_heads = 4
+head_dim = None
+rope_scale = 1.0
+rope_theta = 1e4
+rope_scaling = {}
+dtype = None
+device = tvm.cpu()
+
+fclear = None
+fadd_sequence = None
+fremove_sequence = None
+ffork_sequence = None
+fenable_sliding_window_for_seq = None
+fpopn = None
+fbegin_forward = None
+fend_forward = None
+fcommit_accepted_token_tree_nodes = None
+fattention_with_fuse_qkv = None
+fis_empty = None
+fdebug_get_kv = None
+
+ftranspose_append = None
+fcopy_cache = None
+fattn_prefill = None
+fattn_decode = None
+fattn_prefill_sliding_window = None
+fattn_decode_sliding_window = None
+fattn_prefill_ragged = None
+fattn_prefill_with_tree_mask = None
+fattn_prefill_with_tree_mask_paged_kv_cache = None
+fmerge_state = None
+fsplit_rotary = None
+fattention_rotary = None
+fcopy_single_page = None
+fcompact_copy = None
+
+
+def set_global_func(head_dim, dtype):
+ global fclear, fadd_sequence, fremove_sequence, ffork_sequence,
fenable_sliding_window_for_seq
+ global fpopn, fbegin_forward, fend_forward,
fcommit_accepted_token_tree_nodes
+ global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
+ global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
+ global fattn_prefill_ragged, fattn_prefill_with_tree_mask,
fattn_prefill_with_tree_mask_paged_kv_cache
+ global fattn_prefill_sliding_window, fattn_decode_sliding_window
+ global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page,
fcompact_copy
+
+ fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
+ fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+ fremove_sequence =
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+ ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+ fenable_sliding_window_for_seq = tvm.get_global_func(
+ "vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"
+ )
+ fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
+ fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+ fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
+ fcommit_accepted_token_tree_nodes = tvm.get_global_func(
+ "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
+ )
+ fattention_with_fuse_qkv = tvm.get_global_func(
+ "vm.builtin.attention_kv_cache_attention_with_fused_qkv"
+ )
+ fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
+ fdebug_get_kv =
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
+
+ target = tvm.target.Target.from_device(device)
+ builts = []
+ for tir_func in [
+ _kv_cache_transpose_append(num_kv_heads, head_dim, dtype),
+ _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype),
+ _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype,
False, rope_scaling),
+ _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype,
False, rope_scaling),
+ _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype,
True, rope_scaling),
+ _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype,
True, rope_scaling),
+ _attention_prefill_ragged_cpu(num_kv_heads, num_qo_heads, head_dim,
dtype, rope_scaling),
+ tree_attn_cpu(num_kv_heads, num_qo_heads, head_dim, dtype,
rope_scaling),
+ tree_attn_with_paged_kv_cache_cpu(
+ num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling
+ ),
+ _merge_state_inplace_cpu(dtype),
+ llama_rope_with_position_map(
+ rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads,
dtype, rope_scaling
+ ),
+ _copy_single_page_cpu(num_kv_heads, page_size, head_dim, dtype),
+ _compact_kv_copy_cpu(num_kv_heads, head_dim, dtype),
+ ]:
+ mod = tvm.IRModule({"main": tir_func})
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
+ f = tvm.build(mod["main"], target=target)
+ builts.append(f.entry_func)
+
+ (
+ ftranspose_append,
+ fcopy_cache,
+ fattn_prefill,
+ fattn_decode,
+ fattn_prefill_sliding_window,
+ fattn_decode_sliding_window,
+ fattn_prefill_ragged,
+ fattn_prefill_with_tree_mask,
+ fattn_prefill_with_tree_mask_paged_kv_cache,
+ fmerge_state,
+ fsplit_rotary,
+ fcopy_single_page,
+ fcompact_copy,
+ ) = builts
+
+
+def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
+ fcreate =
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
+ cache = fcreate(
+ tvm.runtime.ShapeTuple(
+ [
+ reserved_nseq,
+ maximum_total_seq_length,
+ prefill_chunk_size,
+ page_size,
+ int(support_sliding_window),
+ ]
+ ),
+ tvm.runtime.ShapeTuple([0, num_layers]),
+ num_qo_heads,
+ num_kv_heads,
+ head_dim,
+ rope_mode,
+ rope_scale,
+ rope_theta,
+ tvm.nd.empty((), dtype, device=device),
+ ftranspose_append,
+ fattn_prefill,
+ fattn_decode,
+ fattn_prefill_sliding_window,
+ fattn_decode_sliding_window,
+ fattn_prefill_ragged,
+ fmerge_state,
+ fsplit_rotary,
+ fcopy_single_page,
+ fcopy_cache,
+ fcompact_copy,
+ fattn_prefill_with_tree_mask,
+ fattn_prefill_with_tree_mask_paged_kv_cache,
+ None,
+ False,
+ )
+ return cache
+
+
+class RopeMode(enum.IntEnum):
+ """The RoPE mode of the Paged KV cache.
+ If it is none, the KV cache will not apply RoPE to q and k.
+ If it is normal, RoPE will be applied to k before adding k to cache.
+ Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+ """
+
+ NONE = 0
+ NORMAL = 1
+ INLINE = 2
+
+
[email protected](
+ params=itertools.chain(
+ itertools.product(
+ [64, 128],
+ ["float32", "float16"],
+ [RopeMode.NORMAL],
+ [False],
+ ),
+ itertools.product(
+ [128],
+ ["float16"],
+ [RopeMode.NONE, RopeMode.INLINE],
+ [False, True],
+ ),
+ )
+)
+def kv_cache_and_config(request):
+ global head_dim, dtype
+ head_dim, dtype, rope_mode, support_sliding_window = request.param
+ set_global_func(head_dim, dtype)
+ return create_kv_cache(*request.param), rope_mode, support_sliding_window
+
+
+def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
+ for seq_id in seq_ids:
+ keys_expected = expected_k[seq_id]
+ values_expected = expected_v[seq_id]
+ assert keys_expected.shape == values_expected.shape
+ seq_length = expected_k[seq_id].shape[1]
+ keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device)
+ values = tvm.nd.empty(values_expected.shape, dtype=dtype,
device=device)
+ fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values)
+ tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3,
atol=1e-3)
+ tvm.testing.assert_allclose(values.numpy(), values_expected,
rtol=1e-3, atol=1e-3)
+
+
+def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] =
None):
+ # x: (N, H, D)
+ assert len(x.shape) == 3
+ nfeat = x.shape[-1]
+ nfeat_half = x.shape[-1] // 2
+ x = x.astype("float32")
+ y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1)
+
+ inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") /
nfeat))
+ t = (
+ np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype)
+ if offset_list is None
+ else (np.array(offset_list, dtype=inv_freq.dtype) + offset)
+ )
+ freqs = np.einsum("i,j->ij", t, inv_freq)
+ emb = np.concatenate((freqs, freqs), axis=-1)
+ cos_values = np.cos(emb)
+ sin_values = np.sin(emb)
+
+ return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj",
sin_values, y)
+
+
+def apply_attention(
+ kv_cache,
+ rope_mode: RopeMode,
+ batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
+ cached_k: Dict[int, np.ndarray],
+ cached_v: Dict[int, np.ndarray],
+ sliding_window_sizes: Optional[List[int]] = None,
+ attn_sink_sizes: Optional[List[int]] = None,
+ token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
+ accepted_leaf_indices: Optional[List[int]] = None,
+) -> None:
+ seq_ids = []
+ append_lengths = []
+ for i, (seq_id, append_length) in enumerate(batch):
+ fork_parent_id = None
+ if isinstance(seq_id, tuple):
+ # Fork sequence
+ seq_id, fork_parent_id, fork_pos = seq_id
+ batch[i] = (seq_id, append_length)
+ seq_ids.append(seq_id)
+ append_lengths.append(append_length)
+ if fork_parent_id is not None:
+ assert fork_parent_id in cached_k
+ assert seq_id not in cached_k
+ ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+ if fork_pos == -1:
+ cached_k[seq_id] = cached_k[fork_parent_id]
+ cached_v[seq_id] = cached_v[fork_parent_id]
+ else:
+ cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+ cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
+ elif seq_id not in cached_k:
+ fadd_sequence(kv_cache, seq_id)
+ cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
+ cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
+
+ flattened_token_tree_parent_ptr = None
+ token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in
batch]
+
+ if token_tree_parent_ptr_list:
+ assert len(token_tree_node_depths_list) == len(seq_ids)
+ if accepted_leaf_indices is not None:
+ assert len(accepted_leaf_indices) == len(seq_ids)
+ flattened_token_tree_parent_ptr = []
+ for i, (token_tree_parent_ptr, append_length) in enumerate(
+ zip(token_tree_parent_ptr_list, append_lengths)
+ ):
+ assert len(token_tree_parent_ptr) >= append_length
+ # parent pointer for the last `append_length` nodes (the new
tokens)
+ append_token_tree_parent_ptr =
token_tree_parent_ptr[-append_length:]
+ flattened_token_tree_parent_ptr += append_token_tree_parent_ptr
+ token_tree_node_depths = []
+ for parent in token_tree_parent_ptr:
+ token_tree_node_depths.append(
+ 0 if parent == -1 else token_tree_node_depths[parent] + 1
+ )
+ # depth of each node in the tree (this contains more than the last
`append_length` nodes)
+ token_tree_node_depths_list[i] = token_tree_node_depths
+
+ fbegin_forward(
+ kv_cache,
+ ShapeTuple(seq_ids),
+ ShapeTuple(append_lengths),
+ (
+ ShapeTuple(flattened_token_tree_parent_ptr)
+ if flattened_token_tree_parent_ptr is not None
+ else None
+ ),
+ )
+
+ global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype)
+ global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
+ global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
+
+ q_array = []
+ for i, (seq_id, append_length) in enumerate(batch):
+ new_q = np.random.rand(num_layers, append_length, num_qo_heads,
head_dim).astype(dtype)
+ new_k = np.random.rand(num_layers, append_length, num_kv_heads,
head_dim).astype(dtype)
+ new_v = np.random.rand(num_layers, append_length, num_kv_heads,
head_dim).astype(dtype)
+ q_array.append(new_q)
+
+ rope_offset = cached_k[seq_id].shape[1]
+ if token_tree_parent_ptr_list is not None:
+ prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length
+ assert prev_tree_size >= 0
+ rope_offset -= prev_tree_size
+ cached_k[seq_id] = np.concatenate(
+ [
+ cached_k[seq_id],
+ np.stack(
+ [
+ (
+ new_k[l]
+ if rope_mode != RopeMode.NORMAL
+ else f_apply_rotary(
+ new_k[l],
+ rope_offset,
+ rope_scale,
+ rope_theta,
+ (
+
token_tree_node_depths_list[i][-append_length:]
+ if token_tree_node_depths_list[i] is not
None
+ else None
+ ),
+ )
+ )
+ for l in range(num_layers)
+ ],
+ axis=0,
+ ),
+ ],
+ axis=1,
+ )
+ cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1)
+ global_new_q = np.concatenate([global_new_q, new_q], axis=1)
+ global_new_k = np.concatenate([global_new_k, new_k], axis=1)
+ global_new_v = np.concatenate([global_new_v, new_v], axis=1)
+
+ for layer_id in range(num_layers):
+ queries_np = global_new_q[layer_id]
+ keys_np = global_new_k[layer_id]
+ values_np = global_new_v[layer_id]
+ qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np],
axis=1), device)
+ outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
+ fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+
+ # Compute attention expected results.
+ outputs = np.expand_dims(outputs.numpy(), axis=0)
+ sum_length = 0
+ for i, (seq_id, append_length) in enumerate(batch):
+ assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >=
append_length
+
+ rope_offset = cached_k[seq_id].shape[1]
+ if token_tree_parent_ptr_list is not None:
+ rope_offset -= len(token_tree_parent_ptr_list[i])
+ else:
+ rope_offset -= append_length
+ q_seq = (
+ q_array[i][layer_id]
+ if rope_mode == RopeMode.NONE
+ else f_apply_rotary(
+ q_array[i][layer_id],
+ rope_offset,
+ rope_scale,
+ rope_theta,
+ (
+ token_tree_node_depths_list[i][-append_length:]
+ if token_tree_node_depths_list[i] is not None
+ else None
+ ),
+ )
+ ).transpose(1, 0, 2)
+ k_seq = (
+ cached_k[seq_id][layer_id]
+ if rope_mode != RopeMode.INLINE
+ else f_apply_rotary(
+ cached_k[seq_id][layer_id],
+ 0,
+ rope_scale,
+ rope_theta,
+ (
+ (
+ list(range(rope_offset))
+ + [depth + rope_offset for depth in
token_tree_node_depths_list[i]]
+ )
+ if token_tree_node_depths_list[i] is not None
+ else None
+ ),
+ )
+ ).transpose(1, 2, 0)
+ v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
+
+ k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0)
+ v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0)
+ softmax_input = (q_seq.astype("float32") @
k_seq.astype("float32")) / np.sqrt(head_dim)
+ softmax_shape = softmax_input.shape
+ assert softmax_shape[-2] == append_length
+ length_diff = softmax_shape[-1] - softmax_shape[-2]
+ assert length_diff >= 0
+ mask = np.tril(
+ np.full_like(softmax_input, np.finfo("float32").max),
k=length_diff
+ ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min),
k=length_diff + 1)
+ if token_tree_parent_ptr_list is not None:
+ tree_size = len(token_tree_parent_ptr_list[i])
+ tree_mask = np.full(
+ (tree_size, tree_size), np.finfo("float32").min,
dtype="float32"
+ )
+ for i, parent in enumerate(token_tree_parent_ptr_list[i]):
+ if parent != -1:
+ tree_mask[i] = tree_mask[parent]
+ tree_mask[i, i] = np.finfo("float32").max
+ tree_mask = np.broadcast_to(tree_mask, (num_qo_heads,
*tree_mask.shape))
+ mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :]
+
+ softmax_input = np.minimum(softmax_input, mask)
+
+ results = np.expand_dims(
+ (scipy.special.softmax(softmax_input, axis=-1) @
v_seq.astype("float32")).transpose(
+ 1, 0, 2
+ ),
+ axis=0,
+ ).astype(dtype)
+
+ tvm.testing.assert_allclose(
+ outputs[:, sum_length : sum_length + append_length, ...],
+ results,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ sum_length += append_length
+ fend_forward(kv_cache)
+
+ if accepted_leaf_indices is not None:
+ seq_ids = [seq_id for seq_id, _ in batch]
+ fcommit_accepted_token_tree_nodes(
+ kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
+ )
+ for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
+ zip(accepted_leaf_indices, batch)
+ ):
+ tree_path = []
+ node = accepted_leaf_idx
+ while node != -1:
+ tree_path.append(node)
+ node = token_tree_parent_ptr_list[i][node]
+ offset = cached_k[seq_id].shape[1] - append_length
+ length_to_pop = append_length - len(tree_path)
+ assert 0 <= length_to_pop <= append_length
+ for dst_pos, src_pos in enumerate(reversed(tree_path)):
+ if dst_pos == src_pos:
+ continue
+ cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][
+ :, offset + src_pos, ...
+ ]
+ cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][
+ :, offset + src_pos, ...
+ ]
+ if length_to_pop > 0:
+ cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...]
+ cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...]
+
+ for seq_id, _ in batch:
+ if sliding_window_sizes is not None and len(sliding_window_sizes) >
seq_id:
+ assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes)
> seq_id
+ sliding_window_size = sliding_window_sizes[seq_id]
+ attn_sink_size = attn_sink_sizes[seq_id]
+ if sliding_window_size == 0:
+ continue
+ if cached_k[seq_id].shape[1] > sliding_window_size:
+ # Apply sliding window and sink to cached kv.
+ length_to_slide = cached_k[seq_id].shape[1] -
sliding_window_size
+ cached_k[seq_id] = np.concatenate(
+ [
+ cached_k[seq_id][:, :attn_sink_size, ...],
+ cached_k[seq_id][:, attn_sink_size + length_to_slide
:, ...],
+ ],
+ axis=1,
+ )
+ cached_v[seq_id] = np.concatenate(
+ [
+ cached_v[seq_id][:, :attn_sink_size, ...],
+ cached_v[seq_id][:, attn_sink_size + length_to_slide
:, ...],
+ ],
+ axis=1,
+ )
+ assert cached_k[seq_id].shape[1] == sliding_window_size
+
+ # Verify
+ verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
+
+
+def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ # Normal RoPE mode under sliding window settings is not supported.
+ return
+ fclear(kv_cache)
+
+ # Prefill.
+ operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5,
20)]]
+ operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
+ operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5,
12), (7, 11)]]
+ # Decode
+ operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1),
(7, 1), (8, 1)]]
+ operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1),
(7, 1), (8, 1)]]
+ operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
+ operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
+
+ cached_k = {}
+ cached_v = {}
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+
+def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ # Normal RoPE mode under sliding window settings is not supported.
+ return
+ fclear(kv_cache)
+
+ num_sequences = 5
+ batch = [(seq_id, 1) for seq_id in range(num_sequences)]
+ cached_k = {}
+ cached_v = {}
+ for seq_id_to_remove in range(num_sequences):
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+ # Remove sequence.
+ fremove_sequence(kv_cache, seq_id_to_remove)
+ cached_k.pop(seq_id_to_remove)
+ cached_v.pop(seq_id_to_remove)
+ verify_cached_kv(
+ kv_cache,
+ seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id !=
seq_id_to_remove],
+ expected_k=cached_k,
+ expected_v=cached_v,
+ )
+
+
+def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ # Normal RoPE mode under sliding window settings is not supported.
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+ # Fork existing sequences.
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)],
cached_k, cached_v)
+ # 0 <- 5 <- 6,8,9
+ # 0 <- 7
+ # 3 <- 4
+ # Mixture of decode and prefill.
+ operation_seq = [
+ [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
+ [(7, 1), (6, 1), (8, 1), (9, 1)],
+ [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)],
+ [(7, 10), (6, 2), (8, 3), (9, 4)],
+ ]
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+ apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15),
14)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17),
19)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80),
10)], cached_k, cached_v)
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
+ cached_k,
+ cached_v,
+ )
+
+ operation_seq = [
+ [(6, 1), (11, 1), (13, 1), (9, 1)],
+ [(10, 1), (16, 1), (18, 1), (19, 1)],
+ [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+ [(10, 10), (6, 2), (8, 3), (19, 4)],
+ ]
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+ num_sequence = 20
+ for i in range(num_sequence):
+ fremove_sequence(kv_cache, i)
+ cached_k.pop(i)
+ cached_v.pop(i)
+ verify_cached_kv(
+ kv_cache,
+ seq_ids=list(range(i + 1, num_sequence)),
+ expected_k=cached_k,
+ expected_v=cached_v,
+ )
+
+ assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
+
+ # Test fork after page recycle
+ apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+ apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k,
cached_v)
+
+
+def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ # Normal RoPE mode under sliding window settings is not supported.
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v)
+ # Fork existing sequences.
+ apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k,
cached_v)
+ apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k,
cached_v)
+ # 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11
+ # | | | |
+ # 4 6 8 10
+ # Decode.
+ operation_seq = [
+ [(3, 1), (6, 1), (9, 1)],
+ [(4, 1), (8, 1), (10, 1)],
+ [(5, 1), (7, 1), (11, 1)],
+ ]
+ for batch in operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+ num_sequence = 12
+ for i in range(num_sequence):
+ fremove_sequence(kv_cache, i)
+ cached_k.pop(i)
+ cached_v.pop(i)
+ verify_cached_kv(
+ kv_cache,
+ seq_ids=list(range(i + 1, num_sequence)),
+ expected_k=cached_k,
+ expected_v=cached_v,
+ )
+
+ assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
+
+
+def test_paged_attention_kv_cache_popn(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+ apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
+
+ popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
+ for seq_id, pop_length in popn_operations:
+ fpopn(kv_cache, seq_id, pop_length)
+ if pop_length != 0:
+ cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
+ cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
+ verify_cached_kv(kv_cache, seq_ids=list(range(4)),
expected_k=cached_k, expected_v=cached_v)
+
+ num_sequence = 5
+ for seq_id in range(num_sequence):
+ fremove_sequence(kv_cache, seq_id)
+ verify_cached_kv(
+ kv_cache,
+ seq_ids=list(range(seq_id + 1, num_sequence)),
+ expected_k=cached_k,
+ expected_v=cached_v,
+ )
+
+ assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
+
+
+def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ sliding_window_sizes = [20, 25, 30, 35, 40]
+ attn_sink_sizes = [6, 4, 8, 3, 7]
+ for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+ zip(sliding_window_sizes, attn_sink_sizes)
+ ):
+ fadd_sequence(kv_cache, seq_id)
+ fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size,
attn_sink_size)
+ cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+ cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+
+ # Prefill.
+ operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]]
+ operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]]
+ operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]]
+ for batch in operation_seq:
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ batch,
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # Decode
+ batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]
+ for _ in range(20):
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ batch,
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+
+
+def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ sliding_window_sizes = [30, 35, 40]
+ attn_sink_sizes = [15, 20, 25]
+ for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+ zip(sliding_window_sizes, attn_sink_sizes)
+ ):
+ fadd_sequence(kv_cache, seq_id)
+ fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size,
attn_sink_size)
+ cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+ cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 12), (1, 18), (2, 28)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [12, 18, 25+3]
+ sliding_window_sizes += [0, 0, 0]
+ attn_sink_sizes += [0, 0, 0]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [12, 18, 25+3, 18, 38, 36]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
+ sliding_window_sizes += [25]
+ attn_sink_sizes += [24]
+ ffork_sequence(kv_cache, 3, 6, 18)
+ fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1],
attn_sink_sizes[-1])
+ cached_k[6] = cached_k[3][::, :18]
+ cached_v[6] = cached_v[3][::, :18]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(3, 10), (6, 12)],
+ cached_k,
+ cached_v,
+ sliding_window_sizes,
+ attn_sink_sizes,
+ )
+ # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
+
+
+def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
+ kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+ if support_sliding_window:
+ # Normal RoPE mode under sliding window settings is not supported.
+ return
+ if rope_mode == RopeMode.INLINE:
+ # Inline RoPE mode is not supported for tree attention.
+ return
+ fclear(kv_cache)
+
+ cached_k = {}
+ cached_v = {}
+ # Prefill 4 sequences
+ apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
+ # Tree attention
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 7), (1, 15), (2, 10), (3, 14)],
+ cached_k,
+ cached_v,
+ token_tree_parent_ptr_list=[
+ [-1, 0, 0, 1, 1, 2, 2], # complete binary tree of height 3
+ [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], # complete binary
tree of height 4
+ [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
+ [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9], # two complete
binary trees of height 3
+ ],
+ accepted_leaf_indices=[6, 11, 6, 13],
+ )
+ # Do 5 rounds of decode.
+ for _ in range(5):
+ apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)],
cached_k, cached_v)
+
+ # Test the cases where all trees are chains.
+ fclear(kv_cache)
+ cached_k = {}
+ cached_v = {}
+ # Prefill 4 sequences
+ apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
+ # Tree attention
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(0, 7), (1, 15), (2, 10), (3, 14)],
+ cached_k,
+ cached_v,
+ token_tree_parent_ptr_list=[
+ [-1, 0, 1, 2, 3, 4, 5], # complete binary tree of height 7
+ [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], # chain of
length 15
+ [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
+ [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length
14
+ ],
+ accepted_leaf_indices=[2, 6, -1, 4],
+ )
+ # Do 5 rounds of decode.
+ for _ in range(5):
+ apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)],
cached_k, cached_v)
+
+ # Test the cases of tree attn with cached kv.
+ fclear(kv_cache)
+ cached_k = {}
+ cached_v = {}
+ # Prefill 4 sequences
+ apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
+ # Do 5 rounds of tree decode.
+ num_seq = 4
+ for i in range(5):
+ num_leaf_nodes = 2**i
+ parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)]
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)],
+ cached_k,
+ cached_v,
+ token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)],
+ accepted_leaf_indices=(
+ None if i != 4 else [2, 6, -1, 4]
+ ), # Leaf nodes are committed all at once at the end.
+ )
+
+
+if __name__ == "__main__":
+ HEAD_DIMS = [64, 128]
+ DTYPES = ["float16", "float32"]
+ ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]
+ SUPPORT_SLIDING_WINDOW = [False, True]
+ for head_dim, dtype, rope_mode, support_sliding_window in
itertools.product(
+ HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
+ ):
+ set_global_func(head_dim, dtype)
+ cache = create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window)
+ cache_and_config = (cache, rope_mode, support_sliding_window)
+ test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
+ test_paged_attention_kv_cache_remove_sequence(cache_and_config)
+ test_paged_attention_kv_cache_fork_sequence(cache_and_config)
+ test_paged_attention_kv_cache_popn(cache_and_config)
+ test_paged_attention_kv_cache_sliding_window(cache_and_config)
+ test_paged_attention_kv_cache_tree_attn(cache_and_config)
+ test_paged_attention_kv_cache_unlimited_depth(cache_and_config)