This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d93f4ad4b8 [KVCache] Add KV Cache for CPU Runtime (#17615)
d93f4ad4b8 is described below

commit d93f4ad4b88b224523e258f04a4bae450eddb504
Author: Mengshiun Yu <[email protected]>
AuthorDate: Sun Feb 2 16:48:42 2025 -0500

    [KVCache] Add KV Cache for CPU Runtime (#17615)
    
    * [KVCache] Add KV Cache for CPU Runtime
    
    This PR introduces KV Cache support for the CPU runtime:
    * Implementation of KV Cache TIR for CPU-based processing.
    * Updates to the relevant runtime components to integrate KV Cache.
    
    Co-authored-by: HMZ <[email protected]>
    Co-authored-by: ShouChenChiu <[email protected]>
    
    * Add unit test
    
    ---------
    
    Co-authored-by: HMZ <[email protected]>
    Co-authored-by: ShouChenChiu <[email protected]>
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py       | 882 ++++++++++++++++++-
 python/tvm/relax/frontend/nn/llm/tree_attn.py      | 405 +++++++++
 src/runtime/cpu_device_api.cc                      |  48 ++
 src/runtime/relax_vm/paged_kv_cache.cc             |   3 +-
 ...runtime_builtin_paged_attention_kv_cache_cpu.py | 956 +++++++++++++++++++++
 5 files changed, 2278 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 399e418c46..844b237381 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -30,7 +30,16 @@ from tvm.script import tir as T
 from tvm.target import Target
 
 from .position_embedding import llama_rope_with_position_map, 
switch_rope_freq_func
-from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache
+from .tree_attn import (
+    tree_attn,
+    tree_attn_cpu,
+    tree_attn_with_paged_kv_cache,
+    tree_attn_with_paged_kv_cache_cpu,
+)
+
+
+def _var_cpu(dtype):
+    return T.alloc_buffer((1,), dtype)
 
 
 def get_max_num_threads_per_block(target: Target) -> int:
@@ -371,23 +380,230 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
             # pylint: disable=line-too-long
             # fmt: off
             bb.add_func(_kv_cache_transpose_append(num_key_value_heads, 
head_dim, dtype), "kv_cache_transpose_append"),
-            bb.add_func(_attention_prefill(num_key_value_heads, 
num_attention_heads, head_dim, dtype, False, rope_scaling, target), 
"tir_attention_prefill"),
-            bb.add_func(_attention_decode(num_key_value_heads, 
num_attention_heads, head_dim, dtype, False, rope_scaling, target), 
"tir_attention_decode"),
-            bb.add_func(_attention_prefill(num_key_value_heads, 
num_attention_heads, head_dim, dtype, True, rope_scaling, target), 
"tir_attention_prefill_sliding_window"),
-            bb.add_func(_attention_decode(num_key_value_heads, 
num_attention_heads, head_dim, dtype, True, rope_scaling, target), 
"tir_attention_decode_sliding_window"),
-            bb.add_func(_attention_prefill_ragged(num_key_value_heads, 
num_attention_heads, head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_ragged"),
-            bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, 
dtype, target), "tir_attention_merge_state"),
-            bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, 
head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, 
rotary_dim), "tir_split_rotary"),
-            bb.add_func(_copy_single_page(num_key_value_heads, page_size, 
head_dim, dtype, target), "kv_cache_copy_single_page"),
-            bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, 
num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
-            bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, 
target), "kv_cache_compact_kv_copy"),
-            bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
-            bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
-            rope_ext_factors,
-            rx.PrimValue(enable_disaggregation),
             # fmt: on
             # pylint: enable=line-too-long
         ]
