This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 98d5153918 [Unity] PagedKVCache supporting on-the-fly RoPE calculation 
(#16396)
98d5153918 is described below

commit 98d5153918616668a84b451cb23a61d3b7fe0839
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jan 14 22:33:39 2024 -0500

    [Unity] PagedKVCache supporting on-the-fly RoPE calculation (#16396)
    
    This PR enhances PagedKVCache with the inline RoPE compute,
    which unblocks the movement towards sliding window and attention
    sink.
    
    Both FlashInfer and TIR kernels are updated in this PR with
    the RoPE calculation. Note that FlashInfer is bumped in order
    to include the RoPE update.
    
    The previous standalone kernel used for RoPE application
    are thereby removed.
    
    ---
    
    Co-authored-by: Bohan Hou <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
---
 3rdparty/flashinfer                                |   2 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 131 ++--
 ...builtin_paged_attention_kv_cache_flashinfer.py} |  22 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 700 ++++++++++++++++-----
 4 files changed, 635 insertions(+), 220 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 7d3a47310a..9cd1f42e96 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 7d3a47310af1ac0795e0d8e8435e42c882c96a13
+Subproject commit 9cd1f42e968a8de7d3af2c7567072e0ad6c8ffed
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 8823467873..20e68a9d33 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -70,6 +70,8 @@ struct Block {
   std::vector<int32_t> page_ids;
   /*! \brief The total sequence length in the block. */
   int32_t seq_length = 0;
+  /*! \brief The start position in sequence of this block. */
+  int32_t start_pos = 0;
 
   /*! \brief The global index of the block. */
   const int32_t index;
@@ -236,14 +238,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
   std::vector<NDArray> page_indices_on_depths_device_;
   /*! \brief The number of KV slots used in the last page of sequences. */
   std::vector<NDArray> last_page_len_on_depths_device_;
+  /*! \brief The k position offset of applying RoPE for each sequence. */
+  std::vector<NDArray> k_rope_pos_offset_device_;
   /*!
    * \brief The append length indptr array on device.
    * \note Since the Q/K/V data may have raggedness in terms of lengths,
    * we represent the the append lengths in CSR format.
    */
   NDArray cur_append_length_indptr_device_;
-  /*! \brief The position offset of applying RoPE for each sequence. */
-  NDArray cur_rope_offset_device_;
+  /*! \brief The k position offset of applying RoPE for each sequence. */
+  NDArray k_ragged_rope_pos_offset_device_;
+  /*! \brief The q position mapping of applying RoPE for each sequence. */
+  NDArray q_rope_position_map_device_;
   /*!
    * \brief The corresponding position in global KV cache (pages)
    * for each position along the length dimension of K/V data when
@@ -264,7 +270,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
   // attention/append.
   //-------------------------------------------
   NDArray cur_append_length_indptr_view_;
-  NDArray cur_rope_offset_view_;
+  NDArray k_ragged_rope_pos_offset_view_;
+  NDArray q_rope_position_map_view_;
   NDArray append_position_map_view_;
   NDArray temp_attn_output_view_;
   NDArray temp_attn_scores_view_;
@@ -273,6 +280,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
   std::vector<NDArray> page_indptr_on_depths_view_;
   std::vector<NDArray> page_indices_on_depths_view_;
   std::vector<NDArray> last_page_len_on_depths_view_;
+  std::vector<NDArray> k_rope_pos_offset_view_;
 
   PackedFunc f_transpose_append_;
   PackedFunc f_attention_prefill_;
@@ -284,7 +292,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
   Optional<PackedFunc> f_attention_prefill_end_forward_;
   Optional<PackedFunc> f_attention_decode_begin_forward_;
   Optional<PackedFunc> f_attention_decode_end_forward_;
-  PackedFunc f_rotary_;
   Optional<PackedFunc> f_merge_inplace_;
   Optional<PackedFunc> f_debug_get_kv_;
 
@@ -312,7 +319,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
                                     Optional<PackedFunc> 
f_attention_prefill_end_forward,
                                     Optional<PackedFunc> 
f_attention_decode_begin_forward,
                                     Optional<PackedFunc> 
f_attention_decode_end_forward,
-                                    PackedFunc f_rotary, Optional<PackedFunc> 
f_merge_inplace,
+                                    Optional<PackedFunc> f_merge_inplace,
                                     Optional<PackedFunc> f_debug_get_kv)
       : page_size_(page_size),
         num_layers_(num_layers),
@@ -333,7 +340,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
         
f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)),
         
f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)),
         
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
-        f_rotary_(std::move(f_rotary)),
         f_merge_inplace_(std::move(f_merge_inplace)),
         f_debug_get_kv_(std::move(f_debug_get_kv)) {
     pages_.reserve(num_layers);
@@ -350,13 +356,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
           NDArray::Empty({num_total_pages}, dtype_aux_, device));
       last_page_len_on_depths_device_.push_back(
           NDArray::Empty({reserved_num_seqs}, dtype_aux_, device));
+      k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device));
       qo_indptr_on_depths_view_.push_back(NDArray());
       page_indptr_on_depths_view_.push_back(NDArray());
       page_indices_on_depths_view_.push_back(NDArray());
       last_page_len_on_depths_view_.push_back(NDArray());
+      k_rope_pos_offset_view_.push_back(NDArray());
     }
     cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
-    cur_rope_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
+    k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device);
+    q_rope_position_map_device_ = NDArray::Empty({num_total_pages * 
page_size}, dtype_aux_, device);
     append_position_map_device_ = NDArray::Empty({num_total_pages * 
page_size}, dtype_aux_, device);
 
     temp_attn_output_device_ =
@@ -428,6 +437,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
     int32_t parent_block_idx = parent_it->second.last_block_idx;
     // Create a child block with the parent block pointer.
     int32_t child_block_idx = GetFreeBlock();
+    global_block_pool_[child_block_idx].start_pos = 
parent_it->second.seq_length;
     global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
     // Create the child sequence with the child block.
     seq_map_.insert({child_seq_id, Sequence(global_block_pool_, 
child_block_idx)});
@@ -471,16 +481,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
 
     // - Collect sequence/block/page information for attention.
     std::vector<const Sequence*> sequences;
-    std::vector<int32_t> rope_offset;
+    std::vector<int32_t> k_ragged_rope_pos_offset;
     is_decode_request_ = true;
     sequences.reserve(cur_batch_size_);
-    rope_offset.reserve(cur_batch_size_);
+    k_ragged_rope_pos_offset.reserve(cur_batch_size_);
     for (int i = 0; i < cur_batch_size_; ++i) {
       auto it = seq_map_.find(seq_ids[i]);
       CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
                                   << "\" cannot be found in KV cache.";
       sequences.push_back(&it->second);
-      rope_offset.push_back(it->second.seq_length);
+      k_ragged_rope_pos_offset.push_back(it->second.seq_length);
       it->second.seq_length += append_lengths[i];
       if (append_lengths[i] != 1) {
         is_decode_request_ = false;
@@ -504,6 +514,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
     std::vector<std::vector<int32_t>> page_indptr_on_depths;
     std::vector<std::vector<int32_t>> page_indices_on_depths;
     std::vector<std::vector<int32_t>> last_page_len_on_depths;
+    std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths;
     use_decode_kernel_.clear();
     for (int d = 0; d < num_depths_; ++d) {
       auto [chunked_block_ids, use_decode_kernel] = 
GetChunkedBlockIds(block_ids_on_depths[d]);
@@ -513,23 +524,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
       std::vector<int32_t> page_indptr_h{0};
       std::vector<int32_t> page_indices_h;
       std::vector<int32_t> last_page_len_h;
+      std::vector<int32_t> k_rope_pos_offset_h;
       for (const auto& [block_id, chunk_append_length] : chunked_block_ids) {
         qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
         if (block_id == -1) {
           page_indptr_h.push_back(page_indptr_h.back());
           last_page_len_h.push_back(0);
+          k_rope_pos_offset_h.push_back(0);
         } else {
           const Block& block = global_block_pool_[block_id];
           page_indptr_h.push_back(page_indptr_h.back() + 
block.page_ids.size());
           page_indices_h.insert(page_indices_h.end(), block.page_ids.begin(), 
block.page_ids.end());
           last_page_len_h.push_back(
               block.seq_length == 0 ? 0 : (block.seq_length - 1) % page_size_ 
+ 1);
+          k_rope_pos_offset_h.push_back(block.start_pos);
         }
       }
       qo_indptr_on_depths.push_back(qo_indptr_h);
       page_indptr_on_depths.push_back(page_indptr_h);
       page_indices_on_depths.push_back(page_indices_h);
       last_page_len_on_depths.push_back(last_page_len_h);
+      k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h);
     }
 
     if (num_depths_ > 1) {
@@ -543,21 +558,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
 
     // Map each the token position in the input batch to the position
     // in the global KV cache. The mapping is used in when appending k/v 
values.
+    std::vector<int32_t> q_rope_position_map;
     std::vector<int32_t> append_position_map;
     for (int i = 0; i < cur_batch_size_; ++i) {
       int64_t append_length = append_lengths[i];
       const Block& block = global_block_pool_[sequences[i]->last_block_idx];
       for (int64_t pos = 0; pos < append_length; ++pos) {
-        int64_t pos_in_seq = block.seq_length - append_length + pos;
-        append_position_map.push_back(block.page_ids[pos_in_seq / page_size_] 
* page_size_ +
-                                      pos_in_seq % page_size_);
+        int64_t pos_in_block = block.seq_length - append_length + pos;
+        q_rope_position_map.push_back(sequences[i]->seq_length - append_length 
+ pos);
+        append_position_map.push_back(block.page_ids[pos_in_block / 
page_size_] * page_size_ +
+                                      pos_in_block % page_size_);
       }
     }
 
     // - Sync NDArrays to GPU.
     SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), 
std::move(page_indptr_on_depths),
                          std::move(page_indices_on_depths), 
std::move(last_page_len_on_depths),
-                         std::move(rope_offset), 
std::move(append_position_map));
+                         std::move(k_rope_pos_offset_on_depths),
+                         std::move(k_ragged_rope_pos_offset), 
std::move(q_rope_position_map),
+                         std::move(append_position_map));
 
     // NOTE(Zihao): This logic is problematic ATM because we need a unique 
split per depth
     KernelBeginForward();
@@ -643,14 +662,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
         << "The auxiliary arrays are not synchronized to device. Please call "
            "`BeginForward` to synchronize before calling `Attention`.";
 
-    // Part 2: apply rotary embedding to q/k data.
-    f_rotary_(q_data, k_data, cur_append_length_indptr_view_, 
cur_rope_offset_view_,
-              cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, 
/*qkv_layout=*/0,
-              rotary_scale_, rotary_theta_);
-
-    // Part 3: append k/v data to kv-cache
+    // Part 2: append k/v data to kv-cache
     f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
-    // Part 4: perform attention
+    // Part 3: perform attention
     AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
   }
 
