This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b8f64c21c5 [Builtin] Sliding window and sink support for PagedKVCache 
(#16729)
b8f64c21c5 is described below

commit b8f64c21c5cce8ce0fa00e341f1e169f1fc59891
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 16 12:16:32 2024 -0400

    [Builtin] Sliding window and sink support for PagedKVCache (#16729)
    
    This PR supports sliding window attention and attention sink for
    PagedKVCache, so that PagedKVCache can back models such as Mistral.
    
    Meanwhile, this PR removes the "Attention" function (without
    fused-qkv) from AttentionKVCache interface, given its usage is now
    completely covered by the "AttentionWithFusedQKV" function.
    Considering the cost of maintenance, we decide to remove it for now.
    When in the future there is the need of this function, we will add
    it back.
    
    This PR also unifies the global function names of the PagedKVCache
    with the KVState introduced earlier, and introduces a new KV cache
    raw info query function to get the current total sequence length
    in the KV cache.
---
 src/runtime/relax_vm/kv_state.cc                   |  11 +-
 src/runtime/relax_vm/kv_state.h                    |  40 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 626 +++++++++++++--------
 ..._builtin_paged_attention_kv_cache_flashinfer.py | 110 ++--
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 557 +++++++++++-------
 5 files changed, 802 insertions(+), 542 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index 7c86e96ec6..05ba7c9650 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -45,19 +45,16 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
     .set_body_method<KVState>(&KVStateObj::EndForward);
 
 // Attention KV Cache methods
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
     
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetTotalSequenceLength);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
     
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
     .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
-TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
-    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
-                       double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
-                       NDArray v_data, NDArray o_data) {
-      kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
-                          NullOpt, std::move(o_data), 
attn_score_scaling_factor);
-    });
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
     .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
                        double attn_score_scaling_factor, NDArray qkv_data, 
NDArray o_data) {
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 5f824a84b1..2227944b86 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -122,36 +122,22 @@ class AttentionKVCacheObj : public KVStateObj {
    */
   virtual int32_t GetNumAvailablePages() const = 0;
 
-  /************** Attention **************/
+  /*! \brief Get the current total sequence length in the KV cache. */
+  virtual int32_t GetTotalSequenceLength() const = 0;
+
+  /************** Sequence Management **************/
 
   /*!
-   * \brief Compute attention with the given Q/K/V data at the specified
-   * layer with regard to the previously reserved append lengths.
-   * Q/K/V data are in layout `(total_length, num_heads, head_dim)`,
-   * where `total_length` is the sum of reserved append lengths.
-   * The returned attention result has the same layout as well.
-   * For example, say the KV cache contains 5 sequences. Before
-   * the current model forward, BeginForward is invoked for seq_ids
-   * `[3, 2]` and append_lengths [10, 20]. Then the leading dim of Q/K/V
-   * is 30, where [0, 10) corresponds to seq 3, and [10, 30)
-   * corresponds to seq 2.
-   * This method typically performs the following operations:
-   * - apply positional embeddings to Q/K data,
-   * - append K/V data to cache,
-   * - compute attention with the given Q and all history K/V
-   * for the corresponding sequences.
-   * The function writes attention output to `o_data`, conforming to
-   * the destination-passing style.
-   * \param layer_id The model layer where the attention compute happens.
-   * \param q_data The input Q data, in layout `(total_length, num_qo_heads, 
head_dim)`.
-   * \param k_data The input K data, in layout `(total_length, num_kv_heads, 
head_dim)`.
-   * \param v_data The input V data, in layout `(total_length, num_kv_heads, 
head_dim)`.
-   * \param mask The input mask data, in layout `(total_sqr_length)`.
-   * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
head_dim)`.
+   * \brief Enable sliding window attention for the given sequence.
+   * Error will be thrown when the KV cache does not support sliding window.
+   * \param seq_id The id of the sequence to enable sliding window for.
+   * \param sliding_window_size The sliding window size for the sequence.
+   * \param attn_sink_size The attention sink set for the sequence.
    */
-  virtual void Attention(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
-                         Optional<NDArray> mask, NDArray o_data,
-                         double attn_score_scaling_factor) = 0;
+  virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t 
sliding_window_size,
+                                         int32_t attn_sink_size) = 0;
+
+  /************** Attention **************/
 
   /*!
    * \brief Compute attention with Q/K/V data which are concatenated along
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 651fd4964c..0c64800cec 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -50,6 +50,8 @@ namespace relax_vm {
 constexpr const int kPagedKVCacheMaxBlockDepth = 5;
 /*! \brief The 8MB workspace size for attention auxiliary data. */
 constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
+/*! \brief The id of the temporary logical page, which is useful for sliding 
window. */
+constexpr const int kPagedKVCacheTempPageId = -1;
 
 /*!
  * \brief The block structure in paged KV cache with common prefix support.
@@ -72,8 +74,22 @@ struct Block {
   std::vector<int32_t> page_ids;
   /*! \brief The total sequence length in the block. */
   int32_t seq_length = 0;
-  /*! \brief The start position in sequence of this block. */
+  /*!
+   * \brief The start position in sequence of this block.
+   * This is the absolute position in the sequence for RoPE computation.
+   */
   int32_t start_pos = 0;
+  /*!
+   * \brief The current attention sink length of the block.
+   * It means the the **first** sink size elements will be pinned
+   * in the KV cache even when sliding window is enabled.
+   */
+  int32_t sink_length = 0;
+  /*!
+   * \brief The start offset of the sliding window in the block.
+   * It is always 0 when sliding window attn is not enabled.
+   */
+  int32_t sliding_window_offset = 0;
 
   /*! \brief The global index of the block. */
   const int32_t index;
@@ -115,6 +131,17 @@ struct Sequence {
    * It is the sum of lengths of all its blocks.
    */
   int32_t seq_length = 0;
+  /*!
+   * \brief The sliding window size of the sequence, or -1 if sliding window 
is not enabled.
+   * When a sequence is enabled for sliding window, it can no longer be forked.
+   */
+  int sliding_window_size = -1;
+  /*!
+   * \brief The attention sink size of the last block of the sequence.
+   * The **first** sink size elements of the last block will be pinned
+   * in the KV cache even when sliding window is enabled.
+   */
+  int last_block_attn_sink_size = 0;
 
   explicit Sequence(const std::vector<Block>& global_block_pool, int32_t 
last_block_idx) {
     this->last_block_idx = last_block_idx;
@@ -201,6 +228,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   const int64_t num_total_pages_;
   /*! \brief The maximum total sequence length in a prefill. */
   const int64_t prefill_chunk_size_;
+  /*! \brief A boolean flag indicating if the KV cache supports sliding 
window. */
+  const bool support_sliding_window_;
 
   /*! \brief The RoPE application mode of KV cache.*/
   const RoPEMode rope_mode_;
@@ -255,8 +284,17 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<NDArray> page_indptr_on_depths_device_;
   /*! \brief The indices array of page table. */
   std::vector<NDArray> page_indices_on_depths_device_;
-  /*! \brief The number of KV slots used in the last page of sequences. */
-  std::vector<NDArray> last_page_len_on_depths_device_;
+  /*!
+   * \brief The length information of the sequences.
+   * Each NDArray is in shape `(3, n)`. "n" is the number of sequences.
+   * For a sequence "i", location
+   * - "(0, i)" is the number of KV slots used in the last page of the seq 
("last_page_len"),
+   * - "(1, i)" is the starting offset of the sliding window in the seq,
+   * - "(2, i)" is the attn sink length of the sequence.
+   * \note When sliding window is not enabled, only the
+   * "last_page_len" (a.k.a., the first "n" elements) will be effectively used.
+   */
+  std::vector<NDArray> length_info_on_depths_device_;
   /*! \brief The k position offset of applying RoPE for each sequence. */
   std::vector<NDArray> k_rope_pos_offset_device_;
   /*!
@@ -293,6 +331,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   std::vector<std::vector<int32_t>> page_indptr_on_depths_host_;
   std::vector<std::vector<int32_t>> page_indices_on_depths_host_;
   std::vector<std::vector<int32_t>> last_page_len_on_depths_host_;
+  std::vector<std::vector<int32_t>> sliding_window_offset_on_depths_host_;
+  std::vector<std::vector<int32_t>> sink_size_on_depths_host_;
   std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths_host_;
   std::vector<int32_t> k_ragged_rope_pos_offset_host_;
   std::vector<int32_t> q_rope_position_map_host_;
@@ -316,22 +356,23 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<NDArray> qo_indptr_on_depths_view_;
   std::vector<NDArray> page_indptr_on_depths_view_;
   std::vector<NDArray> page_indices_on_depths_view_;
-  std::vector<NDArray> last_page_len_on_depths_view_;
+  std::vector<NDArray> length_info_on_depths_view_;
   std::vector<NDArray> k_rope_pos_offset_view_;
 
   PackedFunc f_transpose_append_;
   PackedFunc f_attention_prefill_;
   PackedFunc f_attention_decode_;
-  Optional<PackedFunc> f_attention_prefill_ragged_;
+  PackedFunc f_attention_prefill_sliding_window_;
+  PackedFunc f_attention_decode_sliding_window_;
+  PackedFunc f_attention_prefill_ragged_;
   Optional<PackedFunc> f_attention_prefill_ragged_begin_forward_;
   Optional<PackedFunc> f_attention_prefill_ragged_end_forward_;
   Optional<PackedFunc> f_attention_prefill_begin_forward_;
   Optional<PackedFunc> f_attention_prefill_end_forward_;
   Optional<PackedFunc> f_attention_decode_begin_forward_;
   Optional<PackedFunc> f_attention_decode_end_forward_;
-  Optional<PackedFunc> f_merge_inplace_;
+  PackedFunc f_merge_inplace_;
   PackedFunc f_split_rotary_;
-  PackedFunc f_rotary_inplace_;
   Optional<PackedFunc> f_debug_get_kv_;
 
   /*! \brief Number of fork depth in the current round of forward. */
@@ -354,18 +395,19 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   explicit PagedAttentionKVCacheObj(
       int64_t page_size,  //
       int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t 
head_dim,
-      int64_t reserved_num_seqs, int64_t num_total_pages, int64_t 
prefill_chunk_size,  //
-      RoPEMode rope_mode, double rotary_scale, double rotary_theta,            
        //
+      int64_t reserved_num_seqs, int64_t num_total_pages, int64_t 
prefill_chunk_size,
+      bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, 
double rotary_theta,
       DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
       PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
-      Optional<PackedFunc> f_attention_prefill_ragged,
+      PackedFunc f_attention_prefill_sliding_window, PackedFunc 
f_attention_decode_sliding_window,
+      PackedFunc f_attention_prefill_ragged,
       Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
       Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
       Optional<PackedFunc> f_attention_prefill_begin_forward,
       Optional<PackedFunc> f_attention_prefill_end_forward,
       Optional<PackedFunc> f_attention_decode_begin_forward,
-      Optional<PackedFunc> f_attention_decode_end_forward, 
Optional<PackedFunc> f_merge_inplace,
-      PackedFunc f_split_rotary, PackedFunc f_rotary_inplace, 
Optional<PackedFunc> f_debug_get_kv)
+      Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
+      PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
       : page_size_(page_size),
         num_layers_(num_layers),
         num_qo_heads_(num_qo_heads),
@@ -373,12 +415,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         head_dim_(head_dim),
         num_total_pages_(num_total_pages),
         prefill_chunk_size_(prefill_chunk_size),
-        rope_mode_(rope_mode),
+        support_sliding_window_(support_sliding_window),
+        rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? 
RoPEMode::kInline
+                                                                          : 
rope_mode),
         rotary_scale_(rotary_scale),
         rotary_theta_(rotary_theta),
         f_transpose_append_(std::move(f_transpose_append)),
         f_attention_prefill_(std::move(f_attention_prefill)),
         f_attention_decode_(std::move(f_attention_decode)),
+        
f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)),
+        
f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)),
         f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)),
         f_attention_prefill_ragged_begin_forward_(
             std::move(f_attention_prefill_ragged_begin_forward)),