+
+        if str(target.kind) == "llvm":
+            args.extend(
+                [
+                    bb.add_func(
+                        _attention_prefill_cpu(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            False,
+                            rope_scaling,
+                        ),
+                        "tir_attention_prefill_cpu",
+                    ),
+                    bb.add_func(
+                        _attention_decode_cpu(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            False,
+                            rope_scaling,
+                        ),
+                        "tir_attention_decode_cpu",
+                    ),
+                    bb.add_func(
+                        _attention_prefill_cpu(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            True,
+                            rope_scaling,
+                        ),
+                        "tir_attention_prefill_cpu_sliding_window",
+                    ),
+                    bb.add_func(
+                        _attention_decode_cpu(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            True,
+                            rope_scaling,
+                        ),
+                        "tir_attention_decode_cpu_sliding_window",
+                    ),
+                    bb.add_func(
+                        _attention_prefill_ragged_cpu(
+                            num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling
+                        ),
+                        "tir_attention_prefill_ragged_cpu",
+                    ),
+                    bb.add_func(
+                        _merge_state_inplace_cpu(dtype),
+                        "tir_attention_merge_state_cpu",
+                    ),
+                    bb.add_func(
+                        llama_rope_with_position_map(
+                            rope_theta,
+                            rope_scale,
+                            head_dim,
+                            num_attention_heads,
+                            num_key_value_heads,
+                            dtype,
+                            rope_scaling,
+                            rotary_dim,
+                        ),
+                        "tir_split_rotary",
+                    ),
+                    bb.add_func(
+                        _copy_single_page_cpu(num_key_value_heads, page_size, 
head_dim, dtype),
+                        "kv_cache_copy_single_page_cpu",
+                    ),
+                    bb.add_func(
+                        _kv_cache_debug_get_kv(
+                            num_hidden_layers, num_key_value_heads, head_dim, 
dtype
+                        ),
+                        "kv_cache_debug_get_kv",
+                    ),
+                    bb.add_func(
+                        _compact_kv_copy_cpu(num_key_value_heads, head_dim, 
dtype),
+                        "kv_cache_compact_kv_copy_cpu",
+                    ),
+                    bb.add_func(
+                        tree_attn_cpu(
+                            num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling
+                        ),
+                        "tir_attention_prefill_with_tree_mask_cpu",
+                    ),
+                    bb.add_func(
+                        tree_attn_with_paged_kv_cache_cpu(
+                            num_key_value_heads, num_attention_heads, 
head_dim, dtype, rope_scaling
+                        ),
+                        
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu",
+                    ),
+                    rope_ext_factors,
+                    rx.PrimValue(enable_disaggregation),
+                ]
+            )
+        else:
+            args.extend(
+                [
+                    bb.add_func(
+                        _attention_prefill(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            False,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_prefill",
+                    ),
+                    bb.add_func(
+                        _attention_decode(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            False,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_decode",
+                    ),
+                    bb.add_func(
+                        _attention_prefill(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            True,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_prefill_sliding_window",
+                    ),
+                    bb.add_func(
+                        _attention_decode(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            True,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_decode_sliding_window",
+                    ),
+                    bb.add_func(
+                        _attention_prefill_ragged(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_prefill_ragged",
+                    ),
+                    bb.add_func(
+                        _merge_state_inplace(num_attention_heads, head_dim, 
dtype, target),
+                        "tir_attention_merge_state",
+                    ),
+                    bb.add_func(
+                        llama_rope_with_position_map(
+                            rope_theta,
+                            rope_scale,
+                            head_dim,
+                            num_attention_heads,
+                            num_key_value_heads,
+                            dtype,
+                            rope_scaling,
+                            rotary_dim,
+                        ),
+                        "tir_split_rotary",
+                    ),
+                    bb.add_func(
+                        _copy_single_page(num_key_value_heads, page_size, 
head_dim, dtype, target),
+                        "kv_cache_copy_single_page",
+                    ),
+                    bb.add_func(
+                        _kv_cache_debug_get_kv(
+                            num_hidden_layers, num_key_value_heads, head_dim, 
dtype
+                        ),
+                        "kv_cache_debug_get_kv",
+                    ),
+                    bb.add_func(
+                        _compact_kv_copy(num_key_value_heads, head_dim, dtype, 
target),
+                        "kv_cache_compact_kv_copy",
+                    ),
+                    bb.add_func(
+                        tree_attn(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            rope_scaling,
+                            target,
+                        ),
+                        "tir_attention_prefill_with_tree_mask",
+                    ),
+                    bb.add_func(
+                        tree_attn_with_paged_kv_cache(
+                            num_key_value_heads,
+                            num_attention_heads,
+                            head_dim,
+                            dtype,
+                            rope_scaling,
+                            target,
+                        ),
+                        
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache",
+                    ),
+                    rope_ext_factors,
+                    rx.PrimValue(enable_disaggregation),
+                ]
+            )
+
         super().__init__(
             _expr=rx.call_pure_packed(
                 "vm.builtin.paged_attention_kv_cache_create_reduced",
@@ -553,6 +769,161 @@ def _get_seq_offset(pos, seq_id, length_info, 
sliding_window):
     )
 
 
+def _attention_prefill_cpu(h_kv, h_q, d, dtype, sliding_window: bool, 
rope_scaling: Dict[str, Any]):
+    global_symbol = "batch_prefill_paged_kv_cpu"
+    if sliding_window:
+        global_symbol += "_sliding_window"
+
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+    # pylint: disable=line-too-long,too-many-branches
+    # fmt: off
+    @T.prim_func
+    def batch_prefill_paged_kv_cpu(
+        _0: T.int32,  # pylint: disable=unused-argument
+        var_q: T.handle, # [total_len, h_q, d]
+        var_q_indptr: T.handle, # [batch_size + 1]
+        var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d]
+        var_page_indptr: T.handle, # [batch_size + 1]
+        var_page_values: T.handle, # [nnz_pages]
+        var_length_info: T.handle, # [b] when sliding window = False, or 
otherwise [3, b]
+        var_k_rope_pos_offset: T.handle, # [b]
+        var_q_rope_position: T.handle, # [total_len]
+        var_output: T.handle, # [total_len, h_q, d]
+        var_lse: T.handle, # [total_len, h_q]
+        causal: T.int32,
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+    ):
+        T.func_attr({"global_symbol": global_symbol})
+        batch_size = T.int32(is_size_var=True)
+        total_len = T.int32(is_size_var=True)
+        nnz_pages = T.int32(is_size_var=True)
+        max_num_pages = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        page_indptr_elem_offset = T.int32(is_size_var=True)
+        page_values_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        length_info_elem_offset = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
+        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
+        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
+        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", 
elem_offset=page_values_elem_offset)
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
+        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
+        output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (total_len, h_q), "float32")  # pylint: 
disable=unused-variable
+        # The length information of the sequences.
+        # - It is in shape `(3, batch_size)` when sliding window is enabled.
+        #   For a sequence "i", location
+        #   - "(0, i)" is the number of KV slots used in the last page of the 
seq ("last_page_len"),
+        #   - "(1, i)" is the starting offset of the sliding window in the seq,
+        #   - "(2, i)" is the attn sink length of the sequence.
+        # - It is in shape `(batch_size,)` when sliding window is disabled,
+        #   denoting the "last_page_len".
+        length_info = _declare_length_info(var_length_info, batch_size, 
sliding_window, length_info_elem_offset)
+
+
+        for h_qo in T.serial(h_q):
+            for b_idx in T.serial(batch_size):
+                with T.block("attn"):
+                    O_local = T.alloc_buffer((d, ), "float32")
+                    Q_local = T.alloc_buffer((d, ), "float32")
+                    K_local = T.alloc_buffer((d, ), "float32")
+                    V_local = T.alloc_buffer((d, ), "float32")
+
+                    kv_chunk_len = T.alloc_buffer((1, ), "int32")
+
+                    m_val = T.alloc_buffer((1, ), "float32")
+                    new_m = T.alloc_buffer((1, ), "float32")
+                    d_val = T.alloc_buffer((1, ), "float32")
+                    S_val = T.alloc_buffer((1, ), "float32")
+                    scale_O = T.alloc_buffer((1, ), "float32")
+                    factor = T.alloc_buffer((1, ), "float32")
+                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
+                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
+                    #max_kv_len: T.int32 = max_num_pages * 16
+                    kv_chunk_len[0] = T.if_then_else(
+                        cur_page_indptr_begin != cur_page_indptr_end,
+                        _get_kv_chunk_len(cur_page_indptr_end - 
cur_page_indptr_begin, 16, b_idx, length_info, sliding_window),
+                        0
+                    )
+
+
+                    for q_idx in T.serial(q_indptr[b_idx + 1] - 
q_indptr[b_idx]):
+                        #init m, d, O
+                        m_val[0] = -5e4
+                        d_val[0] = 1.0
+                        for d_idx in T.serial(d):
+                            O_local[d_idx] = 0.0
+                        curl_q: T.int32 = q_indptr[b_idx] + q_idx
+
+                        for d_idx in T.serial(d):
+
+                            Q_local[d_idx] = T.if_then_else(
+                                rotary_mode == 1,
+                                _rope(q, q_rope_position[curl_q], d, 
rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
+                                q[curl_q, h_qo, d_idx]
+                            )
+                        for row_idx in T.serial(max_num_pages * 16):
+                            if row_idx < kv_chunk_len[0]:
+                                # seq_offset: T.int32(is_size_var=True) = 
_get_seq_offset(row_idx, b_idx, length_info, sliding_window)
+                                #seq_offset: T.int32(is_size_var=True) = 
row_idx
+                                page_no: T.int32(is_size_var=True) = 
page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, 
length_info, sliding_window) // 16)]
+                                page_offset: T.int32(is_size_var=True) = 
_get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16
+
+                                # Load KV
+                                for d_idx in T.serial(d):
+                                    K_local[d_idx] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(pages, k_rope_pos_offset[b_idx] 
+ row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, 
page_offset, d_idx), dtype, rope_scaling),
+                                        pages[page_no, 0, h_qo // group_size, 
page_offset, d_idx]
+                                    )
+                                    V_local[d_idx] = pages[page_no, 1, h_qo // 
group_size, page_offset, d_idx]
+
+                                # Compute S
+                                # Q[i] * K[i]   * attn_score * sm_scale
+                                S_val[0] = 0.0
+                                for d_idx in T.serial(d):
+                                    S_val[0] += Q_local[d_idx] * K_local[d_idx]
+                                S_val[0] *= attn_score_scaling_factor * 
sm_scale
+
+                                # update m_val, d_val , O_local
+                                if _causal_mask(causal,
+                                    row=q_idx,
+                                    col=row_idx,
+                                    kv_len=kv_chunk_len[0],
+                                    qo_len=q_indptr[b_idx + 1] - 
q_indptr[b_idx]):
+                                    new_m[0] = T.max(m_val[0], S_val[0])
+                                else:
+                                    S_val[0] = -5e4
+                                # update d_val
+                                d_val[0] *= T.exp2(m_val[0] - new_m[0])
+                                d_val[0] += T.exp2(S_val[0] - new_m[0])
+
+                                # restore O_local then update O_local
+                                scale_O[0] = T.exp2(m_val[0] - new_m[0])
+                                m_val[0] = new_m[0]
+                                factor[0] = T.exp2(S_val[0] - m_val[0])
+                                for d_idx in T.serial(d):
+                                    O_local[d_idx] = O_local[d_idx] * 
scale_O[d_idx]
+
+
+                                for d_idx in T.serial(d):
+                                    O_local[d_idx] += V_local[d_idx] * 
factor[0]
+                        # Store Output
+                        for d_idx in T.serial(d):
+                            O_local[d_idx] = O_local[d_idx] /d_val[0]
+                            output[curl_q, h_qo, d_idx] = O_local[d_idx]
+                        lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0])
+    return batch_prefill_paged_kv_cpu
+
+
 def _attention_prefill(
     h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], 
target: Target
 ):
