This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 89b91e2b11 [KVCache] Partial layers support (#17192)
89b91e2b11 is described below

commit 89b91e2b1195b53bf7e1f6c250bc9a1247367d13
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jul 23 21:13:41 2024 -0700

    [KVCache] Partial layers support (#17192)
    
    This PR updates the KVCache implementation, to support partial layers.
---
 include/tvm/runtime/disco/disco_worker.h           | 15 ++++
 src/runtime/disco/disco_worker.cc                  |  9 ---
 src/runtime/relax_vm/paged_kv_cache.cc             | 82 +++++++++++++++-------
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  2 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py |  2 +-
 5 files changed, 73 insertions(+), 37 deletions(-)

diff --git a/include/tvm/runtime/disco/disco_worker.h 
b/include/tvm/runtime/disco/disco_worker.h
index 301b5b8d62..13f94802c8 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -93,6 +93,21 @@ class DiscoWorker {
   struct Impl;
   friend struct DiscoWorker::Impl;
 };
+/*!
+ * \brief A threadlocal wrapper of DiscoWorker.
+ */
+struct ThreadLocalDiscoWorker {
+  /*! \brief The Disco worker */
+  DiscoWorker* worker;
+
+  /*!
+   * \brief Get the threadlocal Disco worker.
+   */
+  static ThreadLocalDiscoWorker* Get() {
+    thread_local static ThreadLocalDiscoWorker worker;
+    return &worker;
+  }
+};
 
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/disco/disco_worker.cc 
b/src/runtime/disco/disco_worker.cc
index b281a3aca7..5e6f401054 100644
--- a/src/runtime/disco/disco_worker.cc
+++ b/src/runtime/disco/disco_worker.cc
@@ -28,15 +28,6 @@
 namespace tvm {
 namespace runtime {
 
-struct ThreadLocalDiscoWorker {
-  DiscoWorker* worker;
-
-  static ThreadLocalDiscoWorker* Get() {
-    thread_local static ThreadLocalDiscoWorker worker;
-    return &worker;
-  }
-};
-
 TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
   DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
   CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index ec1cc3593a..2fb8a72f42 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -21,6 +21,7 @@
  * \brief Runtime paged KV cache object for language models.
  */
 #include <tvm/runtime/device_api.h>
+#include <tvm/runtime/disco/disco_worker.h>
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/memory/memory_manager.h>
 #include <tvm/runtime/ndarray.h>
@@ -825,6 +826,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   const int64_t page_size_;
   /*! \brief The number of layers in the model. */
   const int64_t num_layers_;
+  /*! \brief The beginning layer id offset. */
+  const int64_t layer_id_begin_offset_;
   /*! \brief The number of query/output heads in the model. */
   const int64_t num_qo_heads_;
   /*! \brief The number of key/value heads in the model. */
@@ -981,14 +984,14 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
  public:
   /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
   explicit PagedAttentionKVCacheObj(
-      int64_t page_size,  //
-      int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t 
head_dim,
-      int64_t reserved_num_seqs, int64_t num_total_pages, int64_t 
prefill_chunk_size,
-      bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, 
double rotary_theta,
-      DLDataType dtype, Device device, PackedFunc f_transpose_append, 
PackedFunc f_compact_copy,
-      PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
-      PackedFunc f_attention_prefill_sliding_window, PackedFunc 
f_attention_decode_sliding_window,
-      PackedFunc f_attention_prefill_ragged, PackedFunc 
f_attention_prefill_with_tree_mask,
+      int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset,  //
+      int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t 
reserved_num_seqs,
+      int64_t num_total_pages, int64_t prefill_chunk_size, bool 
support_sliding_window,
+      RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType 
dtype, Device device,
+      PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc 
f_attention_prefill,
+      PackedFunc f_attention_decode, PackedFunc 
f_attention_prefill_sliding_window,
+      PackedFunc f_attention_decode_sliding_window, PackedFunc 
f_attention_prefill_ragged,
+      PackedFunc f_attention_prefill_with_tree_mask,
       Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
       Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
       Optional<PackedFunc> f_attention_prefill_begin_forward,
@@ -998,6 +1001,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       PackedFunc f_split_rotary, PackedFunc f_copy_single_page, 
Optional<PackedFunc> f_debug_get_kv)
       : page_size_(page_size),
         num_layers_(num_layers),
+        layer_id_begin_offset_(layer_id_begin_offset),
         num_qo_heads_(num_qo_heads),
         num_kv_heads_(num_kv_heads),
         head_dim_(head_dim),
@@ -1672,7 +1676,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
                              NDArray o_data, double attn_score_scaling_factor) 
final {
     // Part 1. Shape and dtype check.
-    NDArray pages = pages_[layer_id];
+    int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+    CHECK_GE(local_layer_id, 0);
+    CHECK_LT(local_layer_id, num_layers_);
+    NDArray pages = pages_[local_layer_id];
     CHECK(qkv_data.DataType() == pages.DataType());
     CHECK(o_data.DataType() == pages.DataType());
 
@@ -1713,13 +1720,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
     // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
     if (append_before_attn_) {
-      f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
+      f_transpose_append_(pages_[local_layer_id], k_data, v_data, 
append_position_map_view_);
     }
     // Part 4: perform attention
     AttentionInternal(layer_id, q_data, k_data, v_data, o_data, 
attn_score_scaling_factor);
     // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not 
set.
     if (!append_before_attn_) {
-      f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
+      f_transpose_append_(pages_[local_layer_id], k_data, v_data, 
append_position_map_view_);
     }
   }
 
@@ -2238,6 +2245,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
    */
   void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
                          NDArray output, double attn_score_scaling_factor) {
+    int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+    CHECK_GE(local_layer_id, 0);
+    CHECK_LT(local_layer_id, num_layers_);
     PackedFunc f_prefill =
         !support_sliding_window_ ? f_attention_prefill_ : 
f_attention_prefill_sliding_window_;
     PackedFunc f_decode =
@@ -2245,7 +2255,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     CHECK_GE(num_depths_, 1) << "The number of effective depths must be 
greater or equal to 1.";
     if (append_before_attn_) {
       f_decode(
-          /*depth=*/0, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[0],
+          /*depth=*/0, q_data, pages_[local_layer_id], 
page_indptr_on_depths_view_[0],
           page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
           k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, 
merged_attn_scores_view_,
           /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
@@ -2280,7 +2290,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         }
         if (use_decode_kernel_[d]) {
           // Use decode kernel for depth d
-          f_decode(/*depth=*/d, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[d],
+          f_decode(/*depth=*/d, q_data, pages_[local_layer_id], 
page_indptr_on_depths_view_[d],
                    page_indices_on_depths_view_[d], 
length_info_on_depths_view_[d],
                    k_rope_pos_offset_view_[d], q_rope_position_map_view_, 
temp_attn_output_view_,
                    temp_attn_scores_view_,
@@ -2289,7 +2299,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         } else {
           // Use prefill kernel for depth d
           f_prefill(
-              /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], 
pages_[layer_id],
+              /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], 
pages_[local_layer_id],
               page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
               length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
               temp_attn_output_view_, temp_attn_scores_view_,
@@ -2436,7 +2446,17 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
       CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
           << "Invalid number of KV cache constructor args.";
       ShapeTuple cache_config = args[0];
-      int64_t num_layers = args[1];
+      ShapeTuple layer_indptr_tuple = args[1];
+      int num_groups = 1;
+      int group_id = 0;
+      if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+        // In the Disco worker thread
+        num_groups = disco_worker->num_groups;
+        group_id = disco_worker->worker_id / (disco_worker->num_workers / 
num_groups);
+      }
+      CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+      int64_t num_layers = layer_indptr_tuple[group_id + 1] - 
layer_indptr_tuple[group_id];
+      int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
       int64_t num_qo_heads = args[2];
       int64_t num_kv_heads = args[3];
       int64_t head_dim = args[4];
@@ -2482,11 +2502,11 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
         num_total_pages += reserved_num_seqs * 2;
       }
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
-          page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
-          num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
-          rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
-          std::move(f_compact_copy), std::move(f_attention_prefill), 
std::move(f_attention_decode),
-          std::move(f_attention_prefill_sliding_window),
+          page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
+          reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
+          RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, 
init->device,
+          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
+          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),
           std::move(f_attention_prefill_ragged_begin_forward),
@@ -2503,7 +2523,17 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
       CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
           << "Invalid number of KV cache constructor args.";
       ShapeTuple cache_config = args[0];
-      int64_t num_layers = args[1];
+      ShapeTuple layer_indptr_tuple = args[1];
+      int num_groups = 1;
+      int group_id = 0;
+      if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+        // In the Disco worker thread
+        num_groups = disco_worker->num_groups;
+        group_id = disco_worker->worker_id / (disco_worker->num_workers / 
num_groups);
+      }
+      CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+      int64_t num_layers = layer_indptr_tuple[group_id + 1] - 
layer_indptr_tuple[group_id];
+      int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
       int64_t num_qo_heads = args[2];
       int64_t num_kv_heads = args[3];
       int64_t head_dim = args[4];
@@ -2543,11 +2573,11 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
         num_total_pages += reserved_num_seqs * 2;
       }
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
-          page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
-          num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
-          rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
-          std::move(f_compact_copy), std::move(f_attention_prefill), 
std::move(f_attention_decode),
-          std::move(f_attention_prefill_sliding_window),
+          page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
+          reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
+          RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, 
init->device,
+          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
+          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),         //
           NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,  //
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 048cf49806..bade04a7d7 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -354,7 +354,7 @@ def create_kv_cache(rope_mode):
                 support_sliding_window,
             ]
         ),
-        num_layers,
+        tvm.runtime.ShapeTuple([0, num_layers]),
         num_qo_heads,
         num_kv_heads,
         head_dim,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 34680160c8..9192bb901f 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -153,7 +153,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window):
                 int(support_sliding_window),
             ]
         ),
-        num_layers,
+        tvm.runtime.ShapeTuple([0, num_layers]),
         num_qo_heads,
         num_kv_heads,
         head_dim,

Reply via email to