@@ -389,7 +435,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
         
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
         f_merge_inplace_(std::move(f_merge_inplace)),
         f_split_rotary_(std::move(f_split_rotary)),
-        f_rotary_inplace_(std::move(f_rotary_inplace)),
         f_debug_get_kv_(std::move(f_debug_get_kv)),
         device_(device) {
     pages_.reserve(num_layers);
@@ -404,15 +449,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
       page_indices_on_depths_device_.push_back(
           NDArray::Empty({num_total_pages}, dtype_aux_, device));
-      last_page_len_on_depths_device_.push_back(
-          NDArray::Empty({reserved_num_seqs}, dtype_aux_, device));
+      length_info_on_depths_device_.push_back(
+          NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device));
       k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device));
       temp_attn_workspace_.push_back(
           NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), 
device));
       qo_indptr_on_depths_view_.push_back(NDArray());
       page_indptr_on_depths_view_.push_back(NDArray());
       page_indices_on_depths_view_.push_back(NDArray());
-      last_page_len_on_depths_view_.push_back(NDArray());
+      length_info_on_depths_view_.push_back(NDArray());
       k_rope_pos_offset_view_.push_back(NDArray());
     }
     // Additional workspace for the "prefill with ragged kv" kernel.
@@ -508,8 +553,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
         << "The parent sequence \"" << parent_seq_id << "\" cannot be found in 
KV cache.";
     CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
         << "The child sequence \"" << child_seq_id << "\" is already in the KV 
cache.";
-    CHECK(f_merge_inplace_.defined() && f_attention_prefill_ragged_.defined())
-        << "Attention merge-score function not available. ForkSequence is 
thereby not supported.";
+    CHECK_EQ(parent_it->second.sliding_window_size, -1)
+        << "The parent sequence \"" << parent_seq_id
+        << "\" is enabled with sliding window and thus cannot be forked.";
 
     int32_t parent_block_idx = parent_it->second.last_block_idx;
     ++global_block_pool_[parent_block_idx].external_ref_cnt;
@@ -522,6 +568,33 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     dirty_aux_data_device_ = true;
   }
 
+  void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
+                                 int32_t attn_sink_size) final {
+    CHECK(support_sliding_window_) << "The KV cache does not support sliding 
window.";
+    auto it = seq_map_.find(seq_id);
+    CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot 
be found in KV cache.";
+    CHECK_GE(attn_sink_size, 0)
+        << "The specified attention sink size is expected to be non negative";
+    CHECK_GT(sliding_window_size, 0) << "The specified sliding window size 
should be positive.";
+    CHECK_LT(attn_sink_size, sliding_window_size)
+        << "The attn sink size should be less than the sliding window size.";
+
+    // Set the sliding window flag of the sequence.
+    CHECK_EQ(it->second.sliding_window_size, -1)
+        << "A sequence cannot be enabled twice for sliding window.";
+
+    // Compute the total length of the prefix blocks of this sequence.
+    Block& last_block = global_block_pool_[it->second.last_block_idx];
+    int32_t prefix_length = it->second.seq_length - last_block.seq_length;
+    ICHECK_GE(prefix_length, 0);
+    // Since the prefix blocks cannot sliding, they are natural
+    // attention sinks here. When the prefix length is already
+    // larger than the specified attn sink size, we do not want to
+    // introduce more sink. Therefore, we update the given attn sink size.
+    it->second.last_block_attn_sink_size = std::max(attn_sink_size - 
prefix_length, 0);
+    it->second.sliding_window_size = sliding_window_size;
+  }
+
   void PopN(int64_t seq_id, int32_t n) final {
     auto it = seq_map_.find(seq_id);
     CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot 
be found in KV cache.";
@@ -546,7 +619,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
   /************** Raw Info Query **************/
 
-  int GetNumAvailablePages() const final { return free_page_ids_.size(); }
+  int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); }
+
+  int32_t GetTotalSequenceLength() const final {
+    int32_t total_seq_len = 0;
+    for (const auto& it : seq_map_) {
+      total_seq_len += it.second.seq_length;
+    }
+    return total_seq_len;
+  }
 
   /************** Attention **************/
 
@@ -558,15 +639,19 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     cur_append_lengths_ = append_lengths;
 
     // - Collect sequence/block/page information for attention.
-    std::vector<const Sequence*> sequences;
+    std::vector<Sequence*> sequences;
+    std::vector<int32_t> last_block_length_before_append;
     is_decode_request_ = true;
     sequences.reserve(cur_batch_size_);
+    last_block_length_before_append.reserve(cur_batch_size_);
     k_ragged_rope_pos_offset_host_.resize(cur_batch_size_);
     for (int i = 0; i < cur_batch_size_; ++i) {
       auto it = seq_map_.find(seq_ids[i]);
       CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
                                   << "\" cannot be found in KV cache.";
       sequences.push_back(&it->second);
+      last_block_length_before_append.push_back(
+          global_block_pool_[it->second.last_block_idx].seq_length);
       k_ragged_rope_pos_offset_host_[i] = it->second.seq_length;
       it->second.seq_length += append_lengths[i];
       if (append_lengths[i] != 1) {
@@ -587,13 +672,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       use_decode_kernel_.push_back(use_decode_kernel);
     }
 
-    append_before_attn_ = num_depths_ == 1 && use_decode_kernel_[0];
+    append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && 
use_decode_kernel_[0];
     if (append_before_attn_) {
       // Right now we use different kernels when depth is 1 or not 1.
       // For the case where maximum depth is 1, we create the auxiliary
       // data structure with regard to the page table after appending.
       for (int i = 0; i < cur_batch_size_; ++i) {
-        ReserveAppendLengthInBlock(sequences[i]->last_block_idx, 
append_lengths[i]);
+        ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
       }
     }
 
@@ -601,6 +686,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
     page_indptr_on_depths_host_.resize(num_depths_);
     page_indices_on_depths_host_.resize(num_depths_);
     last_page_len_on_depths_host_.resize(num_depths_);
+    sliding_window_offset_on_depths_host_.resize(num_depths_);
+    sink_size_on_depths_host_.resize(num_depths_);
     k_rope_pos_offset_on_depths_host_.resize(num_depths_);
 
     for (int d = 0; d < num_depths_; ++d) {
@@ -608,11 +695,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       std::vector<int32_t>& page_indptr_h = page_indptr_on_depths_host_[d];
       std::vector<int32_t>& page_indices_h = page_indices_on_depths_host_[d];
       std::vector<int32_t>& last_page_len_h = last_page_len_on_depths_host_[d];
+      std::vector<int32_t>& sliding_window_offset_h = 
sliding_window_offset_on_depths_host_[d];
+      std::vector<int32_t>& sink_size_h = sink_size_on_depths_host_[d];
       std::vector<int32_t>& k_rope_pos_offset_h = 
k_rope_pos_offset_on_depths_host_[d];
       qo_indptr_h.clear();
       page_indptr_h.clear();
       page_indices_h.clear();
       last_page_len_h.clear();
+      sliding_window_offset_h.clear();
+      sink_size_h.clear();
       k_rope_pos_offset_h.clear();
       qo_indptr_h.push_back(0);
       page_indptr_h.push_back(0);
@@ -621,13 +712,20 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         if (block_id == -1) {
           page_indptr_h.push_back(page_indptr_h.back());
           last_page_len_h.push_back(0);
+          sliding_window_offset_h.push_back(0);
+          sink_size_h.push_back(0);
           k_rope_pos_offset_h.push_back(0);
         } else {
           const Block& block = global_block_pool_[block_id];
           page_indptr_h.push_back(page_indptr_h.back() + 
block.page_ids.size());
           page_indices_h.insert(page_indices_h.end(), block.page_ids.begin(), 
block.page_ids.end());
-          last_page_len_h.push_back(
-              block.seq_length == 0 ? 0 : (block.seq_length - 1) % page_size_ 
+ 1);
+          last_page_len_h.push_back(block.seq_length == 0 ? 0
+                                                          : (block.seq_length 
- block.sink_length +
+                                                             
block.sliding_window_offset - 1) %
+                                                                    page_size_ 
+
+                                                                1);
+          sliding_window_offset_h.push_back(block.sliding_window_offset);
+          sink_size_h.push_back(block.sink_length);
           k_rope_pos_offset_h.push_back(block.start_pos);
         }
       }
@@ -638,7 +736,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
       // For the case where maximum depth is not 1, we create the auxiliary
       // data structure with regard to the page table before appending.
       for (int i = 0; i < cur_batch_size_; ++i) {
-        ReserveAppendLengthInBlock(sequences[i]->last_block_idx, 
append_lengths[i]);
+        ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
       }
     }
 
@@ -650,10 +748,26 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       int64_t append_length = append_lengths[i];
       const Block& block = global_block_pool_[sequences[i]->last_block_idx];
       for (int64_t pos = 0; pos < append_length; ++pos) {
-        int64_t pos_in_block = block.seq_length - append_length + pos;
-        q_rope_position_map_host_.push_back(sequences[i]->seq_length - 
append_length + pos);
-        append_position_map_host_.push_back(block.page_ids[pos_in_block / 
page_size_] * page_size_ +
-                                            pos_in_block % page_size_);
+        q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] 
+ pos);
+
+        int32_t pos_in_block = block.seq_length - append_length + pos;
+        if (last_block_length_before_append[i] + pos < block.sink_length) {
+          // The location to write is part of the attention sink.
+          int32_t offset_in_block = last_block_length_before_append[i] + pos;
+          append_position_map_host_.push_back(block.page_ids[offset_in_block / 
page_size_] *
+                                                  page_size_ +
+                                              offset_in_block % page_size_);
+        } else if (pos_in_block < block.sink_length) {
+          // The location to write is pinned by attn sink before the append.
+          // Therefore we cannot write into the location.
+          append_position_map_host_.push_back(-1);
+        } else {
+          // The location to write is in the sliding window.
+          int32_t offset_in_block = pos_in_block - block.sink_length + 
block.sliding_window_offset;
+          append_position_map_host_.push_back(block.page_ids[offset_in_block / 
page_size_] *
+                                                  page_size_ +
+                                              offset_in_block % page_size_);
+        }
       }
     }
   }
@@ -670,60 +784,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
-  void Attention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray 
v_data,
-                 Optional<NDArray> mask, NDArray o_data, double 
attn_score_scaling_factor) final {
-    // Part 1. Shape and dtype check.
-    NDArray pages = pages_[layer_id];
-    CHECK(q_data.DataType() == pages.DataType());
-    CHECK(k_data.DataType() == pages.DataType());
-    CHECK(v_data.DataType() == pages.DataType());
-    CHECK(o_data.DataType() == pages.DataType());
-
-    // q/o_data: (num_total_length, num_qo_heads, head_dim)
-    // k/v_data: (num_total_length, num_kv_heads, head_dim)
-
-    CHECK_EQ(q_data->ndim, 3);
-    CHECK_EQ(k_data->ndim, 3);
-    CHECK_EQ(v_data->ndim, 3);
-    CHECK_EQ(o_data->ndim, 3);
-    for (int dim = 0; dim < 3; ++dim) {
-      if (dim == 1) {
-        CHECK_EQ(q_data->shape[1], num_qo_heads_);
-        CHECK_EQ(k_data->shape[1], num_kv_heads_);
-        CHECK_EQ(v_data->shape[1], num_kv_heads_);
-        CHECK_EQ(o_data->shape[1], num_qo_heads_);
-      } else {
-        CHECK_EQ(k_data->shape[dim], q_data->shape[dim]);
-        CHECK_EQ(v_data->shape[dim], q_data->shape[dim]);
-        CHECK_EQ(o_data->shape[dim], q_data->shape[dim]);
-      }
-    }
-
-    CHECK_GT(q_data->shape[0], 0);
-    CHECK_EQ(q_data->shape[2], head_dim_);
-    int64_t total_seq_length = 0;
-    for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
-      total_seq_length += cur_append_lengths_[seq_id];
-    }
-    CHECK_EQ(total_seq_length, q_data->shape[0]);
-    // Sync the copy stream and the compute stream.
-    ComputeStreamWaitForCopyStream();
-    // The auxiliary data structure on device must have been synchronized.
-    ICHECK(!dirty_aux_data_device_);
-
-    if (rope_mode_ == RoPEMode::kNormal) {
-      // Apply rotary embedding to q/k data.
-      f_rotary_inplace_(q_data, k_data, cur_append_length_indptr_view_,
-                        k_ragged_rope_pos_offset_view_, cur_batch_size_, 
num_qo_heads_,
-                        num_kv_heads_, head_dim_, rotary_scale_, 
rotary_theta_);
-    }
-
-    // Part 3: append k/v data to kv-cache
-    f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
-    // Part 4: perform attention
-    AttentionInternal(layer_id, q_data, k_data, v_data, o_data, 
attn_score_scaling_factor);
-  }
-
   void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
                              NDArray o_data, double attn_score_scaling_factor) 