@@ -920,6 +1291,189 @@ def _attention_prefill(
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
+def _attention_decode_cpu(
+    num_kv_heads,
+    num_qo_heads,
+    head_dim,
+    qkv_dtype,
+    sliding_window: bool,
+    rope_scaling: Dict[str, Any],
+):
+    log2e = math.log2(math.exp(1))
+    H_qo = num_qo_heads
+    H_kv = num_kv_heads
+    D = head_dim
+    group_size = num_qo_heads // num_kv_heads
+
+    global_symbol = "batch_decode_paged_kv_cpu"
+    if sliding_window:
+        global_symbol += "_sliding_window"
+
+    @T.prim_func(check_well_formed=False)
+    def batch_decode_paged_kv(
+        _0: T.int32,  # pylint: disable=unused-argument
+        Q_handle: T.handle,
+        pages_handle: T.handle,
+        page_table_indptr_handle: T.handle,
+        page_table_values_handle: T.handle,
+        var_length_info: T.handle,  # [b] when sliding window = False, or 
otherwise [3, b]
+        k_rope_pos_offset_handle: T.handle,
+        q_rope_position_handle: T.handle,
+        output_handle: T.handle,
+        lse_handle: T.handle,
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+    ):
+        T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol})
+        B = T.int32(is_size_var=True)
+        nnz_pages = T.int32(is_size_var=True)
+        max_num_pages = T.int32(is_size_var=True)
+        page_indptr_elem_offset = T.int32(is_size_var=True)
+        page_values_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        length_info_elem_offset = T.int32(is_size_var=True)
+
+        Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)  # query 值
+        pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, D), 
qkv_dtype)
+        page_table_indptr = T.match_buffer(
+            page_table_indptr_handle, (B + 1,), "int32", 
elem_offset=page_indptr_elem_offset
+        )
+        page_table_values = T.match_buffer(
+            page_table_values_handle, (nnz_pages,), "int32", 
elem_offset=page_values_elem_offset
+        )
+        k_rope_pos_offset = T.match_buffer(
+            k_rope_pos_offset_handle, (B,), "int32", 
elem_offset=k_rope_pos_offset_elem_offset
+        )
+        q_rope_position = T.match_buffer(
+            q_rope_position_handle, (B,), "int32", 
elem_offset=q_rope_position_elem_offset
+        )
+        output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype)
+        lse = T.match_buffer(lse_handle, (B, H_qo), "float32")  # pylint: 
disable=unused-variable
+        # The length information of the sequences.
+        # - It is in shape `(3, batch_size)` when sliding window is enabled.
+        #   For a sequence "i", location
+        #   - "(0, i)" is the number of KV slots used in the last page of the 
seq ("last_page_len"),
+        #   - "(1, i)" is the starting offset of the sliding window in the seq,
+        #   - "(2, i)" is the attn sink length of the sequence.
+        # - It is in shape `(batch_size,)` when sliding window is disabled,
+        #   denoting the "last_page_len".
+        length_info = _declare_length_info(
+            var_length_info, B, sliding_window, length_info_elem_offset
+        )
+
+        sm_scale = 1.0 / math.sqrt(float(D)) * log2e
+
+        for b in T.serial(B):
+            with T.block("attn"):
+                O_local = T.alloc_buffer((D,), "float32")
+                Q_local = T.alloc_buffer((D,), "float32")
+                K_local = T.alloc_buffer((D,), "float32")
+                V_local = T.alloc_buffer((D,), "float32")
+
+                kv_chunk_len = T.alloc_buffer((1,), "int32")
+
+                m_val = T.alloc_buffer((1,), "float32")
+                new_m = T.alloc_buffer((1,), "float32")
+                d_val = T.alloc_buffer((1,), "float32")
+                S_val = T.alloc_buffer((1,), "float32")
+                scale_O = T.alloc_buffer((1,), "float32")
+                factor = T.alloc_buffer((1,), "float32")
+
+                cur_page_indptr_begin: T.int32 = page_table_indptr[b]
+                cur_page_indptr_end: T.int32 = page_table_indptr[b + 1]
+
+                kv_chunk_len[0] = T.if_then_else(
+                    cur_page_indptr_begin != cur_page_indptr_end,
+                    _get_kv_chunk_len(
+                        cur_page_indptr_end - cur_page_indptr_begin,
+                        16,
+                        b,
+                        length_info,
+                        sliding_window,
+                    ),
+                    0,
+                )
+
+                for h_qo in T.serial(H_qo):
+                    m_val[0] = -5e4
+                    d_val[0] = 1.0
+
+                    for d in T.serial(D):
+                        O_local[d] = 0.0
+
+                    for d in T.serial(D):
+                        Q_local[d] = T.if_then_else(
+                            rotary_mode == 1,
+                            _rope(
+                                Q,
+                                q_rope_position[b],
+                                head_dim,
+                                rope_theta,
+                                rope_scale,
+                                (b, h_qo, d),
+                                qkv_dtype,
+                                rope_scaling,
+                            ),
+                            Q[b, h_qo, d],
+                        )
+
+                    for row_idx in T.serial(kv_chunk_len[0]):
+                        seq_offset: T.int32(is_size_var=True) = 
_get_seq_offset(
+                            row_idx, b, length_info, sliding_window
+                        )
+                        page_no: T.int32(is_size_var=True) = page_table_values[
+                            cur_page_indptr_begin + (seq_offset // 16)
+                        ]
+                        page_offset: T.int32(is_size_var=True) = seq_offset % 
16
+
+                        for d in T.serial(D):
+                            K_local[d] = T.if_then_else(
+                                rotary_mode == 1,
+                                _rope(
+                                    pages,
+                                    k_rope_pos_offset[b] + row_idx,
+                                    head_dim,
+                                    rope_theta,
+                                    rope_scale,
+                                    (page_no, 0, h_qo // group_size, 
page_offset, d),
+                                    qkv_dtype,
+                                    rope_scaling,
+                                ),
+                                pages[page_no, 0, h_qo // group_size, 
page_offset, d],
+                            )
+                        S_val[0] = 0.0
+                        for d in T.serial(D):
+                            S_val[0] += Q_local[d] * K_local[d]
+                        S_val[0] *= attn_score_scaling_factor * sm_scale
+
+                        new_m[0] = T.max(m_val[0], S_val[0])
+                        d_val[0] = (d_val[0] * T.exp2(m_val[0] - new_m[0])) + 
T.exp2(
+                            S_val[0] - new_m[0]
+                        )
+
+                        scale_O[0] = T.exp2(m_val[0] - new_m[0])
+
+                        for d in T.serial(D):
+                            O_local[d] = O_local[d] * scale_O[0]
+
+                        m_val[0] = new_m[0]
+                        for d in T.serial(D):
+                            V_local[d] = pages[page_no, 1, h_qo // group_size, 
page_offset, d]
+
+                        factor[0] = T.exp2(S_val[0] - m_val[0])
+                        for d in T.serial(D):
+                            O_local[d] = O_local[d] + V_local[d] * factor[0]
+                    for d in T.serial(D):
+                        O_local[d] = O_local[d] / d_val[0]
+                        output[b, h_qo, d] = O_local[d]
+                    lse[b, h_qo] = m_val[0] + T.log2(d_val[0])
+
+    return batch_decode_paged_kv
+
+
 def _attention_decode(
     num_kv_heads,
     num_qo_heads,
@@ -1179,6 +1733,47 @@ def _attention_decode(
     return batch_decode_paged_kv
 
 
+def _merge_state_inplace_cpu(v_dtype):
+    @T.prim_func
+    def merge_state_inplace_cpu(
+        v: T.handle,
+        s: T.handle,
+        v_other: T.handle,
+        s_other: T.handle,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        N = T.int32(is_size_var=True)
+        H = T.int32(is_size_var=True)
+        D = T.int32(is_size_var=True)
+
+        V = T.match_buffer(v, (N, H, D), v_dtype)
+        S = T.match_buffer(s, (N, H), "float32")
+        V_other = T.match_buffer(v_other, (N, H, D), v_dtype)
+        S_other = T.match_buffer(s_other, (N, H), "float32")
+
+        for n in T.serial(N):
+            for h in T.serial(H):
+                with T.block("merge"):
+                    s_val = _var_cpu("float32")
+                    s_other_val = _var_cpu("float32")
+                    s_max = _var_cpu("float32")
+                    scale = _var_cpu("float32")
+                    other_scale = _var_cpu("float32")
+
+                    s_val[0] = S[n, h]
+                    s_other_val[0] = S_other[n, h]
+                    s_max[0] = T.max(s_val[0], s_other_val[0])
+                    s_val[0] = T.exp2(s_val[0] - s_max[0])
+                    s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
+                    scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
+                    other_scale[0] = s_other_val[0] / (s_val[0] + 
s_other_val[0])
+                    for d in T.serial(D):
+                        V[n, h, d] = V[n, h, d] * scale[0] + V_other[n, h, d] 
* other_scale[0]
+                    S[n, h] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]
+
+    return merge_state_inplace_cpu
+
+
 def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target):
     v_dtype_bytes = 2
     VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4)
@@ -1577,6 +2172,175 @@ def _attention_sequence_prefill(
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
+def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, 
Any]):
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+    @T.prim_func
+    def batch_prefill_ragged_kv(  # pylint: disable=too-many-branches
+        var_q: T.handle,  # [total_len, h_q, d]
+        var_q_indptr: T.handle,  # [batch_size + 1]
+        var_k: T.handle,  # [total_len, h_kv, d]
+        var_v: T.handle,  # [total_len, h_kv, d]
+        var_kv_indptr: T.handle,  # [batch_size + 1]
+        var_q_rope_position: T.handle,  # [total_q_len]
+        var_k_rope_pos_offset: T.handle,  # [b]
+        var_output: T.handle,  # [total_len, h_q, d]
+        var_lse: T.handle,  # [total_len, h_q]
+        causal: T.int32,
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+    ):
+        batch_size = T.int32(is_size_var=True)
+        qo_len = T.int32(is_size_var=True)
+        kv_len = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        kv_indptr_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(
+            var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset
+        )
+        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+        kv_indptr = T.match_buffer(
+            var_kv_indptr, (batch_size + 1,), "int32", 
elem_offset=kv_indptr_elem_offset
+        )
+        q_rope_position = T.match_buffer(
+            var_q_rope_position, (qo_len,), "int32", 
elem_offset=q_rope_position_elem_offset
+        )
+        k_rope_pos_offset = T.match_buffer(
+            var_k_rope_pos_offset, (batch_size,), "int32", 
elem_offset=k_rope_pos_offset_elem_offset
+        )
+        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
+
+        for b in T.serial(batch_size):
+            with T.block("attn"):
+                softmax_sum = T.alloc_buffer([h_q], "float32")
+                m_prev = T.alloc_buffer([h_q], "float32")
+                m_new = T.alloc_buffer([h_q], "float32")
+                d_prev = T.alloc_buffer([h_q], "float32")
+                d_new = T.alloc_buffer([h_q], "float32")
+                p_sum = T.alloc_buffer([d], "float32")
+                max_score = T.alloc_buffer([h_q], "float32")
+                attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
+                exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
+                attention_score = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                query_val = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                key_val = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                result = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+
+                for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
+                    for i in T.serial(h_q):
+                        max_score[i] = -5e4
+                        m_prev[i] = -5e4
+                        d_prev[i] = 1.0
+
+                    for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                        for h in T.serial(h_q):
+                            h_kv_idx = h // group_size
+
+                            if _causal_mask(
+                                causal,
+                                row=q_idx,
+                                col=k_idx,
+                                kv_len=kv_indptr[b + 1] - kv_indptr[b],
+                                qo_len=q_indptr[b + 1] - q_indptr[b],
+                            ):
+                                result[0] = 0.0
+                                for d_idx in T.serial(d):
+                                    query_val[0] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(
+                                            q,
+                                            q_rope_position[q_indptr[b] + 
q_idx],
+                                            d,
+                                            rope_theta,
+                                            rope_scale,
+                                            (q_indptr[b] + q_idx, h, d_idx),
+                                            dtype,
+                                            rope_scaling,
+                                        ),
+                                        q[q_indptr[b] + q_idx, h, d_idx],
+                                    )
+
+                                    key_val[0] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(
+                                            k,
+                                            k_rope_pos_offset[b] + k_idx,
+                                            d,
+                                            rope_theta,
+                                            rope_scale,
+                                            (kv_indptr[b] + k_idx, h_kv_idx, 
d_idx),
+                                            dtype,
+                                            rope_scaling,
+                                        ),
+                                        k[kv_indptr[b] + k_idx, h_kv_idx, 
d_idx],
+                                    )
+
+                                    result[0] += query_val[0] * key_val[0]
+                                attention_score[0] = (
+                                    result[0] * sm_scale * 
attn_score_scaling_factor
+                                )
+                            else:
+                                attention_score[0] = -5e4 * sm_scale * 
attn_score_scaling_factor
+                            attention_scores[k_idx, h] = attention_score[0]
+                            max_score[h] = T.max(max_score[h], 
attention_score[0])
+                            m_new[h] = T.max(m_prev[h], max_score[h])
+
+                    for h in T.serial(h_q):
+                        d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])
+
+                    for h in T.serial(h_q):
+                        softmax_sum[h] = 0.0
+                        for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                            exp_scores[k_idx, h] = 
T.exp2(attention_scores[k_idx, h] - m_new[h])
+                            softmax_sum[h] += exp_scores[k_idx, h]
+                        d_new[h] += softmax_sum[h]
+                    d_prev = d_new
+                    m_prev = m_new
+
+                    for h in T.serial(h_q):
+                        h_kv_idx = h // group_size
+                        for i in T.serial(d):
+                            p_sum[i] = 0.0
+                        for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                            weight = exp_scores[v_idx, h] / d_new[h]
+                            for i in T.serial(d):
+                                p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, 
i] * weight
+                        for i in T.serial(d):
+                            output[q_indptr[b] + q_idx, h, i] = p_sum[i]
+                        lse[q_indptr[b] + q_idx, h] = m_prev[h] + 
T.log2(d_prev[h])
+
+    return batch_prefill_ragged_kv
+
+
 def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, 