@@ -865,7 +879,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
       if (use_decode_kernel_[0]) {
         f_attention_decode_begin_forward_.value()(
             /*depth=*/0, page_indptr_on_depths_view_[0], 
last_page_len_on_depths_view_[0],
-            num_qo_heads_, num_kv_heads_, head_dim_, page_size_, 
/*rotary_mode=*/true);
+            num_qo_heads_, num_kv_heads_, head_dim_, page_size_, 
/*rotary_mode=*/1);
       } else {
         f_attention_prefill_begin_forward_.value()(/*depth=*/0, 
qo_indptr_on_depths_view_[0],
                                                    cur_batch_size_, 
num_qo_heads_, num_kv_heads_);
@@ -880,7 +894,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
         if (use_decode_kernel_[d]) {
           f_attention_decode_begin_forward_.value()(
               d, page_indptr_on_depths_view_[d], 
last_page_len_on_depths_view_[d], num_qo_heads_,
-              num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false);
+              num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/1);
         } else {
           f_attention_prefill_begin_forward_.value()(/*depth=*/d, 
qo_indptr_on_depths_view_[d],
                                                      
last_page_len_on_depths_view_[d]->shape[0],
@@ -901,22 +915,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
       if (use_decode_kernel_[0]) {
         f_attention_decode_(/*depth=*/0, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[0],
                             page_indices_on_depths_view_[0], 
last_page_len_on_depths_view_[0],
-                            output, merged_attn_scores_view_,
-                            /*rotary_mode=*/0, rotary_scale_, rotary_theta_);
+                            k_rope_pos_offset_view_[0], 
q_rope_position_map_view_, output,
+                            merged_attn_scores_view_,
+                            /*rotary_mode=*/1, rotary_scale_, rotary_theta_);
       } else {
         f_attention_prefill_(/*depth=*/0, q_data, 
qo_indptr_on_depths_view_[0], pages_[layer_id],
                              page_indptr_on_depths_view_[0], 
page_indices_on_depths_view_[0],
-                             last_page_len_on_depths_view_[0], output, 
merged_attn_scores_view_,
+                             last_page_len_on_depths_view_[0], 
k_rope_pos_offset_view_[0],
+                             q_rope_position_map_view_, output, 
merged_attn_scores_view_,
                              /*causal=*/1,
-                             /*rotary_mode=*/0, rotary_scale_, rotary_theta_);
+                             /*rotary_mode=*/1, rotary_scale_, rotary_theta_);
       }
     } else {
       // Compute appended text self-attention
       f_attention_prefill_ragged_.value()(q_data, 
cur_append_length_indptr_view_, k_data, v_data,
-                                          cur_append_length_indptr_view_, 
output,
+                                          cur_append_length_indptr_view_, 
q_rope_position_map_view_,
+                                          k_ragged_rope_pos_offset_view_, 
output,
                                           merged_attn_scores_view_,
                                           /*causal=*/1,
-                                          /*rotary_mode=*/0, rotary_scale_, 
rotary_theta_);
+                                          /*rotary_mode=*/1, rotary_scale_, 
rotary_theta_);
 
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -926,16 +943,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
           // Use decode kernel for depth d
           f_attention_decode_(/*depth=*/d, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[d],
                               page_indices_on_depths_view_[d], 
last_page_len_on_depths_view_[d],
+                              k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
                               temp_attn_output_view_, temp_attn_scores_view_,
-                              /*rotary_mode=*/0, rotary_scale_, rotary_theta_);
+                              /*rotary_mode=*/1, rotary_scale_, rotary_theta_);
         } else {
           // Use prefill kernel for depth d
           f_attention_prefill_(/*depth=*/d, q_data, 
qo_indptr_on_depths_view_[d], pages_[layer_id],
                                page_indptr_on_depths_view_[d], 
page_indices_on_depths_view_[d],
-                               last_page_len_on_depths_view_[d], 
temp_attn_output_view_,
+                               last_page_len_on_depths_view_[d], 
k_rope_pos_offset_view_[d],
+                               q_rope_position_map_view_, 
temp_attn_output_view_,
                                temp_attn_scores_view_,
                                /*causal=*/0,
-                               /*rotary_mode=*/0, rotary_scale_, 
rotary_theta_);
+                               /*rotary_mode=*/1, rotary_scale_, 
rotary_theta_);
         }
         f_merge_inplace_.value()(output, merged_attn_scores_view_, 
temp_attn_output_view_,
                                  temp_attn_scores_view_);
@@ -952,7 +971,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
                             std::vector<std::vector<int32_t>> 
page_indptr_on_depths,
                             std::vector<std::vector<int32_t>> 
page_indices_on_depths,
                             std::vector<std::vector<int32_t>> 
last_page_len_on_depths,
-                            std::vector<int32_t> rope_offset,
+                            std::vector<std::vector<int32_t>> 
k_rope_pos_offset_on_depths,
+                            std::vector<int32_t> k_ragged_rope_pos_offset,
+                            std::vector<int32_t> q_rope_position_map,
                             std::vector<int32_t> append_position_map) {
     ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt);
     ICHECK_EQ(qo_indptr_on_depths.size(), num_depths_);
@@ -1015,22 +1036,37 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCache {
       fcopy_from_vec(last_page_len_on_depths_view_[d], 
last_page_len_on_depths[d].data());
     }
 