final {
     // Part 1. Shape and dtype check.
@@ -766,10 +826,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, 
v_data,
                     rope_mode_ == RoPEMode::kNormal);
 
-    // Part 3: append k/v data to kv-cache
-    f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
+    // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
+    if (append_before_attn_) {
+      f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
+    }
     // Part 4: perform attention
     AttentionInternal(layer_id, q_data, k_data, v_data, o_data, 
attn_score_scaling_factor);
+    // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not 
set.
+    if (!append_before_attn_) {
+      f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
+    }
   }
 
   NDArray GetQueryPositions() const final {
@@ -811,13 +877,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     append_position_map.reserve(seq.seq_length);
     for (int32_t block_id : trace) {
       const Block& block = global_block_pool_[block_id];
-      for (int i = 0; i < static_cast<int>(block.page_ids.size()); ++i) {
-        int32_t page_offset = i != static_cast<int>(block.page_ids.size()) - 1
-                                  ? page_size_
-                                  : ((block.seq_length - 1) % page_size_ + 1);
-        for (int32_t p = 0; p < page_offset; ++p) {
-          append_position_map.push_back(block.page_ids[i] * page_size_ + p);
-        }
+      for (int i = 0; i < block.seq_length; ++i) {
+        int32_t offset =
+            i < block.sink_length ? i : i - block.sink_length + 
block.sliding_window_offset;
+        int page_id = block.page_ids[offset / page_size_];
+        int page_offset = offset % page_size_;
+        append_position_map.push_back(page_id * page_size_ + page_offset);
       }
     }
     NDArray position_map_device =
@@ -864,30 +929,116 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   }
 
   /*!
-   * \brief Reserve extra append length in the given block, as
-   * preparation of the incoming KV cache append.
+   * \brief Slide the KV cache window of the given sequence when
+   * it has sliding window enabled.
+   * \param seq The sequence to be slidden when
+   */
+  void SlideWindowForSequence(Sequence* seq) {
+    // - No action when the sequence is not enabled for sliding window.
+    if (seq->sliding_window_size == -1) {
+      return;
+    }
+    // - No action when the sequence length does not exceed the window size.
+    if (seq->seq_length <= seq->sliding_window_size) {
+      return;
+    }
+
+    int32_t length_to_slide = seq->seq_length - seq->sliding_window_size;
+    // - Get the last block of the sequence.
+    Block& block = global_block_pool_[seq->last_block_idx];
+
+    // - If the attention sink exists and the last block has no previous
+    // sink length, it means this is the first time we slide the sequence,
+    // and thus we set the sink length of the last block, the index of the
+    // first sliding page, and starting offset in first sliding page.
+    if (seq->last_block_attn_sink_size > 0 && block.sink_length == 0) {
+      ICHECK_EQ(block.sliding_window_offset, 0);
+      block.sink_length = seq->last_block_attn_sink_size;
+      block.sliding_window_offset = seq->last_block_attn_sink_size;
+    }
+
+    // - The sink pages cannot be slidden.
+    int32_t num_sink_pages = (block.sink_length + page_size_ - 1) / page_size_;
+
+    // - Compute the first sliding page index and in-page sliding window
+    // start offset in the first sliding page after sliding.
+    int32_t page_idx_after_sliding = (block.sliding_window_offset + 
length_to_slide) / page_size_;
+    int32_t page_start_offset_after_sliding =
+        (block.sliding_window_offset + length_to_slide) % page_size_;
+
+    // - Free the pages that are fully slidden.
+    while (page_idx_after_sliding > num_sink_pages) {
+      if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) {
+        free_page_ids_.push_back(block.page_ids[num_sink_pages]);
+      }
+      block.page_ids.erase(block.page_ids.begin() + num_sink_pages);
+      --page_idx_after_sliding;
+    }
+    // - The first sliding page after sliding is either the last sink page,
+    // or the page next to the last sink page.
+    ICHECK(page_idx_after_sliding == num_sink_pages - 1 ||
+           page_idx_after_sliding == num_sink_pages);
+
+    // - Update the length of the sequence and the block.
+    seq->seq_length = seq->sliding_window_size;
+    block.seq_length -= length_to_slide;
+    block.sliding_window_offset =
+        page_idx_after_sliding * page_size_ + page_start_offset_after_sliding;
+    ICHECK_GE(block.seq_length, block.sink_length);
+    ICHECK_GE(block.sliding_window_offset, block.sink_length);
+    ICHECK_EQ(
+        (block.sliding_window_offset + (block.seq_length - block.sink_length) 
+ page_size_ - 1) /
+            page_size_,
+        block.page_ids.size());
+  }
+
+  /*!
+   * \brief Reserve extra append length in the last block of the given
+   * sequence, as preparation of the incoming KV cache append.
    * New pages will be allocated to the block until the total
    * capacity can cover the current sequence length (before reservation)
    * plus the required append length.
    * \param block_idx The index of the block to process.
    * \param append_length The extra append length to reserve for the block.
+   * \note We apply sliding window in this function.
    */
-  void ReserveAppendLengthInBlock(int32_t block_idx, int64_t append_length) {
+  void ReserveAppendLengthInSeq(Sequence* seq, int64_t append_length) {
+    int32_t block_idx = seq->last_block_idx;
     Block& block = global_block_pool_[block_idx];
     CHECK_GT(append_length, 0) << "Append with length 0 is not allowed.";
     CHECK_EQ(block.external_ref_cnt, 0)
         << "The block is " << block.external_ref_cnt
         << "-time referenced by other blocks, thus cannot accept new KV 
values.";
 
+    // ==================== Reserve ====================
     // The reservation is based on the current sequence length.
     // If "current sequence + append length" does not exceed the
     // current capacity (number of pages * page size), no action is taken.
     int64_t cur_npage = block.page_ids.size();
-    int64_t tgt_npage = (block.seq_length + append_length + page_size_ - 1) / 
page_size_;
+    int64_t tgt_npage = (block.seq_length - block.sink_length + 
block.sliding_window_offset +
+                         append_length + page_size_ - 1) /
+                        page_size_;
     for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) {
-      block.page_ids.push_back(GetFreePage());
+      // When sliding window is enabled for the seq, we can "borrow temporary 
pages (-1)",
+      // since the pages need to be slidden out might not have been released.
+      if (free_page_ids_.empty() && seq->sliding_window_size != -1) {
+        block.page_ids.push_back(kPagedKVCacheTempPageId);
+      } else {
+        block.page_ids.push_back(GetFreePage());
+      }
     }
     block.seq_length += append_length;
+
+    // ==================== Slide ====================
+    // Slide the sequences so that the pages exceed the sliding window are 
released.
+    SlideWindowForSequence(seq);
+    for (int i = 0; i < static_cast<int>(block.page_ids.size()); ++i) {
+      if (block.page_ids[i] == kPagedKVCacheTempPageId) {
+        // Re-allocate the temporary pages after sliding window release.
+        block.page_ids[i] = GetFreePage();
+      }
+    }
+
     dirty_aux_data_device_ = true;
   }
 
@@ -901,7 +1052,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
    * vectors from the lowest depth to the highest depth.
    */
   std::vector<std::vector<int32_t>> GetBlockIdsOnDepth(
-      const std::vector<const Sequence*>& sequences) const {
+      const std::vector<Sequence*>& sequences) const {
     // - Get the trace of each sequence.
     int64_t num_depths = 0;
     std::vector<std::vector<int32_t>> seq_block_traces;
@@ -987,14 +1138,19 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
 
     if (append_before_attn_) {
-      f_attention_decode_begin_forward_.value()(
-          /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0],
-          last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
-          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
+      if (!support_sliding_window_) {
+        f_attention_decode_begin_forward_.value()(
+            /*depth=*/0, temp_attn_workspace_[1], 
page_indptr_on_depths_view_[0],
+            length_info_on_depths_view_[0], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
+            /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
+      }
     } else {
       f_attention_prefill_ragged_begin_forward_.value()(
           temp_attn_workspace_[0], cur_append_length_indptr_view_, 
cur_batch_size_, num_qo_heads_,
           num_kv_heads_, head_dim_, copy_stream_);
+      if (support_sliding_window_) {
+        return;
+      }
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
           continue;
@@ -1002,12 +1158,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         if (use_decode_kernel_[d]) {
           f_attention_decode_begin_forward_.value()(
               d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d],
-              last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
+              length_info_on_depths_view_[d], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
               /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
         } else {
           f_attention_prefill_begin_forward_.value()(
               /*depth=*/d, temp_attn_workspace_[d + 1], 
qo_indptr_on_depths_view_[d],
-              last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, 
num_kv_heads_, head_dim_,
+              length_info_on_depths_view_[d]->shape[0], num_qo_heads_, 
num_kv_heads_, head_dim_,
               copy_stream_);
         }
       }
@@ -1020,23 +1176,26 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
    */
   void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
                          NDArray output, double attn_score_scaling_factor) {
+    PackedFunc f_prefill =
+        !support_sliding_window_ ? f_attention_prefill_ : 
f_attention_prefill_sliding_window_;
+    PackedFunc f_decode =
+        !support_sliding_window_ ? f_attention_decode_ : 
f_attention_decode_sliding_window_;
     CHECK_GE(num_depths_, 1) << "The number of effective depths must be 
greater or equal to 1.";
     if (append_before_attn_) {
-      f_attention_decode_(
+      f_decode(
           /*depth=*/0, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[0],
-          page_indices_on_depths_view_[0], last_page_len_on_depths_view_[0],
+          page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
           k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, 
merged_attn_scores_view_,
           /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
           attn_score_scaling_factor);
     } else {
       // Compute appended text self-attention
-      f_attention_prefill_ragged_.value()(q_data, 
cur_append_length_indptr_view_, k_data, v_data,
-                                          cur_append_length_indptr_view_, 
q_rope_position_map_view_,
-                                          k_ragged_rope_pos_offset_view_, 
output,
-                                          merged_attn_scores_view_,
-                                          /*causal=*/1,
-                                          /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline,
-                                          rotary_scale_, rotary_theta_, 
attn_score_scaling_factor);
+      f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, 
k_data, v_data,
+                                  cur_append_length_indptr_view_, 
q_rope_position_map_view_,
+                                  k_ragged_rope_pos_offset_view_, output, 
merged_attn_scores_view_,
+                                  /*causal=*/1,
+                                  /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline, rotary_scale_,
+                                  rotary_theta_, attn_score_scaling_factor);
 
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -1044,25 +1203,25 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         }
         if (use_decode_kernel_[d]) {
           // Use decode kernel for depth d
-          f_attention_decode_(/*depth=*/d, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[d],
-                              page_indices_on_depths_view_[d], 
last_page_len_on_depths_view_[d],
-                              k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
-                              temp_attn_output_view_, temp_attn_scores_view_,
-                              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_,
-                              rotary_theta_, attn_score_scaling_factor);
+          f_decode(/*depth=*/d, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[d],
+                   page_indices_on_depths_view_[d], 
length_info_on_depths_view_[d],
+                   k_rope_pos_offset_view_[d], q_rope_position_map_view_, 
temp_attn_output_view_,
+                   temp_attn_scores_view_,
+                   /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_, rotary_theta_,
+                   attn_score_scaling_factor);
         } else {
           // Use prefill kernel for depth d
-          f_attention_prefill_(
+          f_prefill(
               /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], 
pages_[layer_id],
               page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
-              last_page_len_on_depths_view_[d], k_rope_pos_offset_view_[d],
-              q_rope_position_map_view_, temp_attn_output_view_, 
temp_attn_scores_view_,
+              length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
+              temp_attn_output_view_, temp_attn_scores_view_,
               /*causal=*/0,
               /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
               attn_score_scaling_factor);
         }
-        f_merge_inplace_.value()(output, merged_attn_scores_view_, 
temp_attn_output_view_,
-                                 temp_attn_scores_view_);
+        f_merge_inplace_(output, merged_attn_scores_view_, 
temp_attn_output_view_,
+                         temp_attn_scores_view_);
       }
     }
   }