Any], target: Target):
     # pylint: disable=line-too-long
     NUM_BLKS = 16
@@ -1949,6 +2713,45 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
+def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype):
+    tx = 1
+
+    @T.prim_func
+    def copy_single_page_cpu(
+        var_pages: T.handle,
+        src_page_id: T.int64,
+        tgt_page_id: T.int64,
+        copy_length: T.int64,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        num_pages = T.int32()
+        pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, 
head_dim), dtype)
+
+        for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx):
+            for t in T.serial(tx):
+                with T.block("copy"):
+                    T.where(b * tx + t < copy_length * num_heads * head_dim)
+                    vh = T.axis.spatial(
+                        num_heads,
+                        T.Cast("int32", (b * tx + t) // (copy_length * 
head_dim)),
+                    )
+                    vp = T.axis.spatial(
+                        copy_length,
+                        (b * tx + t) % (copy_length * head_dim) // head_dim,
+                    )
+                    vd = T.axis.spatial(
+                        head_dim,
+                        T.Cast(
+                            "int32",
+                            (b * tx + t) % head_dim,
+                        ),
+                    )
+                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, 
vh, vp, vd]
+                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, 
vh, vp, vd]
+
+    return copy_single_page_cpu
+
+
 def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target):
     tx = get_max_num_threads_per_block(target)
 
@@ -1996,6 +2799,55 @@ def _copy_single_page(num_heads, page_size, head_dim, 
dtype, target: Target):
     return copy_single_page
 
 
+def _compact_kv_copy_cpu(num_heads, head_dim, dtype):
+    tx = 8
+
+    @T.prim_func
+    def compact_kv_copy_cpu(
+        var_pages: T.handle,
+        var_copy_length_indptr: T.handle,
+        var_copy_src_dst_pos: T.handle,
+        batch_size: T.int32,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        num_pages = T.int32()
+        total_copy_length = T.int32()
+        copy_length_indptr_elem_offset = T.int32()
+        copy_src_dst_pos_elem_offset = T.int32()
+        pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, 
head_dim), dtype)
+        copy_length_indptr = T.match_buffer(
+            var_copy_length_indptr,
+            (batch_size + 1,),
+            "int32",
+            elem_offset=copy_length_indptr_elem_offset,
+        )
+        copy_src_dst_pos = T.match_buffer(
+            var_copy_src_dst_pos,
+            (2, total_copy_length),
+            "int32",
+            elem_offset=copy_src_dst_pos_elem_offset,
+        )
+
+        with T.block("root"):
+            for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) 
// tx):
+                for bhd_i in T.serial(tx):
+                    b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim)
+                    h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads
+                    d: T.int32 = (bhd_o * tx + bhd_i) % head_dim
+                    if (bhd_o * tx + bhd_i) < batch_size * num_heads * 
head_dim:
+                        for i in T.serial(copy_length_indptr[b + 1] - 
copy_length_indptr[b]):
+                            src_pos: T.int32 = copy_src_dst_pos[0, 
copy_length_indptr[b] + i]
+                            dst_pos: T.int32 = copy_src_dst_pos[1, 
copy_length_indptr[b] + i]
+                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = 
pages[
+                                src_pos // 16, 0, h, src_pos % 16, d
+                            ]
+                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = 
pages[
+                                src_pos // 16, 1, h, src_pos % 16, d
+                            ]
+
+    return compact_kv_copy_cpu
+
+
 def _compact_kv_copy(num_heads, head_dim, dtype, target: Target):
     tx = get_max_num_threads_per_block(target)
 
diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py 
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index 9e4a7ed97e..fa0146afb6 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -82,6 +82,213 @@ def _check_tree_order(tree_order_indptr, tree_order, batch, 
row, col, kv_len, qo
     )
 
 
+def _declare_length_info(var_length_info, batch_size, sliding_window, 
elem_offset):
+    return (
+        T.match_buffer(var_length_info, (3, batch_size), "int32", 
elem_offset=elem_offset)
+        if sliding_window
+        else T.match_buffer(var_length_info, (batch_size,), "int32", 
elem_offset=elem_offset)
+    )
+
+
+def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]):
+    """Generate tree attention kernel for batched tree attention.
+
+    Parameters
+    ----------
+    h_kv : int
+        Number of heads for key and value.
+    h_q : int
+        Number of heads for query.
+    d : int
+        Hidden dimension.
+    dtype : str
+        Data type.
+    target : Target
+        The target device.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The generated IR module.
+    """
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+    # fmt: off
+    @T.prim_func
+    def batch_tree_attn(  # pylint: disable=too-many-branches,line-too-long
+        var_q: T.handle,  # [total_len, h_q, d]
+        var_q_indptr: T.handle,  # [batch_size + 1]
+        var_k: T.handle,  # [total_len, h_kv, d]
+        var_v: T.handle,  # [total_len, h_kv, d]
+        var_kv_indptr: T.handle,  # [batch_size + 1], kv_indptr should be the 
same as q_indptr in this case
+        var_q_rope_position: T.handle,  # [total_q_len]
+        var_mn_indptr: T.handle,  # [batch_size + 1]
+        var_mask: T.handle,  # [mn_indptr[batch_size]]
+        var_output: T.handle,  # [total_len, h_q, d]
+        var_lse: T.handle,  # [total_len, h_q]
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+        batch_size: T.int32,
+    ):
+        qo_len = T.int32(is_size_var=True)
+        kv_len = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        kv_indptr_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        mn_indptr_elem_offset = T.int32(is_size_var=True)
+        mask_elem_offset = T.int32(is_size_var=True)
+        tree_size = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(
+            var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset
+        )
+        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+        kv_indptr = T.match_buffer(
+            var_kv_indptr, (batch_size + 1,), "int32", 
elem_offset=kv_indptr_elem_offset
+        )
+        q_rope_position = T.match_buffer(
+            var_q_rope_position, (qo_len,), "int32", 
elem_offset=q_rope_position_elem_offset
+        )
+        mn_indptr = T.match_buffer(
+            var_mn_indptr, (batch_size + 1,), "int32", 
elem_offset=mn_indptr_elem_offset
+        )
+        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", 
elem_offset=mask_elem_offset)
+        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
+
+        for b in T.serial(batch_size):
+            with T.block("attn"):
+
+                softmax_sum = T.alloc_buffer([h_q], "float32")
+                m_prev = T.alloc_buffer([h_q], "float32")
+                m_new = T.alloc_buffer([h_q], "float32")
+                d_prev = T.alloc_buffer([h_q], "float32")
+                d_new = T.alloc_buffer([h_q], "float32")
+                p_sum = T.alloc_buffer([d], "float32")
+
+                max_score = T.alloc_buffer([h_q], "float32")
+                attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
+                exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
+                attention_score = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                query_val = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                key_val = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+                result = T.alloc_buffer(
+                    [
+                        1,
+                    ],
+                    "float32",
+                )
+
+                for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
+                    for i in T.serial(h_q):
+                        max_score[i] = -5e4
+                        m_prev[i] = -5e4
+                        d_prev[i] = 1.0
+
+                    for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                        for h in T.serial(h_q):
+                            h_kv_idx = h // group_size
+
+                            if _check_tree_order(
+                                row=q_idx,
+                                col=k_idx,
+                                batch=b,
+                                tree_order=mask,
+                                tree_order_indptr=mn_indptr,
+                                kv_len=kv_indptr[b + 1] - kv_indptr[b],
+                                qo_len=q_indptr[b + 1] - q_indptr[b],
+                            ):
+                                result[0] = 0.0
+                                for d_idx in T.serial(d):
+                                    query_val[0] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(
+                                            q,
+                                            q_rope_position[q_indptr[b] + 
q_idx],
+                                            d,
+                                            rope_theta,
+                                            rope_scale,
+                                            (q_indptr[b] + q_idx, h, d_idx),
+                                            dtype,
+                                            rope_scaling,
+                                        ),
+                                        q[q_indptr[b] + q_idx, h, d_idx],
+                                    )
+
+                                    key_val[0] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(
+                                            k,
+                                            q_rope_position[kv_indptr[b] + 
k_idx],
+                                            d,
+                                            rope_theta,
+                                            rope_scale,
+                                            (kv_indptr[b] + k_idx, h_kv_idx, 
d_idx),
+                                            dtype,
+                                            rope_scaling,
+                                        ),
+                                        k[kv_indptr[b] + k_idx, h_kv_idx, 
d_idx],
+                                    )
+
+                                    result[0] += query_val[0] * key_val[0]
+                                attention_score[0] = (
+                                    result[0] * sm_scale * 
attn_score_scaling_factor
+                                )
+                            else:
+                                attention_score[0] = -5e4 * sm_scale * 
attn_score_scaling_factor
+                            attention_scores[k_idx, h] = attention_score[0]
+                            max_score[h] = T.max(max_score[h], 
attention_score[0])
+                            m_new[h] = T.max(m_prev[h], max_score[h])
+
+                    for h in T.serial(h_q):
+                        d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])
+
+                    for h in T.serial(h_q):
+                        softmax_sum[h] = 0.0
+                        for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                            exp_scores[k_idx, h] = 
T.exp2(attention_scores[k_idx, h] - m_new[h])
+                            softmax_sum[h] += exp_scores[k_idx, h]
+                        d_new[h] += softmax_sum[h]
+                    d_prev = d_new
+                    m_prev = m_new
+
+                    for h in T.serial(h_q):
+                        h_kv_idx = h // group_size
+                        for i in T.serial(d):
+                            p_sum[i] = 0.0
+                        for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
+                            weight = exp_scores[v_idx, h] / d_new[h]
+                            for i in T.serial(d):
+                                p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, 
i] * weight
+                        for i in T.serial(d):
+                            output[q_indptr[b] + q_idx, h, i] = p_sum[i]
+                        lse[q_indptr[b] + q_idx, h] = m_prev[h] + 
T.log2(d_prev[h])
+
+    # fmt: on
+    # pylint: enable=line-too-long,too-many-branches
+    return batch_tree_attn
+
+
 def tree_attn(
     h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target
 ):  # pylint: disable=unused-argument
