This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7a0c3f9e05 [Unity][Support] PagedKVCache support growth control
(#16112)
7a0c3f9e05 is described below
commit 7a0c3f9e056a771b6854d91be6656b31e871e622
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Nov 12 12:54:19 2023 -0500
[Unity][Support] PagedKVCache support growth control (#16112)
This PR supports controlling whether KV cache automatic
growth is allowed through constructor parameter. Previously
we always allow the KV cache to grow whenever it is full
and more capacity is demanded.
Although automatic growth can be good, in practice we
often want the pre-allocated memory to be static, large enough
and not changeable, which will make the memory management more
controllable. Hence, this PR supports to specify if growth
is allowed, and will throw error when growing in unallowed cases.
This PR also adds an auxiliary function to KV cache to query
the number of available pages.
---
src/runtime/relax_vm/paged_kv_cache.cc | 46 ++++++++++++++++------
...est_runtime_builtin_paged_attention_kv_cache.py | 5 +++
2 files changed, 39 insertions(+), 12 deletions(-)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 4c61af9018..6d2444ea64 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -77,6 +77,8 @@ class PagedAttentionKVCacheObj : public Object {
const int64_t num_heads_;
/*! \brief The number of features each head has. */
const int64_t head_dim_;
+ /*! \brief A boolean denoting if cache automatic growth is allowed. */
+ const bool allow_growth_;
/*! \brief We fix int32 to be the index dtype of auxiliary data. */
const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
@@ -145,8 +147,6 @@ class PagedAttentionKVCacheObj : public Object {
* length dimension of K/V data. It is used for efficient computation.
*/
NDArray cur_pos2seqid_device_;
- /*! \brief A temporary buffer for efficient attention computation. */
- NDArray attn_tmp_buffer_;
//-------------------------------------------
// For efficient memory management, the actual sizes of the arrays
@@ -165,8 +165,13 @@ class PagedAttentionKVCacheObj : public Object {
/*! \brief Constructor. Take the cache configuration and initialize the
NDArrays. */
explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers,
int64_t num_heads,
int64_t head_dim, int64_t
reserved_num_seqs,
- int64_t reserved_num_pages, DLDataType
dtype, DLDevice device)
- : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads),
head_dim_(head_dim) {
+ int64_t reserved_num_pages, DLDataType
dtype, DLDevice device,
+ bool allow_growth)
+ : page_size_(page_size),
+ num_layers_(num_layers),
+ num_heads_(num_heads),
+ head_dim_(head_dim),
+ allow_growth_(allow_growth) {
pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads,
page_size, head_dim},
dtype, device);
page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1},
dtype_aux_, device);
@@ -174,7 +179,6 @@ class PagedAttentionKVCacheObj : public Object {
last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_,
device);
cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1},
dtype_aux_, device);
cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size},
dtype_aux_, device);
- attn_tmp_buffer_ = NDArray::Empty({8 * 1024 * 1024},
DLDataType(DataType::Float(32)), device);
}
/*!
@@ -272,7 +276,7 @@ class PagedAttentionKVCacheObj : public Object {
f_attention(q_data, pages_, //
page_table_indptr_view_, page_table_values_view_, //
last_page_offset_view_, cur_append_length_indptr_view_, //
- layer_id, attn_tmp_buffer_, output, apply_rotary,
rotary_scale, rotary_theta);
+ layer_id, output, apply_rotary, rotary_scale, rotary_theta);
}
/*!
@@ -486,6 +490,12 @@ class PagedAttentionKVCacheObj : public Object {
dirty_aux_data_device_ = false;
}
+ /*! \brief Return the number of remaining pages. */
+ int GetNumAvailablePages() {
+ ICHECK_EQ(num_pages_allocated_, free_page_ids_.size() + num_pages_in_use_);
+ return pages_->shape[0] - num_pages_in_use_;
+ }
+
/*! \brief Reset the KV cache. */
void Clear() {
num_total_seqs_ = 0;
@@ -530,6 +540,9 @@ class PagedAttentionKVCacheObj : public Object {
if (num_pages_allocated_ < reserved_num_pages) {
return num_pages_allocated_++;
}
+ CHECK(allow_growth_)
+ << "The page KV cache is full and growth is not allowed. Please set a
larger "
+ "total token capacity when initialization.";
ICHECK_EQ(num_pages_allocated_, reserved_num_pages);
// Grow the `pages` array by doubling its size.
@@ -563,12 +576,18 @@ class PagedAttentionKVCacheObj : public Object {
DLDevice device = page_table_indptr_device_->device;
if (reserved_nseq != page_table_indptr_device_->shape[0] - 1) {
+ CHECK(allow_growth_)
+ << "The page KV cache is full and growth is not allowed. Please set
a larger "
+ "sequence capacity when initialization.";
page_table_indptr_device_ = NDArray::Empty({reserved_nseq + 1},
dtype_aux_, device);
last_page_offset_device_ = NDArray::Empty({reserved_nseq}, dtype_aux_,
device);
cur_append_length_indptr_device_ = NDArray::Empty({reserved_nseq + 1},
dtype_aux_, device);
}
if (pages_->shape[0] > page_table_values_device_->shape[0]) {
+ CHECK(allow_growth_)
+ << "The page KV cache is full and growth is not allowed. Please set
a larger "
+ "total token capacity when initialization.";
page_table_values_device_ = NDArray::Empty({pages_->shape[0]},
dtype_aux_, device);
}
}
@@ -576,13 +595,13 @@ class PagedAttentionKVCacheObj : public Object {
class PagedAttentionKVCache : public ObjectRef {
public:
- static PagedAttentionKVCache Create(int64_t reserved_num_seqs, int64_t
total_sequence_length,
+ static PagedAttentionKVCache Create(int64_t reserved_num_seqs, int64_t
total_token_capacity,
int64_t page_size, int64_t num_layers,
int64_t num_heads,
- int64_t head_dim, NDArray init) {
- int64_t reserved_num_pages = (total_sequence_length + page_size - 1) /
page_size;
+ int64_t head_dim, NDArray init, bool
allow_growth) {
+ int64_t reserved_num_pages = (total_token_capacity + page_size - 1) /
page_size;
auto n = make_object<PagedAttentionKVCacheObj>(page_size, num_layers,
num_heads, head_dim,
reserved_num_seqs,
reserved_num_pages,
- init->dtype, init->device);
+ init->dtype, init->device,
allow_growth);
return PagedAttentionKVCache(n);
}
@@ -597,10 +616,10 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body_typed([](ShapeTuple cache_config, int64_t num_layers_, int64_t
num_heads_,
- int64_t head_dim_, NDArray init) {
+ int64_t head_dim_, NDArray init, bool allow_growth) {
CHECK_EQ(cache_config.size(), 3);
return PagedAttentionKVCache::Create(cache_config[0], cache_config[1],
cache_config[2],
- num_layers_, num_heads_, head_dim_,
init);
+ num_layers_, num_heads_, head_dim_,
init, allow_growth);
});
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
@@ -640,6 +659,9 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_devic
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
.set_body_typed([](PagedAttentionKVCache cache) { cache->Clear(); });
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_num_available_pages")
+ .set_body_typed([](PagedAttentionKVCache cache) { return
cache->GetNumAvailablePages(); });
+
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 5 || args.size() == 8);
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
index a57a905550..685ee262f1 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py
@@ -160,6 +160,7 @@ def test_paged_attention_kv_cache_append_prefill():
nhead,
nfeat,
tvm.nd.empty((), dtype),
+ True,
)
operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5,
20)]]
@@ -207,6 +208,7 @@ def test_paged_attention_kv_cache_append_decode():
nhead,
nfeat,
tvm.nd.empty((), dtype),
+ True,
)
cached_values = []
@@ -260,6 +262,7 @@ def test_paged_attention_kv_cache_remove():
nhead,
nfeat,
tvm.nd.empty((), dtype),
+ True,
)
cached_values = []
@@ -319,6 +322,7 @@ def test_paged_attention_kv_cache_popn():
nhead,
nfeat,
tvm.nd.empty((), dtype),
+ True,
)
cached_values = []
@@ -381,6 +385,7 @@ def test_paged_attention_kv_cache_clear():
nhead,
nfeat,
tvm.nd.empty((), dtype),
+ True,
)
cached_values = []