@@ -1074,10 +1233,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       return;
     }
     // - Sync NDArrays to GPU.
-    SyncAuxArrayToDevice(qo_indptr_on_depths_host_, 
page_indptr_on_depths_host_,
-                         page_indices_on_depths_host_, 
last_page_len_on_depths_host_,
-                         k_rope_pos_offset_on_depths_host_, 
k_ragged_rope_pos_offset_host_,
-                         q_rope_position_map_host_, append_position_map_host_);
+    SyncAuxArrayToDevice();
     KernelBeginForward();
     // - Clear the dirty flag.
     dirty_aux_data_device_ = false;
@@ -1089,24 +1245,43 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, 
compute_stream_);
   }
 
+  /*!
+   * \brief Copy a vector of data to the input NDArray.
+   * It optionally supports specifying the shape of copy and the element
+   * offset to the destination NDArray.
+   */
+  void CopyVecDataToArray(NDArray array, int32_t* vec_data, 
Optional<ShapeTuple> shape = NullOpt,
+                          int dst_elem_offset = 0) {
+    DLTensor copy_dst = *array.operator->();
+    if (shape.defined()) {
+      ICHECK_EQ(shape.value().size(), 1);
+      copy_dst.ndim = 1;
+      copy_dst.shape = shape.value()->data;
+    }
+    copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t);
+
+    DLTensor copy_src;
+    copy_src.data = vec_data;
+    copy_src.device = Device{kDLCPU, 0};
+    copy_src.ndim = 1;
+    copy_src.dtype = array->dtype;
+    copy_src.shape = copy_dst.shape;
+    copy_src.strides = nullptr;
+    copy_src.byte_offset = 0;
+    NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
+  }
+
   /*!
    * \brief Synchronize auxiliary arrays to device.
    * \note This method resets the dirty flag to false, and needs to be
    * invoked before running attention computation on device.
    */
-  void SyncAuxArrayToDevice(std::vector<std::vector<int32_t>> 
qo_indptr_on_depths,
-                            std::vector<std::vector<int32_t>> 
page_indptr_on_depths,
-                            std::vector<std::vector<int32_t>> 
page_indices_on_depths,
-                            std::vector<std::vector<int32_t>> 
last_page_len_on_depths,
-                            std::vector<std::vector<int32_t>> 
k_rope_pos_offset_on_depths,
-                            std::vector<int32_t> k_ragged_rope_pos_offset,
-                            std::vector<int32_t> q_rope_position_map,
-                            std::vector<int32_t> append_position_map) {
+  void SyncAuxArrayToDevice() {
     ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt);
-    ICHECK_EQ(qo_indptr_on_depths.size(), num_depths_);
-    ICHECK_EQ(page_indptr_on_depths.size(), num_depths_);
-    ICHECK_EQ(page_indices_on_depths.size(), num_depths_);
-    ICHECK_EQ(last_page_len_on_depths.size(), num_depths_);
+    ICHECK_EQ(qo_indptr_on_depths_host_.size(), num_depths_);
+    ICHECK_EQ(page_indptr_on_depths_host_.size(), num_depths_);
+    ICHECK_EQ(page_indices_on_depths_host_.size(), num_depths_);
+    ICHECK_EQ(last_page_len_on_depths_host_.size(), num_depths_);
     int64_t total_append_length = 0;
     int num_sequences = cur_append_lengths_.size();
     cur_append_lengths_indptr_host_.resize(num_sequences + 1);
@@ -1116,83 +1291,92 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i];
     }
     total_append_length = cur_append_lengths_indptr_host_.back();
-    ICHECK_EQ(total_append_length, append_position_map.size());
-
-    auto fcopy_from_vec = [copy_stream = this->copy_stream_](NDArray array, 
int32_t* vec_data) {
-      DLTensor copy_dst = *array.operator->();
-      DLTensor copy_src;
-      copy_src.data = vec_data;
-      copy_src.device = Device{kDLCPU, 0};
-      copy_src.ndim = 1;
-      copy_src.dtype = array->dtype;
-      copy_src.shape = array->shape;
-      copy_src.strides = nullptr;
-      copy_src.byte_offset = 0;
-      NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream);
-    };
+    ICHECK_EQ(total_append_length, append_position_map_host_.size());
 
     // 1. qo_indptr_on_depths
     for (int d = 0; d < num_depths_; ++d) {
       qo_indptr_on_depths_view_[d] = qo_indptr_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(qo_indptr_on_depths[d].size())}, dtype_aux_);
-      fcopy_from_vec(qo_indptr_on_depths_view_[d], 
qo_indptr_on_depths[d].data());
+          {static_cast<int64_t>(qo_indptr_on_depths_host_[d].size())}, 
dtype_aux_);
+      CopyVecDataToArray(qo_indptr_on_depths_view_[d], 
qo_indptr_on_depths_host_[d].data());
     }
 
     // 2. page_indptr_on_depths
     for (int d = 0; d < num_depths_; ++d) {
-      ICHECK_EQ(page_indptr_on_depths[d].size(), 
qo_indptr_on_depths[d].size());
+      ICHECK_EQ(page_indptr_on_depths_host_[d].size(), 
qo_indptr_on_depths_host_[d].size());
       page_indptr_on_depths_view_[d] = 
page_indptr_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(page_indptr_on_depths[d].size())}, dtype_aux_);
-      fcopy_from_vec(page_indptr_on_depths_view_[d], 
page_indptr_on_depths[d].data());
+          {static_cast<int64_t>(page_indptr_on_depths_host_[d].size())}, 
dtype_aux_);
+      CopyVecDataToArray(page_indptr_on_depths_view_[d], 
page_indptr_on_depths_host_[d].data());
     }
 
     // 3. page_indices_on_depths
     for (int d = 0; d < num_depths_; ++d) {
-      ICHECK_EQ(page_indices_on_depths[d].size(), 
page_indptr_on_depths[d].back());
+      ICHECK_EQ(page_indices_on_depths_host_[d].size(), 
page_indptr_on_depths_host_[d].back());
       page_indices_on_depths_view_[d] = 
page_indices_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(page_indices_on_depths[d].size())}, 
dtype_aux_);
-      if (!page_indices_on_depths[d].empty()) {
-        fcopy_from_vec(page_indices_on_depths_view_[d], 
page_indices_on_depths[d].data());
+          {static_cast<int64_t>(page_indices_on_depths_host_[d].size())}, 
dtype_aux_);
+      if (!page_indices_on_depths_host_[d].empty()) {
+        CopyVecDataToArray(page_indices_on_depths_view_[d], 
page_indices_on_depths_host_[d].data());
       }
     }
 
-    // 4. last_page_len_on_depths
+    // 4. length_info_on_depths
+    // last_page_len_on_depths_host_;
+    // sliding_window_offset_on_depths_host_;
+    // sink_size_on_depths_host_;
     for (int d = 0; d < num_depths_; ++d) {
-      ICHECK_EQ(last_page_len_on_depths[d].size() + 1, 
qo_indptr_on_depths[d].size());
-      last_page_len_on_depths_view_[d] = 
last_page_len_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(last_page_len_on_depths[d].size())}, 
dtype_aux_);
-      fcopy_from_vec(last_page_len_on_depths_view_[d], 
last_page_len_on_depths[d].data());
+      int num_seq_on_layer = 
static_cast<int>(qo_indptr_on_depths_host_[d].size()) - 1;
+      ICHECK_EQ(last_page_len_on_depths_host_[d].size(), num_seq_on_layer);
+      ICHECK_EQ(sliding_window_offset_on_depths_host_[d].size(), 
num_seq_on_layer);
+      ICHECK_EQ(sink_size_on_depths_host_[d].size(), num_seq_on_layer);
+      if (!support_sliding_window_) {
+        // Sliding window is not enabled, so we first copy "last_page_len".
+        length_info_on_depths_view_[d] =
+            length_info_on_depths_device_[d].CreateView({num_seq_on_layer}, 
dtype_aux_);
+        CopyVecDataToArray(length_info_on_depths_view_[d], 
last_page_len_on_depths_host_[d].data());
+      } else {
+        // Sliding window is enabled,
+        length_info_on_depths_view_[d] =
+            length_info_on_depths_device_[d].CreateView({3, num_seq_on_layer}, 
dtype_aux_);
+        ShapeTuple copy_shape{num_seq_on_layer};
+        CopyVecDataToArray(length_info_on_depths_view_[d], 
last_page_len_on_depths_host_[d].data(),
+                           copy_shape);
+        CopyVecDataToArray(length_info_on_depths_view_[d],
+                           sliding_window_offset_on_depths_host_[d].data(), 
copy_shape,
+                           /*dst_elem_offset=*/num_seq_on_layer);
+        CopyVecDataToArray(length_info_on_depths_view_[d], 
sink_size_on_depths_host_[d].data(),
+                           copy_shape, /*dst_elem_offset=*/2 * 
num_seq_on_layer);
+      }
     }
 
-    // 5. k_rope_pos_offset
+    // 5. k_rope_pos_offset_on_depths
     for (int d = 0; d < num_depths_; ++d) {
-      ICHECK_EQ(k_rope_pos_offset_on_depths[d].size() + 1, 
qo_indptr_on_depths[d].size());
+      ICHECK_EQ(k_rope_pos_offset_on_depths_host_[d].size() + 1,
+                qo_indptr_on_depths_host_[d].size());
       k_rope_pos_offset_view_[d] = k_rope_pos_offset_device_[d].CreateView(
-          {static_cast<int64_t>(k_rope_pos_offset_on_depths[d].size())}, 
dtype_aux_);
-      fcopy_from_vec(k_rope_pos_offset_view_[d], 
k_rope_pos_offset_on_depths[d].data());
+          {static_cast<int64_t>(k_rope_pos_offset_on_depths_host_[d].size())}, 
dtype_aux_);
+      CopyVecDataToArray(k_rope_pos_offset_view_[d], 
k_rope_pos_offset_on_depths_host_[d].data());
     }
 
     // 6. cur_append_lengths_indptr
     cur_append_length_indptr_view_ =
         cur_append_length_indptr_device_.CreateView({num_sequences + 1}, 
dtype_aux_);
-    fcopy_from_vec(cur_append_length_indptr_view_, 
cur_append_lengths_indptr_host_.data());
+    CopyVecDataToArray(cur_append_length_indptr_view_, 
cur_append_lengths_indptr_host_.data());
 
     // 7. k_ragged_rope_pos_offset
-    ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);
+    ICHECK_EQ(k_ragged_rope_pos_offset_host_.size(), num_sequences);
     k_ragged_rope_pos_offset_view_ =
         k_ragged_rope_pos_offset_device_.CreateView({num_sequences}, 
dtype_aux_);
-    fcopy_from_vec(k_ragged_rope_pos_offset_view_, 
k_ragged_rope_pos_offset.data());
+    CopyVecDataToArray(k_ragged_rope_pos_offset_view_, 
k_ragged_rope_pos_offset_host_.data());
 
     // 8. q_rope_position_map