@@ -437,6 +644,204 @@ def tree_attn(
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
+def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling: 
Dict[str, Any]):
+    """Generate tree attention kernel for batched tree attention with paged 
key-value cache.
+
+    Parameters
+    ----------
+    h_kv : int
+        Number of heads for key and value.
+    h_q : int
+        Number of heads for query.
+    d : int
+        Hidden dimension.
+    dtype : str
+        Data type.
+    target : Target
+        The target device.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The generated IR module.
+    """
+    # pylint: disable=import-outside-toplevel
+    from .kv_cache import (
+        _declare_length_info,
+        _get_kv_chunk_len,
+        _get_seq_offset,
+    )
+
+    global_symbol = "tree_attn_paged_kv_cpu"
+    sliding_window = False
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+    # pylint: disable=line-too-long,too-many-branches
+    # fmt: off
+    @T.prim_func(check_well_formed=False)
+    def tree_attn_paged_kv_cpu(
+        _0: T.int32,  # pylint: disable=unused-argument
+        var_q: T.handle, # [total_len, h_q, d]
+        var_q_indptr: T.handle, # [batch_size + 1]
+        var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d]
+        var_page_indptr: T.handle, # [batch_size + 1]
+        var_page_values: T.handle, # [nnz_pages]
+        var_length_info: T.handle, # [b] when sliding window = False, or 
otherwise [3, b]
+        var_k_rope_pos_offset: T.handle, # [b]
+        var_q_rope_position: T.handle, # [total_len]
+        var_output: T.handle, # [total_len, h_q, d]
+        var_lse: T.handle, # [total_len, h_q]
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+        tree_order_indptr_handle: T.handle,  # [batch_size + 1]
+        tree_order_handle: T.handle,  # [total_len, 2]
+    ):
+        T.func_attr({"global_symbol": global_symbol})
+        batch_size = T.int32(is_size_var=True)
+        total_len = T.int32(is_size_var=True)
+        nnz_pages = T.int32(is_size_var=True)
+        max_num_pages = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        page_indptr_elem_offset = T.int32(is_size_var=True)
+        page_values_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        length_info_elem_offset = T.int32(is_size_var=True)
+        tree_order_elem_offset = T.int32(is_size_var=True)
+        tree_order_indptr_elem_offset = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
+        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
+        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
+        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", 
elem_offset=page_values_elem_offset)
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
+        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
+        output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (total_len, h_q), "float32")  # pylint: 
disable=unused-variable
+        tree_order_indptr = T.match_buffer(
+            tree_order_indptr_handle,
+            (batch_size + 1,),
+            "int32",
+            elem_offset=tree_order_indptr_elem_offset,
+        )
+        total_tree_order_len = T.int32(is_size_var=True)
+        tree_order = T.match_buffer(
+            tree_order_handle,
+            (total_tree_order_len, 2),
+            "int32",
+            elem_offset=tree_order_elem_offset,
+        )
+        # The length information of the sequences.
+        # - It is in shape `(3, batch_size)` when sliding window is enabled.
+        #   For a sequence "i", location
+        #   - "(0, i)" is the number of KV slots used in the last page of the 
seq ("last_page_len"),
+        #   - "(1, i)" is the starting offset of the sliding window in the seq,
+        #   - "(2, i)" is the attn sink length of the sequence.
+        # - It is in shape `(batch_size,)` when sliding window is disabled,
+        #   denoting the "last_page_len".
+        length_info = _declare_length_info(var_length_info, batch_size, 
sliding_window, length_info_elem_offset)
+
+
+        T.Assert(
+            rotary_mode == T.int32(0), "Inline rotary mode is not supported in 
tree attention."
+        )
+
+        for h_qo in T.serial(h_q):
+            for b_idx in T.serial(batch_size):
+                with T.block("attn"):
+                    O_local = T.alloc_buffer((d, ), "float32")
+                    Q_local = T.alloc_buffer((d, ), "float32")
+                    K_local = T.alloc_buffer((d, ), "float32")
+                    V_local = T.alloc_buffer((d, ), "float32")
+
+                    kv_chunk_len = T.alloc_buffer((1, ), "int32")
+
+                    m_val = T.alloc_buffer((1, ), "float32")
+                    new_m = T.alloc_buffer((1, ), "float32")
+                    d_val = T.alloc_buffer((1, ), "float32")
+                    S_val = T.alloc_buffer((1, ), "float32")
+                    scale_O = T.alloc_buffer((1, ), "float32")
+                    factor = T.alloc_buffer((1, ), "float32")
+                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
+                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
+                    kv_chunk_len[0] = T.if_then_else(
+                        cur_page_indptr_begin != cur_page_indptr_end,
+                        _get_kv_chunk_len(cur_page_indptr_end - 
cur_page_indptr_begin, 16, b_idx, length_info, sliding_window),
+                        0
+                    )
+
+                    for q_idx in T.serial(q_indptr[b_idx + 1] - 
q_indptr[b_idx]):
+                        #init m, d, O
+                        m_val[0] = -5e4
+                        d_val[0] = 1.0
+                        for d_idx in T.serial(d):
+                            O_local[d_idx] = 0.0
+                        curl_q: T.int32 = q_indptr[b_idx] + q_idx
+
+                        for d_idx in T.serial(d):
+                            Q_local[d_idx] = T.if_then_else(
+                                rotary_mode == 1,
+                                _rope(q, q_rope_position[curl_q], d, 
rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling),
+                                q[curl_q, h_qo, d_idx]
+                            )
+                        for row_idx in T.serial(max_num_pages * 16):
+                            if row_idx < kv_chunk_len[0]:
+                                page_no: T.int32(is_size_var=True) = 
page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, 
length_info, sliding_window) // 16)]
+                                page_offset: T.int32(is_size_var=True) = 
_get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16
+
+                                # Load KV
+                                for d_idx in T.serial(d):
+                                    K_local[d_idx] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(pages, k_rope_pos_offset[b_idx] 
+ row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, 
page_offset, d_idx), dtype, rope_scaling),
+                                        pages[page_no, 0, h_qo // group_size, 
page_offset, d_idx]
+                                    )
+                                    V_local[d_idx] = pages[page_no, 1, h_qo // 
group_size, page_offset, d_idx]
+
+                                # Compute S
+                                S_val[0] = 0.0
+                                for d_idx in T.serial(d):
+                                    S_val[0] += Q_local[d_idx] * K_local[d_idx]
+                                S_val[0] *= attn_score_scaling_factor * 
sm_scale
+
+                                # update m_val, d_val , O_local
+                                if _check_tree_order(
+                                    tree_order_indptr=tree_order_indptr,
+                                    tree_order=tree_order,
+                                    batch=b_idx,
+                                    row=q_idx,
+                                    col=row_idx,
+                                    kv_len=kv_chunk_len[0],
+                                    qo_len=q_indptr[b_idx + 1] - 
q_indptr[b_idx],
+                                ):
+                                    new_m[0] = T.max(m_val[0], S_val[0])
+                                else:
+                                    S_val[0] = -5e4
+                                # update d_val
+                                d_val[0] *= T.exp2(m_val[0] - new_m[0])
+                                d_val[0] += T.exp2(S_val[0] - new_m[0])
+
+                                # restore O_local then update O_local
+                                scale_O[0] = T.exp2(m_val[0] - new_m[0])
+                                m_val[0] = new_m[0]
+                                factor[0] = T.exp2(S_val[0] - m_val[0])
+                                for d_idx in T.serial(d):
+                                    O_local[d_idx] = O_local[d_idx] * 
scale_O[d_idx]
+
+
+                                for d_idx in T.serial(d):
+                                    O_local[d_idx] += V_local[d_idx] * 
factor[0]
+                        # Store Output
+                        for d_idx in T.serial(d):
+                            O_local[d_idx] = O_local[d_idx] /d_val[0]
+                            output[curl_q, h_qo, d_idx] = O_local[d_idx]
+                        lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0])
+    return tree_attn_paged_kv_cpu
+
+
 def tree_attn_with_paged_kv_cache(
     h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target
 ):
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index ccd726a6ec..5dd470f00d 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -34,6 +34,19 @@
 #include <android/api-level.h>
 #endif
 