-    // 5. cur_append_lengths_indptr
+    // 5. k_rope_pos_offset
+    for (int d = 0; d < num_depths_; ++d) {
+      ICHECK_EQ(k_rope_pos_offset_on_depths[d].size() + 1, 
qo_indptr_on_depths[d].size());
+      k_rope_pos_offset_view_[d] = k_rope_pos_offset_device_[d].CreateView(
+          {static_cast<int64_t>(k_rope_pos_offset_on_depths[d].size())}, 
dtype_aux_);
+      fcopy_from_vec(k_rope_pos_offset_view_[d], 
k_rope_pos_offset_on_depths[d].data());
+    }
+
+    // 6. cur_append_lengths_indptr
     cur_append_length_indptr_view_ =
         cur_append_length_indptr_device_.CreateView({num_sequences + 1}, 
dtype_aux_);
     fcopy_from_vec(cur_append_length_indptr_view_, 
cur_append_lengths_indptr.data());
 
-    // 6. cur_rope_offset
-    ICHECK_EQ(rope_offset.size(), num_sequences);
-    cur_rope_offset_view_ = 
cur_rope_offset_device_.CreateView({num_sequences}, dtype_aux_);
-    fcopy_from_vec(cur_rope_offset_view_, rope_offset.data());
+    // 7. k_ragged_rope_pos_offset
+    ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);
+    k_ragged_rope_pos_offset_view_ =
+        k_ragged_rope_pos_offset_device_.CreateView({num_sequences}, 
dtype_aux_);
+    fcopy_from_vec(k_ragged_rope_pos_offset_view_, 
k_ragged_rope_pos_offset.data());
+
+    // 8. q_rope_position_map
+    ICHECK_EQ(q_rope_position_map.size(), total_append_length);
+    q_rope_position_map_view_ =
+        q_rope_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
+    fcopy_from_vec(q_rope_position_map_view_, q_rope_position_map.data());
 
-    // 7. append_position_map
+    // 9. append_position_map
     append_position_map_view_ =
         append_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
     fcopy_from_vec(append_position_map_view_, append_position_map.data());
 
-    // 8. Create view for temporary arrays for attention computation.
+    // 10. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
         {total_append_length, num_qo_heads_, head_dim_}, 
temp_attn_output_device_->dtype);
     temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
@@ -1065,8 +1101,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
                        PackedFunc f_attention_prefill_begin_forward,
                        PackedFunc f_attention_prefill_end_forward,
                        PackedFunc f_attention_decode_begin_forward,
-                       PackedFunc f_attention_decode_end_forward, PackedFunc 
f_rotary,
-                       PackedFunc f_merge_inplace, Optional<PackedFunc> 
f_debug_get_kv) {
+                       PackedFunc f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
+                       Optional<PackedFunc> f_debug_get_kv) {
       CHECK_EQ(cache_config.size(), 3);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1081,7 +1117,7 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           std::move(f_attention_prefill_ragged_end_forward),
           std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
-          std::move(f_rotary), std::move(f_merge_inplace), 
std::move(f_debug_get_kv));
+          std::move(f_merge_inplace), std::move(f_debug_get_kv));
       return PagedAttentionKVCache(std::move(n));
     });
 
@@ -1090,7 +1126,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
                        int64_t num_kv_heads, int64_t head_dim, double 
rotary_scale,
                        double rotary_theta, NDArray init, PackedFunc 
f_transpose_append,
                        PackedFunc f_attention_prefill, PackedFunc 
f_attention_decode,
-                       PackedFunc f_rotary, Optional<PackedFunc> 
f_debug_get_kv) {
+                       PackedFunc f_attention_prefill_ragged, PackedFunc 
f_merge_inplace,
+                       Optional<PackedFunc> f_debug_get_kv) {
       CHECK_EQ(cache_config.size(), 3);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1100,9 +1137,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
           page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
           num_total_pages, rotary_scale, rotary_theta, init->dtype, 
init->device,
           std::move(f_transpose_append), std::move(f_attention_prefill),
-          std::move(f_attention_decode),                                  //
-          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,  //
-          std::move(f_rotary), NullOpt, std::move(f_debug_get_kv));
+          std::move(f_attention_decode), 
std::move(f_attention_prefill_ragged),  //
+          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
  //
+          std::move(f_merge_inplace), std::move(f_debug_get_kv));
       return PagedAttentionKVCache(std::move(n));
     });
 
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
similarity index 95%
rename from tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
rename to 
tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 94507d3931..69b7a15793 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -59,7 +59,6 @@ fattention_decode_end_forward = None
 fattention_prefill_ragged_begin_forward = None
 fattention_prefill_ragged_end_forward = None
 fattention_merge_state = None
-fattention_rotary = None
 
 
 @T.prim_func
@@ -145,7 +144,7 @@ def set_global_func():
     global fattention_prefill_ragged
     global fattention_prefill_ragged_begin_forward
     global fattention_prefill_ragged_end_forward
-    global fattention_merge_state, fattention_rotary
+    global fattention_merge_state
 
     fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear")
     fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
@@ -182,7 +181,6 @@ def set_global_func():
         "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"
     )
     fattention_merge_state = 
tvm.get_global_func("flashinfer.merge_state_in_place")
-    fattention_rotary = 
tvm.get_global_func("flashinfer.batch_qk_apply_rotary_in_place")
 
 
 def create_kv_cache():
@@ -216,7 +214,6 @@ def create_kv_cache():
         fattention_prefill_end_forward,
         fattention_decode_begin_forward,
         fattention_decode_end_forward,
-        fattention_rotary,
         fattention_merge_state,
         fcopy_cache,
     )
@@ -303,13 +300,7 @@ def apply_attention(
         cached_k[seq_id] = np.concatenate(
             [
                 cached_k[seq_id],
-                np.stack(
-                    [
-                        f_apply_rotary(new_k[l], cached_k[seq_id].shape[1], 
rope_scale, rope_theta)
-                        for l in range(num_layers)
-                    ],
-                    axis=0,
-                ),
+                np.stack([new_k[l] for l in range(num_layers)], axis=0),
             ],
             axis=1,
         )
@@ -347,12 +338,9 @@ def apply_attention(
                 rope_scale,
                 rope_theta,
             ).transpose(1, 0, 2)
-            # Todo(Zihao, Ruihang): fold RoPE into flashinfer attn kernel in 
multi-level cases.
-            # so that k/v values in cache does not have RoPE applied.
-            # k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, 
rope_scale, rope_theta).transpose(
-            #     1, 2, 0
-            # )
-            k_seq = cached_k[seq_id][layer_id].transpose(1, 2, 0)
+            k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, 
rope_theta).transpose(
+                1, 2, 0
+            )
             v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
 
             k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0)
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 280ad7e0ea..bd667292ea 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -58,11 +58,10 @@ def kv_cache_transpose_append(
     var_v_data: T.handle,
     var_position_map: T.handle,
 ):
-    ntoken = T.SizeVar("ntoken", "int64")
-    page_size = T.SizeVar("page_size", "int64")
-    num_pages = T.int64()
+    ntoken = T.SizeVar("ntoken", "int32")
+    num_pages = T.int32()
 
-    pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, 
head_dim), dtype)
+    pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, 16, 
head_dim), dtype)
     k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), 
dtype)
     v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), 
dtype)
     position_map = T.match_buffer(var_position_map, (ntoken,), "int32")
@@ -71,23 +70,19 @@ def kv_cache_transpose_append(
         with T.block("k_transpose_append"):
             vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
             T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
-            T.writes(
-                pages[position_map[vgpos] // page_size, 0, vh, 
position_map[vgpos] % page_size, vf]
-            )
+            T.writes(pages[position_map[vgpos] // 16, 0, vh, 
position_map[vgpos] % 16, vf])
             position: T.int64 = T.Cast("int64", position_map[vgpos])
-            pages[
-                T.floordiv(position, page_size), 0, vh, T.floormod(position, 
page_size), vf
-            ] = k_data[vgpos, vh, vf]
+            pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), 
vf] = k_data[
+                vgpos, vh, vf
+            ]
         with T.block("v_transpose_append"):
             vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
             T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
