This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 76b954a09e [3rdparty] Bump FlashInfer (#17236)
76b954a09e is described below

commit 76b954a09e781b7f664b1d345e1494123c19484c
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 3 04:28:02 2024 -0400

    [3rdparty] Bump FlashInfer (#17236)
    
    This PR bumps FlashInfer and updates PagedKVCache accordingly
    for performance improvement.
    
    Some notes on this bump:
    
    * When the Grouped-Query Attention group size is at least 4 and
    FlashInfer is enabled, we use the prefill attn kernel for better
    performance.
    * We enlarge the temporary workspace for FlashInfer use accordingly,
    as FlashInfer in the current version may consume much larger workspace.
    We turn off the workspace when FlashInfer is not enabled.
    * We reduce the max block depth to be 2, in observation of the limited
    help of cascade inference when batch size is not large and the prompt
    reuse is low.
---
 3rdparty/flashinfer                                |  2 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 48 +++++++++++++++-------
 ..._builtin_paged_attention_kv_cache_flashinfer.py | 13 +++++-
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 13 +++++-
 4 files changed, 58 insertions(+), 18 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 7e9cc7ff42..0dd801d202 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2
+Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 2fb8a72f42..5aa1411ec1 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -54,11 +54,11 @@ namespace relax_vm {
  * \brief The maximum allowed block depth (a.k.a. number of common
  * prefixes) in paged KV cache.
  */
-constexpr const int kPagedKVCacheMaxBlockDepth = 5;
+constexpr const int kPagedKVCacheMaxBlockDepth = 2;
 /*! \brief The maximum tree size of a single sequence in tree attention. */
 constexpr const int kTreeAttnMaxTreeSize = 256;
 /*! \brief The 8MB workspace size for attention auxiliary data. */
-constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
+constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
 /*! \brief The id of the temporary logical page, which is useful for sliding 
window. */
 constexpr const int kPagedKVCacheTempPageId = -1;
 
@@ -119,6 +119,9 @@ struct Block {
   void Reset() {
     page_ids.clear();
     seq_length = 0;
+    start_pos = 0;
+    sink_length = 0;
+    sliding_window_offset = 0;
     parent_idx = -1;
     external_ref_cnt = 0;
   }
@@ -169,11 +172,9 @@ struct Sequence {
     this->last_block_idx = last_block_idx;
     int32_t block_ptr = last_block_idx;
     // Go through each block in the sequence, sum up the length.
-    int depth = 0;
     while (true) {
       const Block& block = global_block_pool->at(block_ptr);
       this->seq_length += block.seq_length;
-      ++depth;
       if (block.parent_idx == -1) {
         break;
       }
@@ -1078,8 +1079,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
                          dtype_aux_, preferred_host_device);
 
     for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
-      temp_attn_workspace_.push_back(
-          NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), 
device));
+      if (NeedKernelBeginForward()) {
+        temp_attn_workspace_.push_back(
+            NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), 
device));
+      }
       qo_indptr_on_depths_view_.push_back(NDArray());
       page_indptr_on_depths_view_.push_back(NDArray());
       page_indices_on_depths_view_.push_back(NDArray());
@@ -1087,8 +1090,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       k_rope_pos_offset_view_.push_back(NDArray());
     }
     // Additional workspace for the "prefill with ragged kv" kernel.
-    temp_attn_workspace_.push_back(
-        NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
+    if (NeedKernelBeginForward()) {
+      temp_attn_workspace_.push_back(
+          NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), 
device));
+    }
 
     temp_attn_q_device_ =
         NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, 
device);
@@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
 
     append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && 
use_decode_kernel_[0];
+    if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
+      // When GQA group size is at least 4 and FlashInfer is enabled,
+      // we always use prefill kernel for better performance.
+      std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), 
/*value=*/false);
+    }
+
     if (append_before_attn_) {
       // Right now we use different kernels when depth is 1 or not 1.
       // For the case where maximum depth is 1, we create the auxiliary
@@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
             use_decode_kernel};
   }
 
