This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7a0c3f9e05 [Unity][Support] PagedKVCache support growth control 
(#16112)
7a0c3f9e05 is described below

commit 7a0c3f9e056a771b6854d91be6656b31e871e622
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Nov 12 12:54:19 2023 -0500

    [Unity][Support] PagedKVCache support growth control (#16112)
    
    This PR supports controlling whether KV cache automatic
    growth is allowed through constructor parameter. Previously
    we always allow the KV cache to grow whenever it is full
    and more capacity is demanded.
    
    Although automatic growth can be good, in practice we
    often want the pre-allocated memory to be static, large enough
    and not changeable, which will make the memory management more
    controllable. Hence, this PR supports to specify if growth
    is allowed, and will throw error when growing in unallowed cases.
    
    This PR also adds an auxiliary function to KV cache to query
    the number of available pages.
---
 src/runtime/relax_vm/paged_kv_cache.cc             | 46 ++++++++++++++++------
 ...est_runtime_builtin_paged_attention_kv_cache.py |  5 +++
 2 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 4c61af9018..6d2444ea64 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -77,6 +77,8 @@ class PagedAttentionKVCacheObj : public Object {
   const int64_t num_heads_;
   /*! \brief The number of features each head has. */
   const int64_t head_dim_;
+  /*! \brief A boolean denoting if cache automatic growth is allowed. */
+  const bool allow_growth_;
 
   /*! \brief We fix int32 to be the index dtype of auxiliary data. */
   const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
@@ -145,8 +147,6 @@ class PagedAttentionKVCacheObj : public Object {
    * length dimension of K/V data. It is used for efficient computation.
    */
   NDArray cur_pos2seqid_device_;
-  /*! \brief A temporary buffer for efficient attention computation. */
-  NDArray attn_tmp_buffer_;
 
   //-------------------------------------------
   // For efficient memory management, the actual sizes of the arrays
@@ -165,8 +165,13 @@ class PagedAttentionKVCacheObj : public Object {
   /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
   explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers, 
int64_t num_heads,
                                     int64_t head_dim, int64_t 
reserved_num_seqs,
-                                    int64_t reserved_num_pages, DLDataType 
dtype, DLDevice device)
-      : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads), 
head_dim_(head_dim) {
+                                    int64_t reserved_num_pages, DLDataType 
dtype, DLDevice device,
+                                    bool allow_growth)
+      : page_size_(page_size),
+        num_layers_(num_layers),
+        num_heads_(num_heads),
+        head_dim_(head_dim),
+        allow_growth_(allow_growth) {
     pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads, 
page_size, head_dim},
                             dtype, device);
     page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
@@ -174,7 +179,6 @@ class PagedAttentionKVCacheObj : public Object {
     last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
     cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
     cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size}, 
dtype_aux_, device);
-    attn_tmp_buffer_ = NDArray::Empty({8 * 1024 * 1024}, 
DLDataType(DataType::Float(32)), device);
   }
 
   /*!
@@ -272,7 +276,7 @@ class PagedAttentionKVCacheObj : public Object {
     f_attention(q_data, pages_,                                          //
                 page_table_indptr_view_, page_table_values_view_,        //
                 last_page_offset_view_, cur_append_length_indptr_view_,  //
-                layer_id, attn_tmp_buffer_, output, apply_rotary, 
rotary_scale, rotary_theta);
+                layer_id, output, apply_rotary, rotary_scale, rotary_theta);
   }
 
   /*!
@@ -486,6 +490,12 @@ class PagedAttentionKVCacheObj : public Object {
     dirty_aux_data_device_ = false;
   }
 
+  /*! \brief Return the number of remaining pages. */
+  int GetNumAvailablePages() {
+    ICHECK_EQ(num_pages_allocated_, free_page_ids_.size() + num_pages_in_use_);
+    return pages_->shape[0] - num_pages_in_use_;
+  }
+
   /*! \brief Reset the KV cache. */
   void Clear() {
     num_total_seqs_ = 0;
@@ -530,6 +540,9 @@ class PagedAttentionKVCacheObj : public Object {
     if (num_pages_allocated_ < reserved_num_pages) {
       return num_pages_allocated_++;
     }
+    CHECK(allow_growth_)
+        << "The page KV cache is full and growth is not allowed. Please set a 
larger "
+           "total token capacity when initialization.";
     ICHECK_EQ(num_pages_allocated_, reserved_num_pages);
 
     // Grow the `pages` array by doubling its size.
@@ -563,12 +576,18 @@ class PagedAttentionKVCacheObj : public Object {
 
     DLDevice device = page_table_indptr_device_->device;
     if (reserved_nseq != page_table_indptr_device_->shape[0] - 1) {
+      CHECK(allow_growth_)
+          << "The page KV cache is full and growth is not allowed. Please set 
a larger "
+             "sequence capacity when initialization.";
       page_table_indptr_device_ = NDArray::Empty({reserved_nseq + 1}, 
dtype_aux_, device);
       last_page_offset_device_ = NDArray::Empty({reserved_nseq}, dtype_aux_, 
device);
       cur_append_length_indptr_device_ = NDArray::Empty({reserved_nseq + 1}, 
dtype_aux_, device);
     }
 
     if (pages_->shape[0] > page_table_values_device_->shape[0]) {
+      CHECK(allow_growth_)
+          << "The page KV cache is full and growth is not allowed. Please set 
a larger "
+             "total token capacity when initialization.";
       page_table_values_device_ = NDArray::Empty({pages_->shape[0]}, 
dtype_aux_, device);
     }
   }
@@ -576,13 +595,13 @@ class PagedAttentionKVCacheObj : public Object {
 
 class PagedAttentionKVCache : public ObjectRef {
  public:
-  static PagedAttentionKVCache Create(int64_t reserved_num_seqs, int64_t 
total_sequence_length,
+  static PagedAttentionKVCache Create(int64_t reserved_num_seqs, int64_t 
total_token_capacity,
                                       int64_t page_size, int64_t num_layers, 
int64_t num_heads,
-                                      int64_t head_dim, NDArray init) {
-    int64_t reserved_num_pages = (total_sequence_length + page_size - 1) / 
page_size;
+                                      int64_t head_dim, NDArray init, bool 
allow_growth) {
+    int64_t reserved_num_pages = (total_token_capacity + page_size - 1) / 
page_size;
     auto n = make_object<PagedAttentionKVCacheObj>(page_size, num_layers, 
num_heads, head_dim,
                                                    reserved_num_seqs, 
reserved_num_pages,
-                                                   init->dtype, init->device);
+                                                   init->dtype, init->device, 
allow_growth);
     return PagedAttentionKVCache(n);
   }
 
@@ -597,10 +616,10 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
     .set_body_typed([](ShapeTuple cache_config, int64_t num_layers_, int64_t 
num_heads_,
-                       int64_t head_dim_, NDArray init) {
+                       int64_t head_dim_, NDArray init, bool allow_growth) {
       CHECK_EQ(cache_config.size(), 3);
       return PagedAttentionKVCache::Create(cache_config[0], cache_config[1], 
cache_config[2],
-                                           num_layers_, num_heads_, head_dim_, 
init);
+                                           num_layers_, num_heads_, head_dim_, 
init, allow_growth);
     });
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
@@ -640,6 +659,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_devic
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
     .set_body_typed([](PagedAttentionKVCache cache) { cache->Clear(); });
 
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_num_available_pages")
+    .set_body_typed([](PagedAttentionKVCache cache) { return 
cache->GetNumAvailablePages(); });
+
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
       CHECK(args.size() == 5 || args.size() == 8);
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
index a57a905550..685ee262f1 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
@@ -160,6 +160,7 @@ def test_paged_attention_kv_cache_append_prefill():
         nhead,
         nfeat,
         tvm.nd.empty((), dtype),
+        True,
     )
 
     operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 
20)]]
@@ -207,6 +208,7 @@ def test_paged_attention_kv_cache_append_decode():
         nhead,
         nfeat,
         tvm.nd.empty((), dtype),
+        True,
     )
 
     cached_values = []
@@ -260,6 +262,7 @@ def test_paged_attention_kv_cache_remove():
         nhead,
         nfeat,
         tvm.nd.empty((), dtype),
+        True,
     )
 
     cached_values = []
@@ -319,6 +322,7 @@ def test_paged_attention_kv_cache_popn():
         nhead,
         nfeat,
         tvm.nd.empty((), dtype),
+        True,
     )
 
     cached_values = []
@@ -381,6 +385,7 @@ def test_paged_attention_kv_cache_clear():
         nhead,
         nfeat,
         tvm.nd.empty((), dtype),
+        True,
     )
 
     cached_values = []

Reply via email to