-            T.writes(
-                pages[position_map[vgpos] // page_size, 1, vh, 
position_map[vgpos] % page_size, vf]
-            )
+            T.writes(pages[position_map[vgpos] // 16, 1, vh, 
position_map[vgpos] % 16, vf])
             position: T.int64 = T.Cast("int64", position_map[vgpos])
-            pages[
-                T.floordiv(position, page_size), 1, vh, T.floormod(position, 
page_size), vf
-            ] = v_data[vgpos, vh, vf]
+            pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), 
vf] = v_data[
+                vgpos, vh, vf
+            ]
 
 
 @T.prim_func
@@ -150,7 +145,8 @@ def create_kv_cache():
         copy_cache,
         _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype),
         _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype),
-        _inplace_rope(rope_theta, rope_scale, head_dim, num_qo_heads, 
num_kv_heads, dtype),
+        _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype),
+        _merge_state_inplace(num_qo_heads, head_dim, dtype),
     ]:
         mod = tvm.IRModule({"main": tir_func})
         with target:
@@ -158,7 +154,14 @@ def create_kv_cache():
         f = tvm.build(mod["main"], target=target)
         builts.append(f.entry_func)
 
-    ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fbatch_rotary 
= builts
+    (
+        ftranspose_append,
+        fcopy_cache,
+        fattn_prefill,
+        fattn_decode,
+        fattn_prefill_ragged,
+        fmerge_state,
+    ) = builts
     fcreate = 
tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
     cache = fcreate(
         tvm.runtime.ShapeTuple([reserved_nseq, maximum_total_seq_length, 
page_size]),
@@ -172,7 +175,8 @@ def create_kv_cache():
         ftranspose_append,
         fattn_prefill,
         fattn_decode,
-        fbatch_rotary,
+        fattn_prefill_ragged,
+        fmerge_state,
         fcopy_cache,
     )
     return cache
@@ -258,13 +262,7 @@ def apply_attention(
         cached_k[seq_id] = np.concatenate(
             [
                 cached_k[seq_id],
-                np.stack(
-                    [
-                        f_apply_rotary(new_k[l], cached_k[seq_id].shape[1], 
rope_scale, rope_theta)
-                        for l in range(num_layers)
-                    ],
-                    axis=0,
-                ),
+                np.stack([new_k[l] for l in range(num_layers)], axis=0),
             ],
             axis=1,
         )
@@ -302,12 +300,9 @@ def apply_attention(
                 rope_scale,
                 rope_theta,
             ).transpose(1, 0, 2)
-            # Todo(Zihao, Ruihang): fold RoPE into flashinfer attn kernel in 
multi-level cases.
-            # so that k/v values in cache does not have RoPE applied.
-            # k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, 
rope_scale, rope_theta).transpose(
-            #     1, 2, 0
-            # )
-            k_seq = cached_k[seq_id][layer_id].transpose(1, 2, 0)
+            k_seq = f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, 
rope_theta).transpose(
+                1, 2, 0
+            )
             v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
 
             k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0)
@@ -385,6 +380,33 @@ def 
test_paged_attention_kv_cache_remove_sequence(kv_cache):
         )
 
 