+#if defined(__linux__) || defined(__ANDROID__)
+#include <sys/sysinfo.h>
+#endif
+
+#ifdef _WIN32
+#include <windows.h>
+#endif
+
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#include <sys/sysctl.h>
+#endif
+
 namespace tvm {
 namespace runtime {
 class CPUDeviceAPI final : public DeviceAPI {
@@ -43,6 +56,41 @@ class CPUDeviceAPI final : public DeviceAPI {
     if (kind == kExist) {
       *rv = 1;
     }
+
+    switch (kind) {
+      case kExist:
+        break;
+      case kTotalGlobalMemory: {
+#if defined(__linux__) || defined(__ANDROID__)
+        struct sysinfo info;
+        if (sysinfo(&info) == 0) {
+          *rv = static_cast<int64_t>(info.totalram) * info.mem_unit;  // 
Convert to bytes
+        } else {
+          *rv = -1;
+        }
+#elif defined(_WIN32)
+        MEMORYSTATUSEX statex;
+        statex.dwLength = sizeof(statex);
+        if (GlobalMemoryStatusEx(&statex)) {
+          *rv = static_cast<int64_t>(statex.ullTotalPhys);  // Total physical 
memory in bytes
+        } else {
+          *rv = -1;
+        }
+#elif defined(__APPLE__)
+        int64_t mem;
+        size_t size = sizeof(mem);
+        if (sysctlbyname("hw.memsize", &mem, &size, nullptr, 0) == 0) {
+          *rv = mem;
+        } else {
+          *rv = -1;
+        }
+#else
+        *rv = -1;
+#endif
+      }
+      default:
+        break;
+    }
   }
   void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType 
type_hint) final {
     void* ptr;
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 8e5dfb4bd8..0b83bb426d 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1366,7 +1366,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // Create the auxiliary data manager for attention.
     // We only use the merged aux data for CUDA, since direct pointer
     // operations may have issues on other platforms.
-    if (device_.device_type == DLDeviceType::kDLCUDA) {
+    if (device_.device_type == DLDeviceType::kDLCUDA ||
+        device_.device_type == DLDeviceType::kDLCPU) {
       aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
           reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, 
device,
           preferred_host_device, copy_stream_);
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
new file mode 100644
index 0000000000..9487bbf860
--- /dev/null
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
@@ -0,0 +1,956 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import enum
+import itertools
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pytest
+import scipy.special
+
+import tvm
+import tvm.testing
+from tvm import dlight as dl
+from tvm.relax.frontend.nn.llm.kv_cache import (
+    _attention_decode_cpu,
+    _attention_prefill_cpu,
+    _attention_prefill_ragged_cpu,
+    _compact_kv_copy_cpu,
+    _copy_single_page_cpu,
+    _kv_cache_debug_get_kv,
+    _kv_cache_transpose_append,
+    _merge_state_inplace_cpu,
+    llama_rope_with_position_map,
+    tree_attn_cpu,
+    tree_attn_with_paged_kv_cache_cpu,
+)
+from tvm.runtime import ShapeTuple
+
+reserved_nseq = 32
+maximum_total_seq_length = 2048
+prefill_chunk_size = 512
+page_size = 16
+num_layers = 4
+num_qo_heads = 32
+num_kv_heads = 4
+head_dim = None
+rope_scale = 1.0
+rope_theta = 1e4
+rope_scaling = {}
+dtype = None
+device = tvm.cpu()
+
+fclear = None
+fadd_sequence = None
+fremove_sequence = None
+ffork_sequence = None
+fenable_sliding_window_for_seq = None
+fpopn = None
+fbegin_forward = None
+fend_forward = None
+fcommit_accepted_token_tree_nodes = None
+fattention_with_fuse_qkv = None
+fis_empty = None
+fdebug_get_kv = None
+
+ftranspose_append = None
+fcopy_cache = None
+fattn_prefill = None
+fattn_decode = None
+fattn_prefill_sliding_window = None
+fattn_decode_sliding_window = None
+fattn_prefill_ragged = None
+fattn_prefill_with_tree_mask = None
+fattn_prefill_with_tree_mask_paged_kv_cache = None
+fmerge_state = None
+fsplit_rotary = None
+fattention_rotary = None
+fcopy_single_page = None
+fcompact_copy = None
+
+
+def set_global_func(head_dim, dtype):
+    global fclear, fadd_sequence, fremove_sequence, ffork_sequence, 
fenable_sliding_window_for_seq
+    global fpopn, fbegin_forward, fend_forward, 
fcommit_accepted_token_tree_nodes
+    global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
+    global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
+    global fattn_prefill_ragged, fattn_prefill_with_tree_mask, 
fattn_prefill_with_tree_mask_paged_kv_cache
+    global fattn_prefill_sliding_window, fattn_decode_sliding_window
+    global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, 
fcompact_copy
+
+    fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
+    fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+    fremove_sequence = 
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+    ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+    fenable_sliding_window_for_seq = tvm.get_global_func(
+        "vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"
+    )
+    fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
+    fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+    fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
+    fcommit_accepted_token_tree_nodes = tvm.get_global_func(
+        "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
+    )
+    fattention_with_fuse_qkv = tvm.get_global_func(
+        "vm.builtin.attention_kv_cache_attention_with_fused_qkv"
+    )
+    fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
+    fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
+
+    target = tvm.target.Target.from_device(device)
+    builts = []
+    for tir_func in [
+        _kv_cache_transpose_append(num_kv_heads, head_dim, dtype),
+        _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype),
+        _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, 
False, rope_scaling),
+        _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, 
False, rope_scaling),
+        _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, 
True, rope_scaling),
+        _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, 
True, rope_scaling),
+        _attention_prefill_ragged_cpu(num_kv_heads, num_qo_heads, head_dim, 
dtype, rope_scaling),
+        tree_attn_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, 
rope_scaling),
+        tree_attn_with_paged_kv_cache_cpu(
+            num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling
+        ),
+        _merge_state_inplace_cpu(dtype),
+        llama_rope_with_position_map(
+            rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, 
dtype, rope_scaling
+        ),
+        _copy_single_page_cpu(num_kv_heads, page_size, head_dim, dtype),
+        _compact_kv_copy_cpu(num_kv_heads, head_dim, dtype),
+    ]:
+        mod = tvm.IRModule({"main": tir_func})
+        with target:
+            mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
+        f = tvm.build(mod["main"], target=target)
+        builts.append(f.entry_func)
+
+    (
+        ftranspose_append,
+        fcopy_cache,
+        fattn_prefill,
+        fattn_decode,
+        fattn_prefill_sliding_window,
+        fattn_decode_sliding_window,
+        fattn_prefill_ragged,
+        fattn_prefill_with_tree_mask,
+        fattn_prefill_with_tree_mask_paged_kv_cache,
+        fmerge_state,
+        fsplit_rotary,
+        fcopy_single_page,
+        fcompact_copy,
+    ) = builts
+
+
+def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
+    fcreate = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
+    cache = fcreate(
+        tvm.runtime.ShapeTuple(
+            [
+                reserved_nseq,
+                maximum_total_seq_length,
+                prefill_chunk_size,
+                page_size,
+                int(support_sliding_window),
+            ]
+        ),
+        tvm.runtime.ShapeTuple([0, num_layers]),
+        num_qo_heads,
+        num_kv_heads,
+        head_dim,
+        rope_mode,
+        rope_scale,
+        rope_theta,
+        tvm.nd.empty((), dtype, device=device),
+        ftranspose_append,
+        fattn_prefill,
+        fattn_decode,
+        fattn_prefill_sliding_window,
+        fattn_decode_sliding_window,
+        fattn_prefill_ragged,
+        fmerge_state,
+        fsplit_rotary,
+        fcopy_single_page,
+        fcopy_cache,
+        fcompact_copy,
+        fattn_prefill_with_tree_mask,
+        fattn_prefill_with_tree_mask_paged_kv_cache,
+        None,
+        False,
+    )
+    return cache
+
+
+class RopeMode(enum.IntEnum):
+    """The RoPE mode of the Paged KV cache.
+    If it is none, the KV cache will not apply RoPE to q and k.
+    If it is normal, RoPE will be applied to k before adding k to cache.
+    Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+    """
+
+    NONE = 0
+    NORMAL = 1
+    INLINE = 2
+
+
[email protected](
+    params=itertools.chain(
+        itertools.product(
+            [64, 128],
+            ["float32", "float16"],
+            [RopeMode.NORMAL],
+            [False],
+        ),
+        itertools.product(
+            [128],
+            ["float16"],
+            [RopeMode.NONE, RopeMode.INLINE],
+            [False, True],
+        ),
+    )
+)
+def kv_cache_and_config(request):
+    global head_dim, dtype
+    head_dim, dtype, rope_mode, support_sliding_window = request.param
+    set_global_func(head_dim, dtype)
+    return create_kv_cache(*request.param), rope_mode, support_sliding_window
+
+
+def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
+    for seq_id in seq_ids:
+        keys_expected = expected_k[seq_id]
+        values_expected = expected_v[seq_id]
+        assert keys_expected.shape == values_expected.shape
+        seq_length = expected_k[seq_id].shape[1]
+        keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device)
+        values = tvm.nd.empty(values_expected.shape, dtype=dtype, 
device=device)
+        fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values)
+        tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, 
atol=1e-3)
+        tvm.testing.assert_allclose(values.numpy(), values_expected, 
rtol=1e-3, atol=1e-3)
+
+
+def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = 
None):
+    # x: (N, H, D)
+    assert len(x.shape) == 3
+    nfeat = x.shape[-1]
+    nfeat_half = x.shape[-1] // 2
+    x = x.astype("float32")
+    y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1)
+
+    inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / 
nfeat))
+    t = (
+        np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype)
+        if offset_list is None
+        else (np.array(offset_list, dtype=inv_freq.dtype) + offset)
+    )
+    freqs = np.einsum("i,j->ij", t, inv_freq)
+    emb = np.concatenate((freqs, freqs), axis=-1)
+    cos_values = np.cos(emb)
+    sin_values = np.sin(emb)
+
+    return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj", 
sin_values, y)
+
+
+def apply_attention(
+    kv_cache,
+    rope_mode: RopeMode,
+    batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
+    cached_k: Dict[int, np.ndarray],
+    cached_v: Dict[int, np.ndarray],
+    sliding_window_sizes: Optional[List[int]] = None,
+    attn_sink_sizes: Optional[List[int]] = None,
+    token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
+    accepted_leaf_indices: Optional[List[int]] = None,
+) -> None:
+    seq_ids = []
+    append_lengths = []
+    for i, (seq_id, append_length) in enumerate(batch):
+        fork_parent_id = None
+        if isinstance(seq_id, tuple):
+            # Fork sequence
+            seq_id, fork_parent_id, fork_pos = seq_id
+            batch[i] = (seq_id, append_length)
+        seq_ids.append(seq_id)
+        append_lengths.append(append_length)
+        if fork_parent_id is not None:
+            assert fork_parent_id in cached_k
+            assert seq_id not in cached_k
+            ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+            if fork_pos == -1:
+                cached_k[seq_id] = cached_k[fork_parent_id]
+                cached_v[seq_id] = cached_v[fork_parent_id]
+            else:
+                cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
+                cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
+        elif seq_id not in cached_k:
+            fadd_sequence(kv_cache, seq_id)
+            cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
+            cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
+
+    flattened_token_tree_parent_ptr = None
+    token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in 
batch]
+
+    if token_tree_parent_ptr_list:
+        assert len(token_tree_node_depths_list) == len(seq_ids)
+        if accepted_leaf_indices is not None:
+            assert len(accepted_leaf_indices) == len(seq_ids)
+        flattened_token_tree_parent_ptr = []
+        for i, (token_tree_parent_ptr, append_length) in enumerate(
+            zip(token_tree_parent_ptr_list, append_lengths)
+        ):
+            assert len(token_tree_parent_ptr) >= append_length
+            # parent pointer for the last `append_length` nodes (the new 
tokens)
+            append_token_tree_parent_ptr = 
token_tree_parent_ptr[-append_length:]
+            flattened_token_tree_parent_ptr += append_token_tree_parent_ptr
+            token_tree_node_depths = []
+            for parent in token_tree_parent_ptr:
+                token_tree_node_depths.append(
+                    0 if parent == -1 else token_tree_node_depths[parent] + 1
+                )
+            # depth of each node in the tree (this contains more than the last 
`append_length` nodes)
+            token_tree_node_depths_list[i] = token_tree_node_depths
+
+    fbegin_forward(
+        kv_cache,
+        ShapeTuple(seq_ids),
+        ShapeTuple(append_lengths),
+        (
+            ShapeTuple(flattened_token_tree_parent_ptr)
+            if flattened_token_tree_parent_ptr is not None
+            else None
+        ),
+    )
+
+    global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype)
+    global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
+    global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
+
+    q_array = []
+    for i, (seq_id, append_length) in enumerate(batch):
+        new_q = np.random.rand(num_layers, append_length, num_qo_heads, 
head_dim).astype(dtype)
+        new_k = np.random.rand(num_layers, append_length, num_kv_heads, 
head_dim).astype(dtype)
+        new_v = np.random.rand(num_layers, append_length, num_kv_heads, 
head_dim).astype(dtype)
+        q_array.append(new_q)
+
+        rope_offset = cached_k[seq_id].shape[1]
+        if token_tree_parent_ptr_list is not None:
+            prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length
+            assert prev_tree_size >= 0
+            rope_offset -= prev_tree_size
+        cached_k[seq_id] = np.concatenate(
+            [
+                cached_k[seq_id],
+                np.stack(
+                    [
+                        (
+                            new_k[l]
+                            if rope_mode != RopeMode.NORMAL
+                            else f_apply_rotary(
+                                new_k[l],
+                                rope_offset,
+                                rope_scale,
+                                rope_theta,
+                                (
+                                    
token_tree_node_depths_list[i][-append_length:]
+                                    if token_tree_node_depths_list[i] is not 
None
+                                    else None
+                                ),
+                            )
+                        )
+                        for l in range(num_layers)
+                    ],
+                    axis=0,
+                ),
+            ],
+            axis=1,
+        )
+        cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1)
+        global_new_q = np.concatenate([global_new_q, new_q], axis=1)
+        global_new_k = np.concatenate([global_new_k, new_k], axis=1)
+        global_new_v = np.concatenate([global_new_v, new_v], axis=1)
+
+    for layer_id in range(num_layers):
+        queries_np = global_new_q[layer_id]
+        keys_np = global_new_k[layer_id]
+        values_np = global_new_v[layer_id]
+        qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], 
axis=1), device)
+        outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
+        fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+
+        # Compute attention expected results.
+        outputs = np.expand_dims(outputs.numpy(), axis=0)
+        sum_length = 0
+        for i, (seq_id, append_length) in enumerate(batch):
+            assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= 
append_length
+
+            rope_offset = cached_k[seq_id].shape[1]
+            if token_tree_parent_ptr_list is not None:
+                rope_offset -= len(token_tree_parent_ptr_list[i])
+            else:
+                rope_offset -= append_length
+            q_seq = (
+                q_array[i][layer_id]
+                if rope_mode == RopeMode.NONE
+                else f_apply_rotary(
+                    q_array[i][layer_id],
+                    rope_offset,
+                    rope_scale,
+                    rope_theta,
+                    (
+                        token_tree_node_depths_list[i][-append_length:]
+                        if token_tree_node_depths_list[i] is not None
+                        else None
+                    ),
+                )
+            ).transpose(1, 0, 2)
+            k_seq = (
+                cached_k[seq_id][layer_id]
+                if rope_mode != RopeMode.INLINE
+                else f_apply_rotary(
+                    cached_k[seq_id][layer_id],
+                    0,
+                    rope_scale,
+                    rope_theta,
+                    (
+                        (
+                            list(range(rope_offset))
+                            + [depth + rope_offset for depth in 
token_tree_node_depths_list[i]]
+                        )
+                        if token_tree_node_depths_list[i] is not None
+                        else None
+                    ),
+                )
+            ).transpose(1, 2, 0)
+            v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
+
+            k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0)
+            v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0)
+            softmax_input = (q_seq.astype("float32") @ 
k_seq.astype("float32")) / np.sqrt(head_dim)
+            softmax_shape = softmax_input.shape
+            assert softmax_shape[-2] == append_length
+            length_diff = softmax_shape[-1] - softmax_shape[-2]
+            assert length_diff >= 0
+            mask = np.tril(
+                np.full_like(softmax_input, np.finfo("float32").max), 
k=length_diff
+            ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), 
k=length_diff + 1)
+            if token_tree_parent_ptr_list is not None:
+                tree_size = len(token_tree_parent_ptr_list[i])
+                tree_mask = np.full(
+                    (tree_size, tree_size), np.finfo("float32").min, 
dtype="float32"
+                )
+                for i, parent in enumerate(token_tree_parent_ptr_list[i]):
+                    if parent != -1:
+                        tree_mask[i] = tree_mask[parent]
+                    tree_mask[i, i] = np.finfo("float32").max
+                tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, 
*tree_mask.shape))
+                mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :]
+
+            softmax_input = np.minimum(softmax_input, mask)
+
+            results = np.expand_dims(
+                (scipy.special.softmax(softmax_input, axis=-1) @ 
v_seq.astype("float32")).transpose(
+                    1, 0, 2
+                ),
+                axis=0,
+            ).astype(dtype)
+
+            tvm.testing.assert_allclose(
+                outputs[:, sum_length : sum_length + append_length, ...],
+                results,
+                rtol=1e-3,
+                atol=1e-3,
+            )
+            sum_length += append_length
+    fend_forward(kv_cache)
+
+    if accepted_leaf_indices is not None:
+        seq_ids = [seq_id for seq_id, _ in batch]
+        fcommit_accepted_token_tree_nodes(
+            kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
+        )
+        for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
+            zip(accepted_leaf_indices, batch)
+        ):
+            tree_path = []
+            node = accepted_leaf_idx
+            while node != -1:
+                tree_path.append(node)
+                node = token_tree_parent_ptr_list[i][node]
+            offset = cached_k[seq_id].shape[1] - append_length
+            length_to_pop = append_length - len(tree_path)
+            assert 0 <= length_to_pop <= append_length
+            for dst_pos, src_pos in enumerate(reversed(tree_path)):
+                if dst_pos == src_pos:
+                    continue
+                cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][
+                    :, offset + src_pos, ...
+                ]
+                cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][
+                    :, offset + src_pos, ...
+                ]
+            if length_to_pop > 0:
+                cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...]
+                cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...]
+
+    for seq_id, _ in batch:
+        if sliding_window_sizes is not None and len(sliding_window_sizes) > 
seq_id:
+            assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) 
> seq_id
+            sliding_window_size = sliding_window_sizes[seq_id]
+            attn_sink_size = attn_sink_sizes[seq_id]
+            if sliding_window_size == 0:
+                continue
+            if cached_k[seq_id].shape[1] > sliding_window_size:
+                # Apply sliding window and sink to cached kv.
+                length_to_slide = cached_k[seq_id].shape[1] - 
sliding_window_size
+                cached_k[seq_id] = np.concatenate(
+                    [
+                        cached_k[seq_id][:, :attn_sink_size, ...],
+                        cached_k[seq_id][:, attn_sink_size + length_to_slide 
:, ...],
+                    ],
+                    axis=1,
+                )
+                cached_v[seq_id] = np.concatenate(
+                    [
+                        cached_v[seq_id][:, :attn_sink_size, ...],
+                        cached_v[seq_id][:, attn_sink_size + length_to_slide 
:, ...],
+                    ],
+                    axis=1,
+                )
+                assert cached_k[seq_id].shape[1] == sliding_window_size
+
+    # Verify
+    verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
+
+
+def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    fclear(kv_cache)
+
+    # Prefill.
+    operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 
20)]]
+    operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
+    operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 
12), (7, 11)]]
+    # Decode
+    operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), 
(7, 1), (8, 1)]]
+    operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), 
(7, 1), (8, 1)]]
+    operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
+    operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
+
+    cached_k = {}
+    cached_v = {}
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+
+def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    fclear(kv_cache)
+
+    num_sequences = 5
+    batch = [(seq_id, 1) for seq_id in range(num_sequences)]
+    cached_k = {}
+    cached_v = {}
+    for seq_id_to_remove in range(num_sequences):
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+        # Remove sequence.
+        fremove_sequence(kv_cache, seq_id_to_remove)
+        cached_k.pop(seq_id_to_remove)
+        cached_v.pop(seq_id_to_remove)
+        verify_cached_kv(
+            kv_cache,
+            seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id != 
seq_id_to_remove],
+            expected_k=cached_k,
+            expected_v=cached_v,
+        )
+
+
+def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+    # Fork existing sequences.
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)], 
cached_k, cached_v)
+    # 0 <- 5 <- 6,8,9
+    # 0 <- 7
+    # 3 <- 4
+    # Mixture of decode and prefill.
+    operation_seq = [
+        [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
+        [(7, 1), (6, 1), (8, 1), (9, 1)],
+        [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)],
+        [(7, 10), (6, 2), (8, 3), (9, 4)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+    apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15), 
14)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17), 
19)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80), 
10)], cached_k, cached_v)
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
+        cached_k,
+        cached_v,
+    )
+
+    operation_seq = [
+        [(6, 1), (11, 1), (13, 1), (9, 1)],
+        [(10, 1), (16, 1), (18, 1), (19, 1)],
+        [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
+        [(10, 10), (6, 2), (8, 3), (19, 4)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+    num_sequence = 20
+    for i in range(num_sequence):
+        fremove_sequence(kv_cache, i)
+        cached_k.pop(i)
+        cached_v.pop(i)
+        verify_cached_kv(
+            kv_cache,
+            seq_ids=list(range(i + 1, num_sequence)),
+            expected_k=cached_k,
+            expected_v=cached_v,
+        )
+
+    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
+
+    # Test fork after page recycle
+    apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+    apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, 
cached_v)
+
+
+def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v)
+    # Fork existing sequences.
+    apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k, 
cached_v)
+    # 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11
+    #                |    |    |    |
+    #                4    6    8    10
+    # Decode.
+    operation_seq = [
+        [(3, 1), (6, 1), (9, 1)],
+        [(4, 1), (8, 1), (10, 1)],
+        [(5, 1), (7, 1), (11, 1)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+
+    num_sequence = 12
+    for i in range(num_sequence):
+        fremove_sequence(kv_cache, i)
+        cached_k.pop(i)
+        cached_v.pop(i)
+        verify_cached_kv(
+            kv_cache,
+            seq_ids=list(range(i + 1, num_sequence)),
+            expected_k=cached_k,
+            expected_v=cached_v,
+        )
+
+    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
+
+
+def test_paged_attention_kv_cache_popn(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
+
+    popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
+    for seq_id, pop_length in popn_operations:
+        fpopn(kv_cache, seq_id, pop_length)
+        if pop_length != 0:
+            cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
+            cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
+        verify_cached_kv(kv_cache, seq_ids=list(range(4)), 
expected_k=cached_k, expected_v=cached_v)
+
+    num_sequence = 5
+    for seq_id in range(num_sequence):
+        fremove_sequence(kv_cache, seq_id)
+        verify_cached_kv(
+            kv_cache,
+            seq_ids=list(range(seq_id + 1, num_sequence)),
+            expected_k=cached_k,
+            expected_v=cached_v,
+        )
+
+    assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
+
+
+def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    sliding_window_sizes = [20, 25, 30, 35, 40]
+    attn_sink_sizes = [6, 4, 8, 3, 7]
+    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+        zip(sliding_window_sizes, attn_sink_sizes)
+    ):
+        fadd_sequence(kv_cache, seq_id)
+        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
+        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+
+    # Prefill.
+    operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]]
+    operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]]
+    operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]]
+    for batch in operation_seq:
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            batch,
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+    # Decode
+    batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]
+    for _ in range(20):
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            batch,
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+
+
+def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    sliding_window_sizes = [30, 35, 40]
+    attn_sink_sizes = [15, 20, 25]
+    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+        zip(sliding_window_sizes, attn_sink_sizes)
+    ):
+        fadd_sequence(kv_cache, seq_id)
+        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
+        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 12), (1, 18), (2, 28)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [12, 18, 25+3]
+    sliding_window_sizes += [0, 0, 0]
+    attn_sink_sizes += [0, 0, 0]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [12, 18, 25+3, 18, 38, 36]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
+    sliding_window_sizes += [25]
+    attn_sink_sizes += [24]
+    ffork_sequence(kv_cache, 3, 6, 18)
+    fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], 
attn_sink_sizes[-1])
+    cached_k[6] = cached_k[3][::, :18]
+    cached_v[6] = cached_v[3][::, :18]
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(3, 10), (6, 12)],
+        cached_k,
+        cached_v,
+        sliding_window_sizes,
+        attn_sink_sizes,
+    )
+    # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
+
+
+def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    if rope_mode == RopeMode.INLINE:
+        # Inline RoPE mode is not supported for tree attention.
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    # Prefill 4 sequences
+    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
+    # Tree attention
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 7), (1, 15), (2, 10), (3, 14)],
+        cached_k,
+        cached_v,
+        token_tree_parent_ptr_list=[
+            [-1, 0, 0, 1, 1, 2, 2],  # complete binary tree of height 3
+            [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6],  # complete binary 
tree of height 4
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
+            [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9],  # two complete 
binary trees of height 3
+        ],
+        accepted_leaf_indices=[6, 11, 6, 13],
+    )
+    # Do 5 rounds of decode.
+    for _ in range(5):
+        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
+
+    # Test the cases where all trees are chains.
+    fclear(kv_cache)
+    cached_k = {}
+    cached_v = {}
+    # Prefill 4 sequences
+    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
+    # Tree attention
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 7), (1, 15), (2, 10), (3, 14)],
+        cached_k,
+        cached_v,
+        token_tree_parent_ptr_list=[
+            [-1, 0, 1, 2, 3, 4, 5],  # complete binary tree of height 7
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  # chain of 
length 15
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # chain of length 
14
+        ],
+        accepted_leaf_indices=[2, 6, -1, 4],
+    )
+    # Do 5 rounds of decode.
+    for _ in range(5):
+        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
+
+    # Test the cases of tree attn with cached kv.
+    fclear(kv_cache)
+    cached_k = {}
+    cached_v = {}
+    # Prefill 4 sequences
+    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
+    # Do 5 rounds of tree decode.
+    num_seq = 4
+    for i in range(5):
+        num_leaf_nodes = 2**i
+        parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)]
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)],
+            cached_k,
+            cached_v,
+            token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)],
+            accepted_leaf_indices=(
+                None if i != 4 else [2, 6, -1, 4]
+            ),  # Leaf nodes are committed all at once at the end.
+        )
+
+
+if __name__ == "__main__":
+    HEAD_DIMS = [64, 128]
+    DTYPES = ["float16", "float32"]
+    ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]
+    SUPPORT_SLIDING_WINDOW = [False, True]
+    for head_dim, dtype, rope_mode, support_sliding_window in 
itertools.product(
+        HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
+    ):
+        set_global_func(head_dim, dtype)
+        cache = create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window)
+        cache_and_config = (cache, rope_mode, support_sliding_window)
+        test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
+        test_paged_attention_kv_cache_remove_sequence(cache_and_config)
+        test_paged_attention_kv_cache_fork_sequence(cache_and_config)
+        test_paged_attention_kv_cache_popn(cache_and_config)
+        test_paged_attention_kv_cache_sliding_window(cache_and_config)
+        test_paged_attention_kv_cache_tree_attn(cache_and_config)
+        test_paged_attention_kv_cache_unlimited_depth(cache_and_config)

Reply via email to