-    ICHECK_EQ(q_rope_position_map.size(), total_append_length);
+    ICHECK_EQ(q_rope_position_map_host_.size(), total_append_length);
     q_rope_position_map_view_ =
         q_rope_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
-    fcopy_from_vec(q_rope_position_map_view_, q_rope_position_map.data());
+    CopyVecDataToArray(q_rope_position_map_view_, 
q_rope_position_map_host_.data());
 
     // 9. append_position_map
     append_position_map_view_ =
         append_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
-    fcopy_from_vec(append_position_map_view_, append_position_map.data());
+    CopyVecDataToArray(append_position_map_view_, 
append_position_map_host_.data());
 
     // 10. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
@@ -1218,6 +1402,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
                        int64_t num_kv_heads, int64_t head_dim, int rope_mode, 
double rotary_scale,
                        double rotary_theta, NDArray init, PackedFunc 
f_transpose_append,
                        PackedFunc f_attention_prefill, PackedFunc 
f_attention_decode,
+                       PackedFunc f_attention_prefill_sliding_window,  //
+                       PackedFunc f_attention_decode_sliding_window,
                        PackedFunc f_attention_prefill_ragged,
                        PackedFunc f_attention_prefill_ragged_begin_forward,
                        PackedFunc f_attention_prefill_ragged_end_forward,
@@ -1225,25 +1411,30 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
                        PackedFunc f_attention_prefill_end_forward,
                        PackedFunc f_attention_decode_begin_forward,
                        PackedFunc f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, PackedFunc f_rotary_inplace,
-                       Optional<PackedFunc> f_debug_get_kv) {
-      CHECK_EQ(cache_config.size(), 4);
+                       PackedFunc f_split_rotary, Optional<PackedFunc> 
f_debug_get_kv) {
+      CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
       int64_t prefill_chunk_size = cache_config[2];
       int64_t page_size = cache_config[3];
+      bool support_sliding_window = cache_config[4];
       int64_t num_total_pages = (total_token_capacity + page_size - 1) / 
page_size;
+      if (support_sliding_window) {
+        // When sliding window is enabled, each sequence may use two more 
pages at most.
+        num_total_pages += reserved_num_seqs * 2;
+      }
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
           page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
-          num_total_pages, prefill_chunk_size, RoPEMode(rope_mode), 
rotary_scale, rotary_theta,
-          init->dtype, init->device, std::move(f_transpose_append), 
std::move(f_attention_prefill),
-          std::move(f_attention_decode), std::move(f_attention_prefill_ragged),
+          num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
+          rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
+          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_attention_prefill_sliding_window),
+          std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_ragged_begin_forward),
           std::move(f_attention_prefill_ragged_end_forward),
           std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_rotary_inplace),
-          std::move(f_debug_get_kv));
+          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_debug_get_kv));
       return AttentionKVCache(std::move(n));
     });
 
@@ -1252,62 +1443,33 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
                        int64_t num_kv_heads, int64_t head_dim, int rope_mode, 
double rotary_scale,
                        double rotary_theta, NDArray init, PackedFunc 
f_transpose_append,
                        PackedFunc f_attention_prefill, PackedFunc 
f_attention_decode,
+                       PackedFunc f_attention_prefill_sliding_window,
+                       PackedFunc f_attention_decode_sliding_window,
                        PackedFunc f_attention_prefill_ragged, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, PackedFunc f_rotary_inplace,
-                       Optional<PackedFunc> f_debug_get_kv) {
-      CHECK_EQ(cache_config.size(), 4);
+                       PackedFunc f_split_rotary, Optional<PackedFunc> 
f_debug_get_kv) {
+      CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
       int64_t prefill_chunk_size = cache_config[2];
       int64_t page_size = cache_config[3];
+      bool support_sliding_window = cache_config[4];
       int64_t num_total_pages = (total_token_capacity + page_size - 1) / 
page_size;
+      if (support_sliding_window) {
+        // When sliding window is enabled, each sequence may use two more 
pages at most.
+        num_total_pages += reserved_num_seqs * 2;
+      }
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
           page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
-          num_total_pages, prefill_chunk_size, RoPEMode(rope_mode), 
rotary_scale, rotary_theta,
-          init->dtype, init->device, std::move(f_transpose_append), 
std::move(f_attention_prefill),
-          std::move(f_attention_decode), 
std::move(f_attention_prefill_ragged),  //
-          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
  //
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_rotary_inplace),
-          std::move(f_debug_get_kv));
+          num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
+          rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
+          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_attention_prefill_sliding_window),
+          std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),  //
+          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
                 //
+          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_debug_get_kv));
       return AttentionKVCache(std::move(n));
     });
 
-// Keep the following global functions for backward compatibility.
-// TODO(tvm-team): Remove these global functions in the future.
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Clear);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::AddSequence);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_remove_sequence")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::RemoveSequence);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_fork_sequence")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::ForkSequence);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_popn")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::PopN);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_num_available_pages")
-    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::BeginForward);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EndForward);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
-    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
-    .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
-    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
-                       double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
-                       NDArray v_data, NDArray o_data) {
-      kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
-                          NullOpt, std::move(o_data), 
attn_score_scaling_factor);
-    });
-TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv")
-    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
-                       double attn_score_scaling_factor, NDArray qkv_data, 
NDArray o_data) {
-      kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, 
std::move(o_data),
-                                      attn_score_scaling_factor);
-    });
-
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 967e71ecd3..d30ccd0224 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -63,7 +63,6 @@ fattention_decode_end_forward = None
 fattention_prefill_ragged_begin_forward = None
 fattention_prefill_ragged_end_forward = None
 fattention_merge_state = None
-fattention_rotary = None
 
 ftranspose_append = None
 fsplit_rotary = None
@@ -231,39 +230,42 @@ def set_global_func():
     global fattention_prefill_ragged
     global fattention_prefill_ragged_begin_forward
     global fattention_prefill_ragged_end_forward
-    global fattention_merge_state, fsplit_rotary, fattention_rotary
+    global fattention_merge_state, fsplit_rotary
     global ftranspose_append, fcopy_cache
 
-    fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear")
+    fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
     fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
-    fadd_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence")
-    fremove_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_remove_sequence")
-    ffork_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_fork_sequence")
-    fpopn = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_popn")
-    fbegin_forward = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_begin_forward")
-    fend_forward = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_end_forward")
-    fattention = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_attention")
+    fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+    fremove_sequence = 
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+    ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+    fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
+    fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+    fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
     fattention_with_fuse_qkv = tvm.get_global_func(
-        "vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv"
+        "vm.builtin.attention_kv_cache_attention_with_fused_qkv"
     )
-    fdebug_get_kv = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_debug_get_kv")
+    fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
 
-    fattention_prefill = 
tvm.get_global_func("paged_kv_cache.attention_kernel_prefill")
-    fattention_decode = 
tvm.get_global_func("paged_kv_cache.attention_kernel_decode")
+    fattention_prefill = tvm.get_global_func(
+        "flashinfer.attention_kernel_prefill_with_paged_kv_cache"
+    )
+    fattention_decode = tvm.get_global_func(
+        "flashinfer.attention_kernel_decode_with_paged_kv_cache"
+    )
     fattention_prefill_ragged = tvm.get_global_func(
         "flashinfer.attention_kernel_prefill_with_ragged_kv_cache"
     )
     fattention_prefill_begin_forward = tvm.get_global_func(
-        "paged_kv_cache.attention_kernel_prefill_begin_forward"
+        "flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"
     )
     fattention_prefill_end_forward = tvm.get_global_func(
-        "paged_kv_cache.attention_kernel_prefill_end_forward"
+        "flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"
     )
     fattention_decode_begin_forward = tvm.get_global_func(
-        "paged_kv_cache.attention_kernel_decode_begin_forward"
+        "flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"
     )
     fattention_decode_end_forward = tvm.get_global_func(
-        "paged_kv_cache.attention_kernel_decode_end_forward"
+        "flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"
     )
     fattention_prefill_ragged_begin_forward = tvm.get_global_func(
         
"flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"
@@ -272,7 +274,6 @@ def set_global_func():
         "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"
     )
     fattention_merge_state = 
tvm.get_global_func("flashinfer.merge_state_in_place")
-    fattention_rotary = 
tvm.get_global_func("flashinfer.batch_qk_apply_rotary_in_place")
 
     target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")
     builts = []
@@ -293,9 +294,16 @@ def set_global_func():
 
 
 def create_kv_cache(rope_mode):
+    support_sliding_window = 0
     cache = fcreate(
         tvm.runtime.ShapeTuple(
-            [reserved_nseq, maximum_total_seq_length, prefill_chunk_size, 
page_size]
+            [
+                reserved_nseq,
+                maximum_total_seq_length,
+                prefill_chunk_size,
+                page_size,
+                support_sliding_window,
+            ]
         ),
         num_layers,
         num_qo_heads,
@@ -308,6 +316,8 @@ def create_kv_cache(rope_mode):
         ftranspose_append,
         fattention_prefill,
         fattention_decode,
+        fattention_prefill,
+        fattention_decode,
         fattention_prefill_ragged,
         fattention_prefill_ragged_begin_forward,
         fattention_prefill_ragged_end_forward,
@@ -317,7 +327,6 @@ def create_kv_cache(rope_mode):
         fattention_decode_end_forward,
         fattention_merge_state,
         fsplit_rotary,
-        fattention_rotary,
         fcopy_cache,
     )
     return cache
@@ -378,7 +387,6 @@ def apply_attention(
     batch: List[Tuple[Union[int, Tuple[int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
-    fuse_qkv: bool,
 ) -> None:
     seq_ids = []
     append_lengths = []
@@ -442,16 +450,9 @@ def apply_attention(
         queries_np = global_new_q[layer_id]
         keys_np = global_new_k[layer_id]
         values_np = global_new_v[layer_id]
-        if not fuse_qkv:
-            queries = tvm.nd.array(queries_np, device=device)
-            keys = tvm.nd.array(keys_np, device=device)
-            values = tvm.nd.array(values_np, device=device)
-            outputs = tvm.nd.empty(queries.shape, dtype, device=device)
-            fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
-        else:
-            qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, 
values_np], axis=1), device)
-            outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
-            fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+        qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], 
axis=1), device)
+        outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
+        fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
 
         # Compute attention expected results.
         outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -509,8 +510,7 @@ def apply_attention(
 
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode, 
fuse_qkv):
+def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode):
     kv_cache, rope_mode = kv_cache_and_rope_mode
     fclear(kv_cache)
 
@@ -527,12 +527,11 @@ def 
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode, fus
     cached_k = {}
     cached_v = {}
     for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, 
fuse_qkv):
+def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode):
     kv_cache, rope_mode = kv_cache_and_rope_mode
     fclear(kv_cache)
 
@@ -541,7 +540,7 @@ def 
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, fuse_q
     cached_k = {}
     cached_v = {}
     for seq_id_to_remove in range(num_sequences):
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
         # Remove sequence.
         fremove_sequence(kv_cache, seq_id_to_remove)
         cached_k.pop(seq_id_to_remove)
@@ -555,22 +554,21 @@ def 
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, fuse_q
 
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, 
fuse_qkv):
+def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
     kv_cache, rope_mode = kv_cache_and_rope_mode
     fclear(kv_cache)
 
     cached_k = {}
     cached_v = {}
     batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, fuse_qkv)
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
     # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v, 
fuse_qkv)
+    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
     # Mixture of decode and prefill.
     operation_seq = [
         [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -579,20 +577,19 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, fuse_qkv
         [(7, 10), (6, 2), (8, 3), (9, 4)],
     ]
     for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
+def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
     kv_cache, rope_mode = kv_cache_and_rope_mode
     fclear(kv_cache)
 
     cached_k = {}
     cached_v = {}
     batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v, 
fuse_qkv)
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
 
     popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 19)]
     for seq_id, pop_length in popn_operations:
@@ -607,8 +604,7 @@ if __name__ == "__main__":
     set_global_func()
     for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
         cache = create_kv_cache(rope_mode)
-        for fuse_qkv in [False, True]:
-            test_paged_attention_kv_cache_prefill_and_decode((cache, 
rope_mode), fuse_qkv)
-            test_paged_attention_kv_cache_remove_sequence((cache, rope_mode), 
fuse_qkv)
-            test_paged_attention_kv_cache_fork_sequence((cache, rope_mode), 
fuse_qkv)
-            test_paged_attention_kv_cache_popn((cache, rope_mode), fuse_qkv)
+        test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode))
+        test_paged_attention_kv_cache_remove_sequence((cache, rope_mode))
+        test_paged_attention_kv_cache_fork_sequence((cache, rope_mode))
+        test_paged_attention_kv_cache_popn((cache, rope_mode))
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 34e9d51715..64887ca5b6 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -17,7 +17,7 @@
 import enum
 import itertools
 import math
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
 
 import numpy as np
 import pytest
@@ -49,10 +49,10 @@ fclear = None
 fadd_sequence = None
 fremove_sequence = None
 ffork_sequence = None
+fenable_sliding_window_for_seq = None
 fpopn = None
 fbegin_forward = None
 fend_forward = None
-fattention = None
 fattention_with_fuse_qkv = None
 fdebug_get_kv = None
 
@@ -60,6 +60,8 @@ ftranspose_append = None
 fcopy_cache = None
 fattn_prefill = None
 fattn_decode = None
+fattn_prefill_sliding_window = None
+fattn_decode_sliding_window = None
 fattn_prefill_ragged = None
 fmerge_state = None
 fsplit_rotary = None
@@ -67,37 +69,41 @@ fattention_rotary = None
 
 
 def set_global_func(head_dim, dtype):
-    global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn
-    global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, 
fdebug_get_kv
+    global fclear, fadd_sequence, fremove_sequence, ffork_sequence, 
fenable_sliding_window_for_seq
+    global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, 
fdebug_get_kv
     global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, 
fattn_prefill_ragged
+    global fattn_prefill_sliding_window, fattn_decode_sliding_window
     global fmerge_state, fsplit_rotary, fattention_rotary
 
-    fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear")
-    fadd_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence")
-    fremove_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_remove_sequence")
-    ffork_sequence = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_fork_sequence")
-    fpopn = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_popn")
-    fbegin_forward = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_begin_forward")
-    fend_forward = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_end_forward")
-    fattention = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_attention")
+    fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
+    fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
+    fremove_sequence = 
tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
+    ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
+    fenable_sliding_window_for_seq = tvm.get_global_func(
+        "vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"
+    )
+    fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
+    fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
+    fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
     fattention_with_fuse_qkv = tvm.get_global_func(
-        "vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv"
+        "vm.builtin.attention_kv_cache_attention_with_fused_qkv"
     )
-    fdebug_get_kv = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_debug_get_kv")
+    fdebug_get_kv = 
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
 
     target = tvm.target.Target("cuda")
     builts = []
     for tir_func in [
         kv_cache_transpose_append(head_dim, dtype),
         copy_cache(head_dim, dtype),
-        _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, 
target),
-        _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, target),
+        _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, False, 
target),
+        _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, False, 
target),
+        _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, 
target),
+        _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, 
target),
         _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype, 
target),
         _merge_state_inplace(num_qo_heads, head_dim, dtype, target),
         llama_rope_with_position_map(
             rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
         ),
-        _inplace_rope(rope_theta, rope_scale, head_dim, num_qo_heads, 
num_kv_heads, dtype),
     ]:
         mod = tvm.IRModule({"main": tir_func})
         with target:
@@ -110,18 +116,25 @@ def set_global_func(head_dim, dtype):
         fcopy_cache,
         fattn_prefill,
         fattn_decode,
+        fattn_prefill_sliding_window,
+        fattn_decode_sliding_window,
         fattn_prefill_ragged,
         fmerge_state,
         fsplit_rotary,
-        fattention_rotary,
     ) = builts
 
 
-def create_kv_cache(head_dim, dtype, rope_mode):
+def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
     fcreate = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
     cache = fcreate(
         tvm.runtime.ShapeTuple(
-            [reserved_nseq, maximum_total_seq_length, prefill_chunk_size, 
page_size]
+            [
+                reserved_nseq,
+                maximum_total_seq_length,
+                prefill_chunk_size,
+                page_size,
+                int(support_sliding_window),
+            ]
         ),
         num_layers,
         num_qo_heads,
@@ -134,10 +147,11 @@ def create_kv_cache(head_dim, dtype, rope_mode):
         ftranspose_append,
         fattn_prefill,
         fattn_decode,
+        fattn_prefill_sliding_window,
+        fattn_decode_sliding_window,
         fattn_prefill_ragged,
         fmerge_state,
         fsplit_rotary,
-        fattention_rotary,
         fcopy_cache,
     )
     return cache
@@ -156,17 +170,26 @@ class RopeMode(enum.IntEnum):
 
 
 @pytest.fixture(
-    params=itertools.product(
-        [64, 128],
-        ["float16", "float32"],
-        [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE],
+    params=itertools.chain(
+        itertools.product(
+            [64, 128],
+            ["float16", "float32"],
+            [RopeMode.NORMAL],
+            [False],
+        ),
+        itertools.product(
+            [128],
+            ["float16"],
+            [RopeMode.NONE, RopeMode.INLINE],
+            [False, True],
+        ),
     )
 )
-def kv_cache_and_rope_mode(request):
+def kv_cache_and_config(request):
     global head_dim, dtype
-    head_dim, dtype, rope_mode = request.param
+    head_dim, dtype, rope_mode, support_sliding_window = request.param
     set_global_func(head_dim, dtype)
-    return create_kv_cache(*request.param), rope_mode
+    return create_kv_cache(*request.param), rope_mode, support_sliding_window
 
 
 def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
@@ -206,7 +229,8 @@ def apply_attention(
     batch: List[Tuple[Union[int, Tuple[int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
-    fuse_qkv: bool,
+    sliding_window_sizes: Optional[List[int]] = None,
+    attn_sink_sizes: Optional[List[int]] = None,
 ) -> None:
     seq_ids = []
     append_lengths = []
@@ -270,16 +294,9 @@ def apply_attention(
         queries_np = global_new_q[layer_id]
         keys_np = global_new_k[layer_id]
         values_np = global_new_v[layer_id]
-        if not fuse_qkv:
-            queries = tvm.nd.array(queries_np, device=device)
-            keys = tvm.nd.array(keys_np, device=device)
-            values = tvm.nd.array(values_np, device=device)
-            outputs = tvm.nd.empty(queries.shape, dtype, device=device)
-            fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
-        else:
-            qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, 
values_np], axis=1), device)
-            outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
-            fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+        qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], 
axis=1), device)
+        outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
+        fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
 
         # Compute attention expected results.
         outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -332,15 +349,40 @@ def apply_attention(
             sum_length += append_length
     fend_forward(kv_cache)
 
+    for seq_id, _ in batch:
+        if sliding_window_sizes is not None and len(sliding_window_sizes) > 
seq_id:
+            sliding_window_size = sliding_window_sizes[seq_id]
+            attn_sink_size = attn_sink_sizes[seq_id]
+            if cached_k[seq_id].shape[1] > sliding_window_size:
+                # Apply sliding window and sink to cached kv.
+                length_to_slide = cached_k[seq_id].shape[1] - 
sliding_window_size
+                cached_k[seq_id] = np.concatenate(
+                    [
+                        cached_k[seq_id][:, :attn_sink_size, ...],
+                        cached_k[seq_id][:, attn_sink_size + length_to_slide 
:, ...],
+                    ],
+                    axis=1,
+                )
+                cached_v[seq_id] = np.concatenate(
+                    [
+                        cached_v[seq_id][:, :attn_sink_size, ...],
+                        cached_v[seq_id][:, attn_sink_size + length_to_slide 
:, ...],
+                    ],
+                    axis=1,
+                )
+                assert cached_k[seq_id].shape[1] == sliding_window_size
+
     # Verify
     verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
 
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode, 
fuse_qkv):
-    kv_cache, rope_mode = kv_cache_and_rope_mode
+def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
     fclear(kv_cache)
 
     # Prefill.
@@ -356,14 +398,16 @@ def 
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_rope_mode, fus
     cached_k = {}
     cached_v = {}
     for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, 
fuse_qkv):
-    kv_cache, rope_mode = kv_cache_and_rope_mode
+def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
     fclear(kv_cache)
 
     num_sequences = 5
@@ -371,7 +415,7 @@ def 
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, fuse_q
     cached_k = {}
     cached_v = {}
     for seq_id_to_remove in range(num_sequences):
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
         # Remove sequence.
         fremove_sequence(kv_cache, seq_id_to_remove)
         cached_k.pop(seq_id_to_remove)
@@ -386,22 +430,24 @@ def 
test_paged_attention_kv_cache_remove_sequence(kv_cache_and_rope_mode, fuse_q
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, 
fuse_qkv):
-    kv_cache, rope_mode = kv_cache_and_rope_mode
+def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
     fclear(kv_cache)
 
     cached_k = {}
     cached_v = {}
     batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, fuse_qkv)
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
     # Fork existing sequences.
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v, 
fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v, 
fuse_qkv)
+    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((5, 0), 20)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((6, 5), 102)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((7, 0), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((8, 5), 71)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((9, 5), 20)], cached_k, cached_v)
     # Mixture of decode and prefill.
     operation_seq = [
         [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
@@ -410,7 +456,7 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, fuse_qkv
         [(7, 10), (6, 2), (8, 3), (9, 4)],
     ]
     for batch in operation_seq:
-        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, 
fuse_qkv)
+        apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
 
     for i in range(9, -1, -1):
         fremove_sequence(kv_cache, i)
@@ -421,16 +467,17 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, fuse_qkv
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
[email protected]("fuse_qkv", [False, True])
-def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
-    kv_cache, rope_mode = kv_cache_and_rope_mode
+def test_paged_attention_kv_cache_popn(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        return
     fclear(kv_cache)
 
     cached_k = {}
     cached_v = {}
     batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
-    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, fuse_qkv)
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v, 
fuse_qkv)
+    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
 
     popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)]
     for seq_id, pop_length in popn_operations:
@@ -441,6 +488,83 @@ def 
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
         verify_cached_kv(kv_cache, seq_ids=list(range(4)), 
expected_k=cached_k, expected_v=cached_v)
 
 
[email protected]_gpu
[email protected]_cuda
+def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if not support_sliding_window or rope_mode == RopeMode.NORMAL:
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    sliding_window_sizes = [20, 25, 30, 35, 40]
+    attn_sink_sizes = [6, 4, 8, 3, 7]
+    for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
+        zip(sliding_window_sizes, attn_sink_sizes)
+    ):
+        fadd_sequence(kv_cache, seq_id)
+        fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, 
attn_sink_size)
+        cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+        cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), 
dtype)
+
+    # Prefill.
+    operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]]
+    operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]]
+    operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]]
+    for batch in operation_seq:
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            batch,
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+    # Decode
+    batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]
+    for _ in range(20):
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            batch,
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+
+    # Sliding window with fork
+    sliding_window_sizes += [0, 18]
+    attn_sink_sizes += [0, 12]
+    apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
+    ffork_sequence(kv_cache, 5, 6)
+    cached_k[6] = cached_k[5]
+    cached_v[6] = cached_v[5]
+    fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], 
attn_sink_sizes[-1])
+    for _ in range(2):
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            [(6, 10)],
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+    for _ in range(16):
+        apply_attention(
+            kv_cache,
+            rope_mode,
+            [(6, 1)],
+            cached_k,
+            cached_v,
+            sliding_window_sizes,
+            attn_sink_sizes,
+        )
+
+
 def kv_cache_transpose_append(head_dim, dtype):
     @T.prim_func
     def _kv_cache_transpose_append(
@@ -458,22 +582,23 @@ def kv_cache_transpose_append(head_dim, dtype):
         position_map = T.match_buffer(var_position_map, (ntoken,), "int32")
 
         for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
-            with T.block("k_transpose_append"):
-                vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
-                T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
-                T.writes(pages[position_map[vgpos] // 16, 0, vh, 
position_map[vgpos] % 16, vf])
-                position: T.int64 = T.Cast("int64", position_map[vgpos])
-                pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 
16), vf] = k_data[
-                    vgpos, vh, vf
-                ]
-            with T.block("v_transpose_append"):
-                vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
-                T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
-                T.writes(pages[position_map[vgpos] // 16, 1, vh, 
position_map[vgpos] % 16, vf])
-                position: T.int64 = T.Cast("int64", position_map[vgpos])
-                pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 
16), vf] = v_data[
-                    vgpos, vh, vf
-                ]
+            if position_map[global_pos] != T.int32(-1):
+                with T.block("k_transpose_append"):
+                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
+                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
+                    T.writes(pages[position_map[vgpos] // 16, 0, vh, 
position_map[vgpos] % 16, vf])
+                    position: T.int64 = T.Cast("int64", position_map[vgpos])
+                    pages[T.floordiv(position, 16), 0, vh, 
T.floormod(position, 16), vf] = k_data[
+                        vgpos, vh, vf
+                    ]
+                with T.block("v_transpose_append"):
+                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
+                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
+                    T.writes(pages[position_map[vgpos] // 16, 1, vh, 
position_map[vgpos] % 16, vf])
+                    position: T.int64 = T.Cast("int64", position_map[vgpos])
+                    pages[T.floordiv(position, 16), 1, vh, 
T.floormod(position, 16), vf] = v_data[
+                        vgpos, vh, vf
+                    ]
 
     return _kv_cache_transpose_append
 
@@ -488,7 +613,6 @@ def copy_cache(head_dim, dtype):
         layer_id: T.int64,
     ):
         num_kv_heads = T.int64()
-        head_dim = T.int64()
         seqlen = T.SizeVar("seqlen", "int64")
         page_size = T.int64()
         num_pages = T.int64()
@@ -517,74 +641,6 @@ def copy_cache(head_dim, dtype):
     return _copy_cache
 
 