[email protected]_gpu
[email protected]_cuda
+def test_paged_attention_kv_cache_fork_sequence(kv_cache):
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
+    apply_attention(kv_cache, batch, cached_k, cached_v)
+    # Fork existing sequences.
+    apply_attention(kv_cache, [((4, 3), 35)], cached_k, cached_v)
+    apply_attention(kv_cache, [((5, 0), 20)], cached_k, cached_v)
+    apply_attention(kv_cache, [((6, 5), 102)], cached_k, cached_v)
+    apply_attention(kv_cache, [((7, 0), 3)], cached_k, cached_v)
+    apply_attention(kv_cache, [((8, 5), 71)], cached_k, cached_v)
+    apply_attention(kv_cache, [((9, 5), 20)], cached_k, cached_v)
+    # Mixture of decode and prefill.
+    operation_seq = [
+        [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
+        [(7, 1), (6, 1), (8, 1), (9, 1)],
+        [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)],
+        [(7, 10), (6, 2), (8, 3), (9, 4)],
+    ]
+    for batch in operation_seq:
+        apply_attention(kv_cache, batch, cached_k, cached_v)
+
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_paged_attention_kv_cache_popn(kv_cache):
@@ -404,76 +426,6 @@ def test_paged_attention_kv_cache_popn(kv_cache):
         verify_cached_kv(kv_cache, seq_ids=list(range(4)), 
expected_k=cached_k, expected_v=cached_v)
 
 
-def _inplace_rope(
-    theta: float,
-    scale: float,
-    head_dim: int,
-    num_q_heads: int,
-    num_kv_heads: int,
-    dtype: str,
-):
-    assert head_dim <= 128, "Rotary embedding currently only supports head_dim 
<= 128"
-    rotary_dim = head_dim
-
-    def _rope(
-        x: T.Buffer,
-        s: tir.Var,
-        h: tir.Var,
-        d: tir.Var,
-        rope_offset: tir.Var,
-        instance_offset: tir.Var,
-    ):
-        cos_freq, sin_freq = rope_freq((s + rope_offset) * scale, d, 
rotary_dim, theta, dtype)
-        cos = cos_freq * x[s + instance_offset, h, d]
-        sin = sin_freq * tir.if_then_else(
-            d < rotary_dim // 2,
-            -x[s + instance_offset, h, d + rotary_dim // 2],
-            x[s + instance_offset, h, d - rotary_dim // 2],
-        )
-        return cos + sin
-
-    # fmt: off
-    @T.prim_func
-    def tir_rotary(
-        var_q: T.handle,
-        var_k: T.handle,
-        var_append_len_indptr: T.handle,
-        var_rope_offsets: T.handle,
-        _0: T.int32,
-        _1: T.int32,
-        _2: T.int32,
-        _3: T.int32,
-        _4: T.int32,
-        _5: T.float32,
-        _6: T.float32,
-    ):
-        T.func_attr({"tir.is_scheduled": 1})
-        total_len = T.int32()
-        batch_size = T.int32()
-        q = T.match_buffer(var_q, (total_len, num_q_heads, head_dim), dtype)
-        k = T.match_buffer(var_k, (total_len, num_kv_heads, head_dim), dtype)
-        rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32")
-        append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size 
+ 1,), "int32")
-        for b_h in T.thread_binding(batch_size * (num_q_heads + num_kv_heads), 
thread="blockIdx.x"):
-            b: T.int32 = b_h // (num_q_heads + num_kv_heads)
-            h: T.int32 = b_h % (num_q_heads + num_kv_heads)
-            instance_offset: T.int32 = append_len_indptr[b]
-            rope_offset: T.int32 = rope_offsets[b]
-            append_len: T.int32 = append_len_indptr[b + 1] - 
append_len_indptr[b]
-            for s0 in range(T.ceildiv(append_len, 32)):
-                for s1 in T.thread_binding(32, thread="threadIdx.y"):
-                    for d0 in T.thread_binding(T.ceildiv(head_dim, 4), 
thread="threadIdx.x"):
-                        for d1 in T.vectorized(4):
-                            s: T.int32 = s0 * 32 + s1
-                            d: T.int32 = d0 * 4 + d1
-                            if s < append_len and d < head_dim:
-                                if h < num_q_heads:
-                                    q[s + instance_offset, h, d] = _rope(q, s, 
h, d, rope_offset, instance_offset)
-                                else:
-                                    k[s + instance_offset, h - num_q_heads, d] 
= _rope(k, s, h - num_q_heads, d, rope_offset, instance_offset)
-    return tir_rotary
-
-
 def rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str):
     """Compute the inverse frequency of RoPE and then return the cosine and 
sine of it.
 
@@ -562,25 +514,28 @@ def _attention_prefill(h_kv, h_q, d, dtype):
         var_page_indptr: T.handle, # [batch_size + 1]
         var_page_values: T.handle, # [nnz_pages]
         var_last_page_len: T.handle, # [b]
+        var_k_rope_pos_offset: T.handle, # [b]
+        var_q_rope_position: T.handle, # [total_q_len]
         var_output: T.handle, # [total_len, h_q, d]
         var_lse: T.handle, # [total_len, h_q]
         causal: T.int32,
-        _1: T.int32,
-        _2: T.float32,
-        _3: T.float32,
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
     ):
         batch_size = T.int32(is_size_var=True)
         total_len = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
-        page_size = T.int32(is_size_var=True)
 
         q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
         q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32")
-        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, page_size, 
d), dtype)
+        pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
         page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32")
         page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32")
         last_page_len = T.match_buffer(var_last_page_len, (batch_size,), 
"int32")
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32")
+        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32")
         output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
         lse = T.match_buffer(var_lse, (total_len, h_q), "float32")  # pylint: 
disable=unused-variable
 
@@ -599,9 +554,6 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                             batch_rows = _var("int32")
                             iterator = _var("int32")
                             kv_chunk_len = _var("int32")
-                            m_new = _var("float32")
-                            m_prev = _var("float32")
-                            d_new = _var("float32")
 
                             Q_smem = T.alloc_buffer((tile_x, d), dtype, 
scope="shared")
                             K_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
@@ -615,6 +567,10 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                             m_prev_smem = T.alloc_buffer((tile_x, ), 
"float32", scope="shared")
                             d_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
 
+                            m_new = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+                            m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+                            d_new = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+
                             ## get tile_no, batch_idx, batch_tiles, batch_rows
                             tile_id[0] = bx
                             batch_idx[0] = 0
@@ -640,7 +596,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                     cur_last_page_len: T.int32 = 
last_page_len[b_idx]
                                     kv_chunk_len[0] = T.if_then_else(
                                         cur_page_indptr_begin != 
cur_page_indptr_end,
-                                        (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * page_size + cur_last_page_len,
+                                        (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * 16 + cur_last_page_len,
                                         0
                                     )
                                     T.tvm_storage_sync("shared")
@@ -667,7 +623,11 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                             cur_L = L_start + i // group_size
                                             cur_H_qo = H_qo_start + i % 
group_size
                                             if cur_L < q_indptr[b_idx + 1]:
-                                                Q_smem[i, j] = q[cur_L, 
cur_H_qo, j]
+                                                Q_smem[i, j] = T.if_then_else(
+                                                    rotary_mode == 1,
+                                                    _rope(q, 
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j)),
+                                                    q[cur_L, cur_H_qo, j]
+                                                )
                                             else:
                                                 Q_smem[i, j] = 0.0
                                     T.tvm_storage_sync("shared")
@@ -681,9 +641,13 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                 T.writes()
                                                 cur_L = L_kv_start + i
                                                 if cur_L < kv_chunk_len[0]:
-                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, page_size)]
-                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, page_size)
-                                                    K_smem[i, j] = 
pages[page_no, 0, by, page_offset, j]
+                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, 16)]
+                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, 16)
+                                                    K_smem[i, j] = 
T.if_then_else(
+                                                        rotary_mode == 1,
+                                                        _rope(pages, 
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, 
page_offset, j)),
+                                                        pages[page_no, 0, by, 
page_offset, j]
+                                                    )
                                                 else:
                                                     K_smem[i, j] = 0.0
                                         T.tvm_storage_sync("shared")
@@ -694,8 +658,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                 T.writes()
                                                 cur_L = L_kv_start + i
                                                 if cur_L < kv_chunk_len[0]:
-                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, page_size)]
-                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, page_size)
+                                                    page_no: 
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + 
T.floordiv(cur_L, 16)]
+                                                    page_offset: 
T.int32(is_size_var=True) = T.floormod(cur_L, 16)
                                                     V_smem[i, j] = 
pages[page_no, 1, by, page_offset, j]
                                                 else:
                                                     V_smem[i, j] = 0.0
@@ -721,8 +685,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                             row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
                                             if row < tile_x:
                                                 with T.block("update1"):
-                                                    m_prev[0] = m_smem[row]
-                                                    m_new[0] = m_smem[row]
+                                                    m_prev[i] = m_smem[row]
+                                                    m_new[i] = m_smem[row]
                                                     # mask out of kv_chunk_len 
S
                                                     for j in T.serial(tile_z):
                                                         if mask(causal,
@@ -730,8 +694,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
                                                                 
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
-                                                            m_new[0] = 
T.max(m_new[0], S_smem[row, j])
-                                                    d_new[0] = d_smem[row] * 
T.exp2(m_prev[0] - m_new[0])
+                                                            m_new[i] = 
T.max(m_new[i], S_smem[row, j])
+                                                    d_new[i] = d_smem[row] * 
T.exp2(m_prev[i] - m_new[i])
 
                                         for i in T.serial(T.ceildiv(tile_x, 32 
* num_warps)):
                                             row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
@@ -744,19 +708,19 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                                 col=L_kv_start 
+ j,
                                                                 
kv_len=kv_chunk_len[0],
                                                                 
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
-                                                            S_smem[row, j] = 
T.exp2(S_smem[row, j] - m_new[0])
+                                                            S_smem[row, j] = 
T.exp2(S_smem[row, j] - m_new[i])
                                                         else:
-                                                            S_smem[row, j] = 
T.exp2(-5e4 - m_new[0])
+                                                            S_smem[row, j] = 
T.exp2(-5e4 - m_new[i])
 
                                         for i in T.serial(T.ceildiv(tile_x, 32 
* num_warps)):
                                             row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
                                             if row < tile_x:
                                                 with T.block("update"):
                                                     for j in T.serial(tile_z):
-                                                        d_new[0] += 
S_smem[row, j]
-                                                    m_smem[row] = m_new[0]
-                                                    d_smem[row] = d_new[0]
-                                                    m_prev_smem[row] = 
m_prev[0]
+                                                        d_new[i] += 
S_smem[row, j]
+                                                    m_smem[row] = m_new[i]
+                                                    d_smem[row] = d_new[i]
+                                                    m_prev_smem[row] = 
m_prev[i]
                                         T.tvm_storage_sync("shared")
 
                                         # Update O
@@ -775,6 +739,13 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                             if L_start + i // group_size < 
q_indptr[b_idx + 1]:
                                                 output[L_start + i // 
group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
 
+                                    # Store LSE to gmem
+                                    for li in T.grid(tile_x):
+                                        with T.block("lse_store"):
+                                            i = T.axis.remap("S", [li])
+                                            if L_start + i // group_size < 
q_indptr[b_idx + 1]:
+                                                lse[L_start + i // group_size, 
H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
+
                                     # move to next tile
                                     tile_id[0] += NUM_BLKS
     # fmt: on
@@ -835,6 +806,12 @@ def _attention_prefill(h_kv, h_q, d, dtype):
             sch.reorder(ko, ki, xi, yi)
         sch.decompose_reduction(block, ty)
 
+    def apply_to_md(sch, block):
+        loop = sch.get_loops(block)[-1]
+        _, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
     tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
     tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
     apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
@@ -845,6 +822,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
     apply_to_qkv_load(sch, sch.get_block("Q_load"))
     apply_to_qkv_load(sch, sch.get_block("K_load"))
     apply_to_qkv_load(sch, sch.get_block("V_load"))
+    apply_to_md(sch, sch.get_block("lse_store"))
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
@@ -861,6 +839,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, head_dim, 
qkv_dtype):
     GROUP_SIZE = H_qo // H_kv
     VEC_SIZE = max(8 // qkv_dtype_bytes, D // 32)
     bdx = D // VEC_SIZE
+    assert bdx == 32
     bdy = GROUP_SIZE
     threads_per_CTA = max(128, bdx * bdy)
     bdz = threads_per_CTA // (bdx * bdy)
@@ -871,12 +850,14 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
     # fmt: off
     @T.prim_func
     def batch_decode_paged_kv(
-        handler_id: T.int32,  # pylint: disable=unused-argument
+        _0: T.int32,  # pylint: disable=unused-argument
         Q_handle: T.handle,
         pages_handle: T.handle,
         page_table_indptr_handle: T.handle,
         page_table_values_handle: T.handle,
         last_page_len_handle: T.handle,
+        k_rope_pos_offset_handle: T.handle,
+        q_rope_position_handle: T.handle,
         output_handle: T.handle,
         lse_handle: T.handle,
         rotary_mode: T.int32,
@@ -885,16 +866,17 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
     ):
         T.func_attr({"tir.is_scheduled": 1})
         B = T.int32(is_size_var=True)
-        page_size = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
 
         Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
         pages = T.match_buffer(
-            pages_handle, (max_num_pages, 2, H_kv, page_size, D), qkv_dtype
+            pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
         )
         page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), 
"int32")
         page_table_values = T.match_buffer(page_table_values_handle, 
(nnz_pages,), "int32")
+        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), 
"int32")
+        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32")
         last_page_len = T.match_buffer(last_page_len_handle, (B,), "int32")
         output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype)
         lse = T.match_buffer(lse_handle, (B, H_qo), "float32")  # pylint: 
disable=unused-variable
@@ -911,14 +893,15 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                 kv_chunk_len = T.alloc_buffer((1,), "int32", 
scope="local")
                                 K_smem = T.alloc_buffer((bdz * bdy * 
tile_size_per_bdx, D), qkv_dtype, scope="shared")
                                 V_smem = T.alloc_buffer((bdz * bdy * 
tile_size_per_bdx, D), qkv_dtype, scope="shared")
-                                S_allreduce = T.alloc_buffer((bdz, bdy, bdx), 
"float32", scope="shared")
                                 O_allreduce = T.alloc_buffer((bdz, bdy, D), 
"float32", scope="shared")
                                 md_allreduce = T.alloc_buffer((bdz, bdy, 2), 
"float32", scope="shared")
+                                S_reduce_local = T.alloc_buffer((1,), 
"float32", scope="local")
+                                mask = T.alloc_buffer((1,), "uint32", 
scope="local")
+                                t0 = T.alloc_buffer((1,), "float32", 
scope="local")
 
                                 S_local = T.alloc_buffer((bdy * 
tile_size_per_bdx), "float32", scope="local")
                                 K_local = T.alloc_buffer((VEC_SIZE,), 
qkv_dtype, scope="local")
                                 V_local = T.alloc_buffer((VEC_SIZE,), 
qkv_dtype, scope="local")
-                                offset = T.alloc_buffer((1,), "int32", 
scope="local")
                                 m_prev = T.alloc_buffer((1,), "float32", 
scope="local")
                                 d_prev = T.alloc_buffer((1,), "float32", 
scope="local")
                                 other_m = T.alloc_buffer((1,), "float32", 
scope="local")
@@ -934,7 +917,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, head_dim, 
qkv_dtype):
                                 cur_last_page_len: T.int32 = 
last_page_len[batch_idx]
                                 kv_chunk_len[0] = T.if_then_else(
                                     cur_page_indptr_begin != 
cur_page_indptr_end,
-                                    (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * page_size + cur_last_page_len,
+                                    (cur_page_indptr_end - 
cur_page_indptr_begin - 1) * 16 + cur_last_page_len,
                                     0
                                 )
 
@@ -946,9 +929,11 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
 
                                 # load q
                                 for vec in T.vectorized(VEC_SIZE):
-                                    Q_local[vec] = T.if_then_else(rotary_mode 
== 1,
-                                                                  _rope(Q, 
kv_chunk_len[0]-1, head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, 
tx * VEC_SIZE + vec)),
-                                                                  Q[bx, by * 
GROUP_SIZE + ty, tx * VEC_SIZE + vec])
+                                    Q_local[vec] = T.if_then_else(
+                                        rotary_mode == 1,
+                                        _rope(Q, q_rope_position[batch_idx], 
head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + 
vec)),
+                                        Q[bx, by * GROUP_SIZE + ty, tx * 
VEC_SIZE + vec]
+                                    )
 
                                 for iterator in 
T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)):
                                     tile_start_s: T.int32(is_size_var=True) = 
(tz * bdy + ty) * tile_size_per_bdx
@@ -957,12 +942,14 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                     for j in T.serial(tile_size_per_bdx):
                                         row_g: T.int32(is_size_var=True) = 
tile_start_g + j
                                         if row_g < kv_chunk_len[0]:
-                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, page_size)]
-                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, page_size)
+                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]
+                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, 16)
                                             for vec in T.vectorized(VEC_SIZE):
-                                                K_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = T.if_then_else(rotary_mode == 1,
-                                                                               
                                _rope(pages, row_g, head_dim, rope_theta, 
rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec)),
-                                                                               
                                pages[page_no, 0, by, page_offset, tx * 
VEC_SIZE + vec])
+                                                K_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = T.if_then_else(
+                                                    rotary_mode == 1,
+                                                    _rope(pages, 
k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, 
(page_no, 0, by, page_offset, tx * VEC_SIZE + vec)),
+                                                    pages[page_no, 0, by, 
page_offset, tx * VEC_SIZE + vec]
+                                                )
                                         else:
                                             for vec in T.vectorized(VEC_SIZE):
                                                 K_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = 0.0
@@ -971,8 +958,8 @@ def _attention_decode(num_kv_heads, num_qo_heads, head_dim, 
qkv_dtype):
                                     for j in T.serial(tile_size_per_bdx):
                                         row_g: T.int32(is_size_var=True) = 
tile_start_g + j
                                         if row_g < kv_chunk_len[0]:
-                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, page_size)]
-                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, page_size)
+                                            page_no: T.int32(is_size_var=True) 
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]
+                                            page_offset: 
T.int32(is_size_var=True) = T.floormod(row_g, 16)
                                             for vec in T.vectorized(VEC_SIZE):
                                                 V_smem[tile_start_s + j, tx * 
VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec]
                                         else:
@@ -989,19 +976,23 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                             for vec in T.vectorized(VEC_SIZE):
                                                 K_local[vec] = K_smem[tz * bdy 
* tile_size_per_bdx + j, tx * VEC_SIZE + vec]
                                             # compute S = Q * K * sm_scale
-                                            S_local[j] = 0
+                                            S_reduce_local[0] = 0
                                             for vec in T.serial(VEC_SIZE):
-                                                S_local[j] += Q_local[vec] * 
K_local[vec] * sm_scale
-                                            # allreduce over bdx
-                                            S_allreduce[tz, ty, tx] = 
S_local[j]
-                                            T.tvm_storage_sync("shared")
-                                            offset[0] = bdx // 2
-                                            while offset[0] > 0:
-                                                if tx < offset[0]:
-                                                    S_allreduce[tz, ty, tx] += 
S_allreduce[tz, ty, tx + offset[0]]
-                                                T.tvm_storage_sync("shared")
-                                                offset[0] = offset[0] >> 1
-                                            S_local[j] = S_allreduce[tz, ty, 0]
+                                                S_reduce_local[0] += 
Q_local[vec] * K_local[vec] * sm_scale
+
+                                            t0[0] = 
T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 16, 32, 32)
+                                            S_reduce_local[0] = 
S_reduce_local[0] + t0[0]
+                                            t0[0] = 
T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 8, 32, 32)
+                                            S_reduce_local[0] = 
S_reduce_local[0] + t0[0]
+                                            t0[0] = 
T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 4, 32, 32)
+                                            S_reduce_local[0] = 
S_reduce_local[0] + t0[0]
+                                            t0[0] = 
T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 2, 32, 32)
+                                            S_reduce_local[0] = 
S_reduce_local[0] + t0[0]
+                                            t0[0] = 
T.tvm_warp_shuffle_down(mask[0], S_reduce_local[0], 1, 32, 32)
+                                            S_reduce_local[0] = 
S_reduce_local[0] + t0[0]
+                                            S_reduce_local[0] = 
T.tvm_warp_shuffle(mask[0], S_reduce_local[0], 0, 32, 32)
+
+                                            S_local[j] = S_reduce_local[0]
                                         # update st_m
                                         st_m[0] = T.max(st_m[0], S_local[j])
 
@@ -1054,13 +1045,412 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                 # store O to global memory
                                 for vec in T.vectorized(VEC_SIZE):
                                     output[batch_idx, by * GROUP_SIZE + ty, tx 
* VEC_SIZE + vec] = O_local[vec]
+
+                                # store lse to global memory
+                                lse[batch_idx, by * GROUP_SIZE + ty] = st_m[0] 
+ T.log2(st_d[0])
     # fmt: on
     # pylint: 
enable=line-too-long,invalid-name,too-many-arguments,too-many-branches
     return batch_decode_paged_kv
 
 
+def _attention_prefill_ragged(h_kv, h_q, d, dtype):
+    assert dtype == "float16", f"TIR attention kernel does not support dtype 
{dtype} right now"
+    # pylint: disable=invalid-name
+    NUM_BLKS = 16
+    LOAD_VEC = 8 // ((tvm.DataType(dtype).bits + 7) // 8)  # 8 bytes
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+    num_warps = 4
+    tile_x, tile_y, tile_z = 32, d, 16
+    L_per_cta = tile_x // group_size
+
+    def mask(causal, row, col, kv_len, qo_len):
+        return T.if_then_else(
+            causal > 0,
+            col < kv_len - qo_len + row + 1,
+            col < kv_len,
+        )
+
+    # fmt: off
+    @T.prim_func
+    def batch_prefill_ragged_kv(
+        var_q: T.handle, # [total_len, h_q, d]
+        var_q_indptr: T.handle, # [batch_size + 1]
+        var_k: T.handle, # [total_len, h_kv, d]
+        var_v: T.handle, # [total_len, h_kv, d]
+        var_kv_indptr: T.handle, # [batch_size + 1]
+        var_q_rope_position: T.handle, # [total_q_len]
+        var_k_rope_pos_offset: T.handle, # [b]
+        var_output: T.handle, # [total_len, h_q, d]
+        var_lse: T.handle, # [total_len, h_q]
+        causal: T.int32,
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+    ):
+        batch_size = T.int32(is_size_var=True)
+        qo_len = T.int32(is_size_var=True)
+        kv_len = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32")
+        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32")
+        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32")
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32")
+        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
+
+        # kernel code
+        for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
+            for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
+                for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
+                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
+                        with T.block("attn"):
+                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, 
lty, ltx])
+                            T.reads()
+                            T.writes()
+                            tile_id = _var("int32")
+                            batch_idx = _var("int32")
+                            batch_tiles = _var("int32")
+                            batch_rows = _var("int32")
+                            iterator = _var("int32")
+                            kv_chunk_len = _var("int32")
+
+                            Q_smem = T.alloc_buffer((tile_x, d), dtype, 
scope="shared")
+                            K_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            V_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            S_smem = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="shared")
+
+                            S_local = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="local")
+                            O_local = T.alloc_buffer((tile_x, d), "float32", 
scope="local")
+
+                            m_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
+                            m_prev_smem = T.alloc_buffer((tile_x, ), 
"float32", scope="shared")
+                            d_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
+
+                            m_new = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+                            m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+                            d_new = T.alloc_buffer((math.ceil(tile_x / (32 * 
num_warps)),), "float32", scope="local")
+
+                            ## get tile_no, batch_idx, batch_tiles, batch_rows
+                            tile_id[0] = bx
+                            batch_idx[0] = 0
+                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 
group_size
+                            batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x)
+                            while T.tvm_thread_invariant(batch_idx[0] < 
batch_size):
+                                # advance to next tile
+                                while tile_id[0] >= batch_tiles[0] and 
batch_idx[0] < batch_size:
+                                    tile_id[0] -= batch_tiles[0]
+                                    batch_idx[0] += 1
+                                    if batch_idx[0] < batch_size:
+                                        b_idx: T.int32 = batch_idx[0]
+                                        batch_rows[0] = (q_indptr[b_idx + 1] - 
q_indptr[b_idx]) * group_size
+                                        batch_tiles[0] = 
T.ceildiv(batch_rows[0], tile_x)
+
+                                if T.tvm_thread_invariant(batch_idx[0] < 
batch_size):
+                                    b_idx: T.int32 = batch_idx[0]
+                                    L_start: T.int32 = q_indptr[b_idx] + 
tile_id[0] * L_per_cta
+                                    H_qo_start: T.int32 = by * group_size
+
+                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - 
kv_indptr[b_idx]
+                                    T.tvm_storage_sync("shared")
+
+                                    # init states
+                                    for i in T.serial(T.ceildiv(tile_x, 32 * 
num_warps)):
+                                        row: T.int32 = i * 32 * num_warps + ty 
* 32 + tx
+                                        if row < tile_x:
+                                            m_smem[row] = -5e4
+                                            d_smem[row] = 1.0
+
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("O_init"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            O_local[i, j] = 0.0
+                                    T.tvm_storage_sync("shared")
+
+                                    # Load Q from gmem to smem
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("Q_load"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            T.reads()
+                                            T.writes()
+                                            cur_L = L_start + i // group_size
+                                            cur_H_qo = H_qo_start + i % 
group_size
+                                            if cur_L < q_indptr[b_idx + 1]:
+                                                Q_smem[i, j] = T.if_then_else(
+                                                    rotary_mode == 1,
+                                                    _rope(q, 
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j)),
+                                                    q[cur_L, cur_H_qo, j]
+                                                )
+                                            else:
+                                                Q_smem[i, j] = 0.0
+                                    T.tvm_storage_sync("shared")
+
+                                    for iterator in 
T.serial(T.ceildiv(kv_chunk_len[0], tile_z)):
+                                        L_kv_start: T.int32 = iterator * tile_z
+                                        L_kv_base: T.int32 = kv_indptr[b_idx]
+                                        for lz, ly in T.grid(tile_z, tile_y):
+                                            with T.block("K_load"):
+                                                i, j = T.axis.remap("SS", [lz, 
ly])
+                                                T.reads()
+                                                T.writes()
+                                                cur_L = L_kv_start + i
+                                                if cur_L < kv_chunk_len[0]:
+                                                    K_smem[i, j] = 
T.if_then_else(
+                                                        rotary_mode == 1,
+                                                        _rope(k, 
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + 
cur_L, by, j)),
+                                                        k[L_kv_base + cur_L, 
by, j]
+                                                    )
+                                                else:
+                                                    K_smem[i, j] = 0.0
+                                        T.tvm_storage_sync("shared")
+                                        for lz, ly in T.grid(tile_z, tile_y):
+                                            with T.block("V_load"):
+                                                i, j = T.axis.remap("SS", [lz, 
ly])
+                                                T.reads()
+                                                T.writes()
+                                                cur_L = L_kv_start + i
+                                                if cur_L < kv_chunk_len[0]:
+                                                    V_smem[i, j] = v[L_kv_base 
+ cur_L, by, j]
+                                                else:
+                                                    V_smem[i, j] = 0.0
+                                        T.tvm_storage_sync("shared")
+
+                                        # Compute S
+                                        with T.block():
+                                            for li, lj, lk in T.grid(tile_x, 
tile_z, tile_y):
+                                                with T.block("S_gemm"):
+                                                    i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
+                                                    with T.init():
+                                                        S_local[i, j] = 0.0
+                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * sm_scale
+                                        T.tvm_storage_sync("shared")
+                                        for li, lj in T.grid(tile_x, tile_z):
+                                            with T.block("S_store"):
+                                                i, j = T.axis.remap("SS", [li, 
lj])
+                                                S_smem[i, j] = S_local[i, j]
+                                        T.tvm_storage_sync("shared")
+
+                                        # Update S, m, d
+                                        for i in T.serial(T.ceildiv(tile_x, 32 
* num_warps)):
+                                            row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
+                                            if row < tile_x:
+                                                with T.block("update1"):
+                                                    m_prev[i] = m_smem[row]
+                                                    m_new[i] = m_smem[row]
+                                                    # mask out of kv_chunk_len 
S
+                                                    for j in T.serial(tile_z):
+                                                        if mask(causal,
+                                                                row=tile_id[0] 
* L_per_cta + row // group_size,
+                                                                col=L_kv_start 
+ j,
+                                                                
kv_len=kv_chunk_len[0],
+                                                                
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
+                                                            m_new[i] = 
T.max(m_new[i], S_smem[row, j])
+                                                    d_new[i] = d_smem[row] * 
T.exp2(m_prev[i] - m_new[i])
+
+                                        for i in T.serial(T.ceildiv(tile_x, 32 
* num_warps)):
+                                            row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
+                                            with T.block("update"):
+                                                for j in T.serial(tile_z):
+                                                    # this is to avoid sync 
inside condition branch
+                                                    if row < tile_x:
+                                                        if mask(causal,
+                                                                row=tile_id[0] 
* L_per_cta + row // group_size,
+                                                                col=L_kv_start 
+ j,
+                                                                
kv_len=kv_chunk_len[0],
+                                                                
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
+                                                            S_smem[row, j] = 
T.exp2(S_smem[row, j] - m_new[i])
+                                                        else:
+                                                            S_smem[row, j] = 
T.exp2(-5e4 - m_new[i])
+
+                                        for i in T.serial(T.ceildiv(tile_x, 32 
* num_warps)):
+                                            row: T.int32 = i * 32 * num_warps 
+ ty * 32 + tx
+                                            if row < tile_x:
+                                                with T.block("update"):
+                                                    for j in T.serial(tile_z):
+                                                        d_new[i] += 
S_smem[row, j]
+                                                    m_smem[row] = m_new[i]
+                                                    d_smem[row] = d_new[i]
+                                                    m_prev_smem[row] = 
m_prev[i]
+                                        T.tvm_storage_sync("shared")
+
+                                        # Update O
+                                        with T.block():
+                                            for li, lj, lk in T.grid(tile_x, 
tile_y, tile_z):
+                                                with T.block("O_gemm"):
+                                                    i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
+                                                    with T.init():
+                                                        O_local[i, j] *= 
T.exp2(m_prev_smem[i] - m_smem[i])
+                                                    O_local[i, j] += S_smem[i, 
k] * V_smem[k, j]
+
+                                    # Store O from smem to gmem
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("O_store"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            if L_start + i // group_size < 
q_indptr[b_idx + 1]:
+                                                output[L_start + i // 
group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
+
+                                    # Store LSE to gmem
+                                    for li in T.grid(tile_x):
+                                        with T.block("lse_store"):
+                                            i = T.axis.remap("S", [li])
+                                            if L_start + i // group_size < 
q_indptr[b_idx + 1]:
+                                                lse[L_start + i // group_size, 
H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
+
+                                    # move to next tile
+                                    tile_id[0] += NUM_BLKS
+    # fmt: on
+    # pylint: 
enable=line-too-long,invalid-name,too-many-arguments,too-many-branches
+    sch = tir.Schedule(batch_prefill_ragged_kv)
+
+    def get_tile_size(x, y, t):
+        cnt = (x * y) // t
+        assert (x * y) % t == 0
+        tile_y = (int)(math.ceil(math.sqrt(cnt)))
+        while cnt % tile_y != 0 and y % tile_y != 0 and tile_y <= cnt:
+            tile_y += 1
+        assert tile_y <= cnt
+        tile_x = cnt // tile_y
+        return tile_x, tile_y
+
+    def apply_to_qkv_load(sch: tir.Schedule, block):
+        loop_x, loop_y = sch.get_loops(block)[-2:]
+        loop = sch.fuse(loop_x, loop_y)
+        _, ty, tx, vec = sch.split(
+            loop, factors=[None, num_warps, 32, LOAD_VEC], 
preserve_unit_iters=True
+        )
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+        sch.vectorize(vec)
+
+    def apply_to_so_ewise(sch: tir.Schedule, block, tile, vec_len=4):
+        loop_x, loop_y = sch.get_loops(block)[-2:]
+        xo, xi = sch.split(loop_x, factors=[None, tile[0]])
+        yo, yi = sch.split(loop_y, factors=[None, tile[1]])
+        sch.reorder(xo, yo, xi, yi)
+        t = sch.fuse(xo, yo)
+        ty, tx = sch.split(t, factors=[num_warps, 32])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+        if tile[1] % vec_len == 0:
+            yi, vec = sch.split(yi, factors=[None, vec_len])
+            sch.vectorize(vec)
+        elif tile[1] in [2, 4]:
+            sch.vectorize(yi)
+
+    def apply_to_gemm(  # pylint: disable=too-many-arguments,unused-argument
+        sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False
+    ):
+        loop_x, loop_y, loop_z = sch.get_loops(block)[-3:]
+        xo, xi = sch.split(loop_x, factors=[None, tile[0]])
+        yo, yi = sch.split(loop_y, factors=[None, tile[1]])
+        sch.reorder(xo, yo, xi, yi)
+        t = sch.fuse(xo, yo)
+        ty, tx = sch.split(t, factors=[num_warps, 32])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
+        ko, ki = sch.split(loop_z, factors=[None, r_len])
+        if k_major:
+            sch.reorder(ko, xi, yi, ki)
+        else:
+            sch.reorder(ko, ki, xi, yi)
+        sch.decompose_reduction(block, ty)
+
+    def apply_to_md(sch, block):
+        loop = sch.get_loops(block)[-1]
+        _, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
+    tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
+    tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
+    apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
+    apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
+    apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
+    apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o)
+    apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o)
+    apply_to_qkv_load(sch, sch.get_block("Q_load"))
+    apply_to_qkv_load(sch, sch.get_block("K_load"))
+    apply_to_qkv_load(sch, sch.get_block("V_load"))
+
+    apply_to_md(sch, sch.get_block("lse_store"))
+    return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+
+
+def _merge_state_inplace(num_heads, head_dim, v_dtype):
+    # pylint: disable=invalid-name
+    v_dtype_bytes = 2
+    VEC_SIZE = max(8 // v_dtype_bytes, head_dim // 32)
+    bdx = head_dim // VEC_SIZE
+    bdy = num_heads
+
+    @T.prim_func
+    def merge_state_inplace(
+        v: T.handle,
+        s: T.handle,
+        v_other: T.handle,
+        s_other: T.handle,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        N = T.int32(is_size_var=True)
+        H = T.int32(is_size_var=True)
+        D = T.int32(is_size_var=True)
+
+        V = T.match_buffer(v, (N, H, D), v_dtype)
+        S = T.match_buffer(s, (N, H), "float32")
+        V_other = T.match_buffer(v_other, (N, H, D), v_dtype)
+        S_other = T.match_buffer(s_other, (N, H), "float32")
+
+        for bx in T.thread_binding(N, thread="blockIdx.x"):
+            for ty in T.thread_binding(bdy, thread="threadIdx.y"):
+                for tx in T.thread_binding(bdx, thread="threadIdx.x"):
+                    with T.block("merge"):
+                        s_val = _var("float32")
+                        s_other_val = _var("float32")
+                        s_max = _var("float32")
+                        scale = _var("float32")
+                        other_scale = _var("float32")
+
+                        v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, 
scope="local")
+                        v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, 
scope="local")
+
+                        s_val[0] = S[bx, ty]
+                        s_other_val[0] = S_other[bx, ty]
+                        s_max[0] = T.max(s_val[0], s_other_val[0])
+                        s_val[0] = T.exp2(s_val[0] - s_max[0])
+                        s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
+                        scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
+                        other_scale[0] = s_other_val[0] / (s_val[0] + 
s_other_val[0])
+
+                        # load v
+                        for vec in T.vectorized(VEC_SIZE):
+                            v_vec[vec] = V[bx, ty, tx * VEC_SIZE + vec]
+                        # load v_other
+                        for vec in T.vectorized(VEC_SIZE):
+                            v_other_vec[vec] = V_other[bx, ty, tx * VEC_SIZE + 
vec]
+
+                        # merge
+                        for vec in T.serial(VEC_SIZE):
+                            v_vec[vec] = v_vec[vec] * scale[0] + 
v_other_vec[vec] * other_scale[0]
+
+                        # store v
+                        for vec in T.vectorized(VEC_SIZE):
+                            V[bx, ty, tx * VEC_SIZE + vec] = v_vec[vec]
+
+                        # store s
+                        S[bx, ty] = T.log2(s_val[0] + s_other_val[0]) + 
s_max[0]
+
+    # pylint: enable=invalid-name
+    return merge_state_inplace
+
+
 if __name__ == "__main__":
     cache = create_kv_cache()
     test_paged_attention_kv_cache_prefill_and_decode(cache)
     test_paged_attention_kv_cache_remove_sequence(cache)
+    test_paged_attention_kv_cache_fork_sequence(cache)
     test_paged_attention_kv_cache_popn(cache)

Reply via email to