+  /*! \brief Check whether BeginForward for kernels is needed. */
+  bool NeedKernelBeginForward() {
+    return f_attention_prefill_begin_forward_.defined() &&
+           f_attention_decode_begin_forward_.defined() &&
+           f_attention_prefill_ragged_begin_forward_.defined();
+  }
+
   /*! \brief Invoke the "begin forward" functions of underlying kernels. */
   void KernelBeginForward() {
-    if (!f_attention_prefill_begin_forward_.defined() ||
-        !f_attention_decode_begin_forward_.defined() ||
-        !f_attention_prefill_ragged_begin_forward_.defined()) {
+    if (!NeedKernelBeginForward()) {
       return;
     }
 
@@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       }
     } else {
       f_attention_prefill_ragged_begin_forward_.value()(
-          temp_attn_workspace_[0], 
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
-          num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_);
+          temp_attn_workspace_[0], 
cur_append_lengths_indptr_host_.as_ndarray(),
+          cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, 
num_qo_heads_,
+          num_kv_heads_, head_dim_, copy_stream_);
       if (support_sliding_window_) {
         return;
       }
@@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         } else {
           f_attention_prefill_begin_forward_.value()(
               /*depth=*/d, temp_attn_workspace_[d + 1], 
qo_indptr_on_depths_host_[d].as_ndarray(),
-              length_info_on_depths_view_[d]->shape[0], num_qo_heads_, 
num_kv_heads_, head_dim_,
-              copy_stream_);
+              page_indptr_on_depths_host_[d].as_ndarray(),
+              static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, 
num_qo_heads_,
+              num_kv_heads_, head_dim_, page_size_, copy_stream_);
         }
       }
     }
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index bade04a7d7..cab10f84cd 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -29,7 +29,7 @@ from tvm.runtime import ShapeTuple
 from tvm.script import tir as T
 
 reserved_nseq = 32
-maximum_total_seq_length = 1024
+maximum_total_seq_length = 2048
 prefill_chunk_size = 512
 page_size = 16
 num_layers = 4
@@ -249,6 +249,7 @@ def _copy_single_page(num_heads, page_size, head_dim, 
dtype, target):
         ):
             for t in T.thread_binding(tx, thread="threadIdx.x"):
                 with T.block("copy"):
+                    T.where(b * tx + t < copy_length * num_heads * head_dim)
                     vh = T.axis.spatial(
                         num_heads,
                         T.Cast("int32", (b * tx + t) // (copy_length * 
head_dim)),
@@ -662,6 +663,16 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
         cached_v.pop(i)
         verify_cached_kv(kv_cache, seq_ids=list(range(i)), 
expected_k=cached_k, expected_v=cached_v)
 
+    # Test fork after page recycle
+    apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+    apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, 
cached_v)
+
 
 @pytest.mark.skip(reason="Require FlashInfer enabled")
 def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 9192bb901f..3c85a13e4c 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -33,7 +33,7 @@ from tvm.script import tir as T
 from tvm.target import Target
 
 reserved_nseq = 32
-maximum_total_seq_length = 1024
+maximum_total_seq_length = 2048
 prefill_chunk_size = 512
 page_size = 16
 num_layers = 4
@@ -615,6 +615,16 @@ def 
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
 
     assert fis_empty(kv_cache), "The KV cache is not empty after removing all 
sequences"
 
+    # Test fork after page recycle
+    apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
+
+    apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, 
cached_v)
+    apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, 
cached_v)
+
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
@@ -2547,6 +2557,7 @@ def _copy_single_page(num_heads, page_size, head_dim, 
dtype, target: Target):
         ):
             for t in T.thread_binding(tx, thread="threadIdx.x"):
                 with T.block("copy"):
+                    T.where(b * tx + t < copy_length * num_heads * head_dim)
                     vh = T.axis.spatial(
                         num_heads,
                         T.Cast("int32", (b * tx + t) // (copy_length * 
head_dim)),

Reply via email to