-def _inplace_rope(
-    theta: float,
-    scale: float,
-    head_dim: int,
-    num_q_heads: int,
-    num_kv_heads: int,
-    dtype: str,
-):
-    rotary_dim = head_dim
-
-    def _rope(
-        x: T.Buffer,
-        s: tir.Var,
-        h: tir.Var,
-        d: tir.Var,
-        rope_offset: tir.Var,
-        instance_offset: tir.Var,
-    ):
-        cos_freq, sin_freq = rope_freq((s + rope_offset) * scale, d, 
rotary_dim, theta, dtype)
-        cos = cos_freq * x[s + instance_offset, h, d]
-        sin = sin_freq * tir.if_then_else(
-            d < rotary_dim // 2,
-            -x[s + instance_offset, h, d + rotary_dim // 2],
-            x[s + instance_offset, h, d - rotary_dim // 2],
-        )
-        return cos + sin
-
-    # fmt: off
-    @T.prim_func
-    def tir_rotary(
-        var_q: T.handle,
-        var_k: T.handle,
-        var_append_len_indptr: T.handle,
-        var_rope_offsets: T.handle,
-        _0: T.int32,
-        _1: T.int32,
-        _2: T.int32,
-        _3: T.int32,
-        _4: T.float32,
-        _5: T.float32,
-    ):
-        T.func_attr({"tir.is_scheduled": 1})
-        total_len = T.int32()
-        batch_size = T.int32()
-        q = T.match_buffer(var_q, (total_len, num_q_heads, head_dim), dtype)
-        k = T.match_buffer(var_k, (total_len, num_kv_heads, head_dim), dtype)
-        rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32")
-        append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size 
+ 1,), "int32")
-        for b_h in T.thread_binding(batch_size * (num_q_heads + num_kv_heads), 
thread="blockIdx.x"):
-            b: T.int32 = b_h // (num_q_heads + num_kv_heads)
-            h: T.int32 = b_h % (num_q_heads + num_kv_heads)
-            instance_offset: T.int32 = append_len_indptr[b]
-            rope_offset: T.int32 = rope_offsets[b]
-            append_len: T.int32 = append_len_indptr[b + 1] - 
append_len_indptr[b]
-            for s0 in range(T.ceildiv(append_len, 32)):
-                for s1 in T.thread_binding(32, thread="threadIdx.y"):
-                    for d0 in T.thread_binding(T.ceildiv(head_dim, 4), 
thread="threadIdx.x"):
-                        for d1 in T.vectorized(4):
-                            s: T.int32 = s0 * 32 + s1
-                            d: T.int32 = d0 * 4 + d1
-                            if s < append_len and d < head_dim:
-                                if h < num_q_heads:
-                                    q[s + instance_offset, h, d] = _rope(q, s, 
h, d, rope_offset, instance_offset)
-                                else:
-                                    k[s + instance_offset, h - num_q_heads, d] 
= _rope(k, s, h - num_q_heads, d, rope_offset, instance_offset)
-    return tir_rotary
-
-
 def llama_rope_with_position_map(  # pylint: disable=too-many-arguments
     theta: float,
     scale: float,
@@ -721,6 +777,47 @@ def _var(dtype):
     return T.alloc_buffer((1,), dtype, scope="local")
 
 
+def _causal_mask(causal, row, col, kv_len, qo_len):
+    return T.if_then_else(
+        causal > 0,
+        col < kv_len - qo_len + row + 1,
+        col < kv_len,
+    )
+
+
+def _declare_length_info(var_length_info, batch_size, sliding_window):
+    return (
+        T.match_buffer(var_length_info, (3, batch_size), "int32")
+        if sliding_window
+        else T.match_buffer(var_length_info, (batch_size,), "int32")
+    )
+
+
+def _get_kv_chunk_len(num_pages, page_size, seq_id, length_info, 
sliding_window):
+    if not sliding_window:
+        return (num_pages - 1) * page_size + length_info[seq_id]
+    else:
+        # ((num_pages - 1) * page_size + last_page_len) - 
sliding_window_offset + sink_size
+        return (
+            (num_pages - 1) * page_size
+            + length_info[0, seq_id]
+            - length_info[1, seq_id]
+            + length_info[2, seq_id]
+        )
+
+
+def _get_seq_offset(pos, seq_id, length_info, sliding_window):
+    if not sliding_window:
+        return pos
+    else:
+        # pos if pos < sink_size else pos - sink_size + sliding_window_offset
+        return T.if_then_else(
+            pos < length_info[2, seq_id],
+            pos,
+            pos - length_info[2, seq_id] + length_info[1, seq_id],
+        )
+
+
 def get_max_num_threads_per_block(target: Target):
     """
     max(max_num_threads, max_threads_per_block); if latter does not exist, 
return max_num_threads.
@@ -733,7 +830,9 @@ def get_max_num_threads_per_block(target: Target):
     return max(max_num_threads, max_threads_per_block)
 
 
-def _attention_prefill(h_kv, h_q, d, dtype, target: Target):  # pylint: 
disable=unused-argument
+def _attention_prefill(
+    h_kv, h_q, d, dtype, sliding_window: bool, target: Target
+):  # pylint: disable=unused-argument
     # pylint: disable=invalid-name
     NUM_BLKS = 16
     LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8)  # 8 bytes
@@ -753,13 +852,6 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
         tile_z = 8
         num_warps = 2
 
-    def mask(causal, row, col, kv_len, qo_len):
-        return T.if_then_else(
-            causal > 0,
-            col < kv_len - qo_len + row + 1,
-            col < kv_len,
-        )
-
     # pylint: disable=line-too-long,too-many-arguments,too-many-branches
     # fmt: off
     @T.prim_func
@@ -770,7 +862,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
         var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d]
         var_page_indptr: T.handle, # [batch_size + 1]
         var_page_values: T.handle, # [nnz_pages]
-        var_last_page_len: T.handle, # [b]
+        var_length_info: T.handle, # [b] when sliding window = False, or 
otherwise [3, b]
         var_k_rope_pos_offset: T.handle, # [b]
         var_q_rope_position: T.handle, # [total_len]
         var_output: T.handle, # [total_len, h_q, d]
@@ -791,11 +883,19 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
         pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
         page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32")
         page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32")
-        last_page_len = T.match_buffer(var_last_page_len, (batch_size,), 
"int32")
         k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32")
         q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32")
         output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
         lse = T.match_buffer(var_lse, (total_len, h_q), "float32")  # pylint: 
disable=unused-variable
+        # The length information of the sequences.
+        # - It is in shape `(3, batch_size)` when sliding window is enabled.
+        #   For a sequence "i", location
+        #   - "(0, i)" is the number of KV slots used in the last page of the 
seq ("last_page_len"),
+        #   - "(1, i)" is the starting offset of the sliding window in the seq,
+        #   - "(2, i)" is the attn sink length of the sequence.
+        # - It is in shape `(batch_size,)` when sliding window is disabled,
+        #   denoting the "last_page_len".
+        length_info = _declare_length_info(var_length_info, batch_size, 
sliding_window)
 
         # kernel code
         for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
@@ -851,10 +951,9 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
 
                                     cur_page_indptr_begin: T.int32 = 
page_indptr[b_idx]
                                     cur_page_indptr_end: T.int32 = 
page_indptr[b_idx + 1]
-                                    cur_last_page_len: T.int32 = 
last_page_len[b_idx]
                                     kv_chunk_len[0] = T.if_then_else(
                                         cur_page_indptr_begin != 
cur_page_indptr_end,
-                                        (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * 16 + cur_last_page_len,
+                                        _get_kv_chunk_len(cur_page_indptr_end 
- cur_page_indptr_begin, 16, b_idx, length_info, sliding_window),
                                         0
                                     )
                                     T.tvm_storage_sync("shared")
@@ -899,8 +998,9 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
                                                 T.writes()
                                                 cur_L = L_kv_start + i
                                                 if cur_L < kv_chunk_len[0]:
-                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, 16)]  # type: ignore
-                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, 16)  # type: ignore
+                                                    seq_offset: 
T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, 
sliding_window)  # type: ignore
+                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(seq_offset, 16)]  # type: ignore
+                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(seq_offset, 16)  # type: ignore
                                                     K_smem[i, j] = 
T.if_then_else(
                                                         rotary_mode == 1,
                                                         _rope(pages, 
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, 
page_offset, j), dtype),
@@ -916,8 +1016,9 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
                                                 T.writes()
                                                 cur_L = L_kv_start + i
                                                 if cur_L < kv_chunk_len[0]:
-                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, 16)]  # type: ignore
-                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, 16)  # type: ignore
+                                                    seq_offset: 
T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, 
sliding_window)  # type: ignore
+                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(seq_offset, 16)]  # type: ignore
+                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(seq_offset, 16)  # type: ignore
                                                     V_smem[i, j] = 
pages[page_no, 1, by, page_offset, j]
                                                 else:
                                                     V_smem[i, j] = 0.0
@@ -947,7 +1048,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
                                                     m_new[i] = m_smem[row]
                                                     # mask out of kv_chunk_len 
S
                                                     for j in T.serial(tile_z):
-                                                        if mask(causal,
+                                                        if _causal_mask(causal,
                                                                 row=tile_id[0] 
* L_per_cta + row // group_size,
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
@@ -961,7 +1062,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
                                                 for j in T.serial(tile_z):
                                                     # this is to avoid sync 
inside condition branch
                                                     if row < tile_x:
-                                                        if mask(causal,
+                                                        if _causal_mask(causal,
                                                                 row=tile_id[0] 
* L_per_cta + row // group_size,
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
@@ -1036,7 +1137,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
         yo, yi = sch.split(loop_y, factors=[None, tile[1]])
         sch.reorder(xo, yo, xi, yi)
         t = sch.fuse(xo, yo)
-        ty, tx = sch.split(t, factors=[num_warps, bdx])
+        ty, tx = sch.split(t, factors=[None, bdx])
         sch.bind(ty, "threadIdx.y")
         sch.bind(tx, "threadIdx.x")
 
@@ -1048,7 +1149,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: 
Target):  # pylint: disable=
         yo, yi = sch.split(loop_y, factors=[None, tile[1]])
         sch.reorder(xo, yo, xi, yi)
         t = sch.fuse(xo, yo)
-        ty, tx = sch.split(t, factors=[num_warps, bdx])
+        ty, tx = sch.split(t, factors=[None, bdx])
         sch.bind(ty, "threadIdx.y")
         sch.bind(tx, "threadIdx.x")
 
@@ -1084,6 +1185,7 @@ def _attention_decode(
     num_qo_heads,
     head_dim,
     qkv_dtype,
+    sliding_window: bool,
     target: Target,  # pylint: disable=unused-argument
 ):
     # pylint: disable=invalid-name
@@ -1092,8 +1194,13 @@ def _attention_decode(
     H_kv = num_kv_heads
     D = head_dim
 
+    THREAD_LIMIT = 512
+    TILE_SIZE_PER_BDX = 2
+    if target.kind.name == "opencl" and "android" in str(target.host):
+        THREAD_LIMIT = 64
+        TILE_SIZE_PER_BDX = 1
     max_num_threads_per_block = get_max_num_threads_per_block(target)
-    thread_limit = min(max_num_threads_per_block, 512)
+    thread_limit = min(max_num_threads_per_block, THREAD_LIMIT)
 
     GROUP_SIZE = H_qo // H_kv
     VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4)
@@ -1104,7 +1211,7 @@ def _attention_decode(
     gdz = GROUP_SIZE // bdy
     threads_per_CTA = max(thread_limit, bdx * bdy)
     bdz = threads_per_CTA // (bdx * bdy)
-    tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1
+    tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1
     log2e = math.log2(math.exp(1))
 
     # pylint: disable=line-too-long,too-many-arguments,too-many-branches
@@ -1116,7 +1223,7 @@ def _attention_decode(
         pages_handle: T.handle,
         page_table_indptr_handle: T.handle,
         page_table_values_handle: T.handle,
-        last_page_len_handle: T.handle,
+        var_length_info: T.handle, # [b] when sliding window = False, or 
otherwise [3, b]
         k_rope_pos_offset_handle: T.handle,
         q_rope_position_handle: T.handle,
         output_handle: T.handle,
@@ -1139,9 +1246,17 @@ def _attention_decode(
         page_table_values = T.match_buffer(page_table_values_handle, 
(nnz_pages,), "int32")
         k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), 
"int32")
         q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32")
-        last_page_len = T.match_buffer(last_page_len_handle, (B,), "int32")
         output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype)
         lse = T.match_buffer(lse_handle, (B, H_qo), "float32")  # pylint: 
disable=unused-variable
+        # The length information of the sequences.
+        # - It is in shape `(3, batch_size)` when sliding window is enabled.
+        #   For a sequence "i", location
+        #   - "(0, i)" is the number of KV slots used in the last page of the 
seq ("last_page_len"),
+        #   - "(1, i)" is the starting offset of the sliding window in the seq,
+        #   - "(2, i)" is the attn sink length of the sequence.
+        # - It is in shape `(batch_size,)` when sliding window is disabled,
+        #   denoting the "last_page_len".
+        length_info = _declare_length_info(var_length_info, B, sliding_window)
 
         sm_scale = 1.0 / math.sqrt(float(D)) * log2e
 
@@ -1177,10 +1292,9 @@ def _attention_decode(
                                 batch_idx: T.int32 = bx
                                 cur_page_indptr_begin: T.int32 = 
page_table_indptr[batch_idx]
                                 cur_page_indptr_end: T.int32 = 
page_table_indptr[batch_idx + 1]
-                                cur_last_page_len: T.int32 = 
last_page_len[batch_idx]
                                 kv_chunk_len[0] = T.if_then_else(
                                     cur_page_indptr_begin != 
cur_page_indptr_end,
-                                    (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * 16 + cur_last_page_len,
+                                    _get_kv_chunk_len(cur_page_indptr_end - 
cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window),
                                     0
                                 )
 
@@ -1203,31 +1317,39 @@ def _attention_decode(
                                     tile_start_g: T.int32(is_size_var=True) = 
((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx  # type: ignore
                                     # load K from global memory to shared 
memory
                                     for j in T.serial(tile_size_per_bdx):
-                                        row_g: T.int32(is_size_var=True) = 
tile_start_g + j  # type: ignore
-                                        if row_g < kv_chunk_len[0]:
-                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]  # type: 
ignore
-                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, 16)  # type: ignore
-                                            for vec in T.vectorized(VEC_SIZE):
-                                                K_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = T.if_then_else(
-                                                    rotary_mode == 1,
-                                                    _rope(pages, 
k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, 
(page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype),
-                                                    pages[page_no, 0, by, 
page_offset, tx * VEC_SIZE + vec]
-                                                )
-                                        else:
-                                            for vec in T.vectorized(VEC_SIZE):
-                                                K_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = 0.0
+                                        with T.block("K_load"):
+                                            T.reads()
+                                            T.writes()
+                                            row_g: T.int32(is_size_var=True) = 
tile_start_g + j  # type: ignore
+                                            if row_g < kv_chunk_len[0]:
+                                                seq_offset: 
T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, 
sliding_window)  # type: ignore
+                                                page_no: 
T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + 
T.floordiv(seq_offset, 16)]  # type: ignore
+                                                page_offset: 
T.int32(is_size_var=True) = T.floormod(seq_offset, 16)  # type: ignore
+                                                for vec in 
T.vectorized(VEC_SIZE):
+                                                    K_smem[tile_start_s + j, 
tx * VEC_SIZE + vec] = T.if_then_else(
+                                                        rotary_mode == 1,
+                                                        _rope(pages, 
k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, 
(page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype),
+                                                        pages[page_no, 0, by, 
page_offset, tx * VEC_SIZE + vec]
+                                                    )
+                                            else:
+                                                for vec in 
T.vectorized(VEC_SIZE):
+                                                    K_smem[tile_start_s + j, 
tx * VEC_SIZE + vec] = 0.0
                                     T.tvm_storage_sync("shared")
                                     # load V from global memory to shared 
memory
                                     for j in T.serial(tile_size_per_bdx):
-                                        row_g: T.int32(is_size_var=True) = 
tile_start_g + j  # type: ignore
-                                        if row_g < kv_chunk_len[0]:
-                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]  # type: 
ignore
-                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, 16)  # type: ignore
-                                            for vec in T.vectorized(VEC_SIZE):
-                                                V_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec]
-                                        else:
-                                            for vec in T.vectorized(VEC_SIZE):
-                                                V_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = 0.0
+                                        with T.block("V_load"):
+                                            T.reads()
+                                            T.writes()
+                                            row_g: T.int32(is_size_var=True) = 
tile_start_g + j  # type: ignore
+                                            if row_g < kv_chunk_len[0]:
+                                                seq_offset: 
T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, 
sliding_window)  # type: ignore
+                                                page_no: 
T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + 
T.floordiv(seq_offset, 16)]  # type: ignore
+                                                page_offset: 
T.int32(is_size_var=True) = T.floormod(seq_offset, 16)  # type: ignore
+                                                for vec in 
T.vectorized(VEC_SIZE):
+                                                    V_smem[tile_start_s + j, 
tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec]
+                                            else:
+                                                for vec in 
T.vectorized(VEC_SIZE):
+                                                    V_smem[tile_start_s + j, 
tx * VEC_SIZE + vec] = 0.0
                                     T.tvm_storage_sync("shared")
                                     # compute QK
                                     m_prev[0] = st_m[0]
@@ -1250,10 +1372,9 @@ def _attention_decode(
                                             )
                                             
T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, 
dtype="handle")
 
+                                        S_local[j] = -5e4
                                         if (iterator * bdz + tz) * bdy * 
tile_size_per_bdx + j < kv_chunk_len[0]:
                                             S_local[j] = t0[0]
-                                        else:
-                                            S_local[j] = -5e4
                                         # update st_m
                                         st_m[0] = T.max(st_m[0], S_local[j])
 
@@ -1336,13 +1457,6 @@ def _attention_prefill_ragged(
         tile_z = 8
         num_warps = 2
 
-    def mask(causal, row, col, kv_len, qo_len):
-        return T.if_then_else(
-            causal > 0,
-            col < kv_len - qo_len + row + 1,
-            col < kv_len,
-        )
-
     # fmt: off
     @T.prim_func
     def batch_prefill_ragged_kv(  # pylint: 
disable=too-many-arguments,too-many-branches
@@ -1515,7 +1629,7 @@ def _attention_prefill_ragged(
                                                     m_new[i] = m_smem[row]
                                                     # mask out of kv_chunk_len 
S
                                                     for j in T.serial(tile_z):
-                                                        if mask(causal,
+                                                        if _causal_mask(causal,
                                                                 row=tile_id[0] 
* L_per_cta + row // group_size,
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
@@ -1529,7 +1643,7 @@ def _attention_prefill_ragged(
                                                 for j in T.serial(tile_z):
                                                     # this is to avoid sync 
inside condition branch
                                                     if row < tile_x:
-                                                        if mask(causal,
+                                                        if _causal_mask(causal,
                                                                 row=tile_id[0] 
* L_per_cta + row // group_size,
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
@@ -1604,7 +1718,7 @@ def _attention_prefill_ragged(
         yo, yi = sch.split(loop_y, factors=[None, tile[1]])
         sch.reorder(xo, yo, xi, yi)
         t = sch.fuse(xo, yo)
-        ty, tx = sch.split(t, factors=[num_warps, bdx])
+        ty, tx = sch.split(t, factors=[None, bdx])
         sch.bind(ty, "threadIdx.y")
         sch.bind(tx, "threadIdx.x")
 
@@ -1616,7 +1730,7 @@ def _attention_prefill_ragged(
         yo, yi = sch.split(loop_y, factors=[None, tile[1]])
         sch.reorder(xo, yo, xi, yi)
         t = sch.fuse(xo, yo)
-        ty, tx = sch.split(t, factors=[num_warps, bdx])
+        ty, tx = sch.split(t, factors=[None, bdx])
         sch.bind(ty, "threadIdx.y")
         sch.bind(tx, "threadIdx.x")
 
@@ -1725,13 +1839,18 @@ def _merge_state_inplace(
 
 
 if __name__ == "__main__":
-    for head_dim in [64, 128]:
-        for dtype in ["float16", "float32"]:
-            for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
-                set_global_func(head_dim, dtype)
-                cache = create_kv_cache(head_dim, dtype, rope_mode)
-                for fuse_qkv in [False, True]:
-                    test_paged_attention_kv_cache_prefill_and_decode((cache, 
rope_mode), fuse_qkv)
-                    test_paged_attention_kv_cache_remove_sequence((cache, 
rope_mode), fuse_qkv)
-                    test_paged_attention_kv_cache_fork_sequence((cache, 
rope_mode), fuse_qkv)
-                    test_paged_attention_kv_cache_popn((cache, rope_mode), 
fuse_qkv)
+    HEAD_DIMS = [64, 128]
+    DTYPES = ["float16", "float32"]
+    ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]
+    SUPPORT_SLIDING_WINDOW = [False, True]
+    for head_dim, dtype, rope_mode, support_sliding_window in 
itertools.product(
+        HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
+    ):
+        set_global_func(head_dim, dtype)
+        cache = create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window)
+        cache_and_config = (cache, rope_mode, support_sliding_window)
+        test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
+        test_paged_attention_kv_cache_remove_sequence(cache_and_config)
+        test_paged_attention_kv_cache_fork_sequence(cache_and_config)
+        test_paged_attention_kv_cache_popn(cache_and_config)
+        test_paged_attention_kv_cache_sliding_window(cache_and_config)

Reply via email to