This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9862c84b9f [KVCache] Reducing CacheAuxDataManager copy size (#16831)
9862c84b9f is described below
commit 9862c84b9f624d842d2a8f79d5a5fa240734afa9
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Apr 3 09:31:46 2024 -0400
[KVCache] Reducing CacheAuxDataManager copy size (#16831)
The cached KV cache auxiliary data manager turns out introducing
much extra copy size due to improper handling of array offsets.
Specifically, prior to this PR, the manager always align the start
of each offset to the largest possible. As a result, in each copy
there are quite a lot of unnecessary elements getting copied.
This PR reduces the copy size to the minimal by aligning properly.
This significantly reduces the copy size.
---
src/runtime/relax_vm/paged_kv_cache.cc | 148 ++++++++++++++++-----------------
1 file changed, 73 insertions(+), 75 deletions(-)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 1e674d0ec6..e16d79885e 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -216,6 +216,8 @@ class PagedKVCacheAuxDataManager {
}
virtual ~PagedKVCacheAuxDataManager() = default;
+ /*! \brief Reset the status of copy manager. */
+ virtual void ResetCopy() = 0;
/*! \brief Copy the indptr array of append lengths after coalescing. (see
GetChunkedBlockIds) */
virtual NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int
depth) = 0;
/*! \brief Copy the indptr array of page table. */
@@ -295,6 +297,8 @@ class PlainPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
append_position_map_device_ = NDArray::Empty({prefill_chunk_size},
dtype_aux_, device);
}
+ // The reset of the plain auxiliary data manager is no-op.
+ void ResetCopy() final {}
NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int depth)
final {
NDArray view = qo_indptr_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
@@ -424,69 +428,69 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
: PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream),
elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8),
offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) {
- // - Calculate all the starting offsets of the auxiliary arrays in
+ // - Calculate cache size of all the auxiliary arrays in
// local cache and the large on-device array.
- int64_t total_elems =
- InitializeArrayElemOffset(reserved_num_seqs, num_total_pages,
prefill_chunk_size);
- copy_shape_ = {total_elems};
+ int64_t cache_size = CalculateCacheSize(reserved_num_seqs,
num_total_pages, prefill_chunk_size);
// - Initialize the host auxiliary data buffer.
- merged_aux_data_host_.resize(total_elems);
+ merged_aux_data_host_.resize(cache_size);
// - Initialize the device auxiliary data buffer.
memory::Allocator* allocator =
memory::MemoryManager::GetOrCreateAllocator(device,
memory::AllocatorType::kNaive);
ICHECK_NOTNULL(allocator);
merged_aux_data_device_ =
- memory::Storage(allocator->Alloc(device, {total_elems}, dtype_aux),
allocator);
+ memory::Storage(allocator->Alloc(device, {cache_size}, dtype_aux),
allocator);
}
+ void ResetCopy() final { copy_offset_ = 0; }
NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int depth)
final {
- return CopyVecToCacheAtOffset(data, depth_offsets_[depth] +
qo_indptr_in_depth_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyPageIndptrOnDepthAsync(std::vector<int32_t>* data, int depth)
final {
- return CopyVecToCacheAtOffset(data, depth_offsets_[depth] +
page_indptr_in_depth_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyPageIndicesOnDepthAsync(std::vector<int32_t>* data, int depth)
final {
- return CopyVecToCacheAtOffset(data, depth_offsets_[depth] +
page_indices_in_depth_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyLastPageLenOnDepthAsync(std::vector<int32_t>* data, int depth)
final {
- return CopyVecToCacheAtOffset(data, depth_offsets_[depth] +
length_info_in_depth_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector<int32_t>* data, int
depth) final {
- return CopyVecToCacheAtOffset(data, depth_offsets_[depth] +
k_rope_pos_offset_in_depth_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyCurAppendLengthIndptrAsync(std::vector<int32_t>* data) final {
- return CopyVecToCacheAtOffset(data, cur_append_length_indptr_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector<int32_t>* data) final {
- return CopyVecToCacheAtOffset(data, k_ragged_rope_pos_offset_offset_);
- }
- NDArray CopyQRoPEPosMapAsync(std::vector<int32_t>* data) final {
- return CopyVecToCacheAtOffset(data, q_rope_position_map_offset_);
+ return CopyVecToCache(data);
}
+ NDArray CopyQRoPEPosMapAsync(std::vector<int32_t>* data) final { return
CopyVecToCache(data); }
NDArray CopyAppendPositionMapAsync(std::vector<int32_t>* data) final {
- return CopyVecToCacheAtOffset(data, append_position_map_offset_);
+ return CopyVecToCache(data);
}
NDArray CopyLengthInfoOnDepthAsync(std::vector<int32_t>* last_page_len,
std::vector<int32_t>*
sliding_window_offset,
std::vector<int32_t>* sink_size, int
depth) final {
- int64_t offset = depth_offsets_[depth] + length_info_in_depth_offset_;
int64_t n_elem = last_page_len->size();
- std::memcpy(merged_aux_data_host_.data() + offset, last_page_len->data(),
+ std::memcpy(merged_aux_data_host_.data() + copy_offset_,
last_page_len->data(),
n_elem * elem_byte_size_);
- std::memcpy(merged_aux_data_host_.data() + offset + n_elem,
sliding_window_offset->data(),
+ std::memcpy(merged_aux_data_host_.data() + copy_offset_ + n_elem,
sliding_window_offset->data(),
n_elem * elem_byte_size_);
- std::memcpy(merged_aux_data_host_.data() + offset + 2 * n_elem,
sink_size->data(),
+ std::memcpy(merged_aux_data_host_.data() + copy_offset_ + 2 * n_elem,
sink_size->data(),
n_elem * elem_byte_size_);
- return merged_aux_data_device_->AllocNDArray(offset * elem_byte_size_, {3,
n_elem}, dtype_aux_);
+ NDArray view = merged_aux_data_device_->AllocNDArray(copy_offset_ *
elem_byte_size_,
+ {3, n_elem},
dtype_aux_);
+ copy_offset_ += CeilDivElemAlignment(3 * n_elem);
+ return view;
}
void CommitCopy() final {
+ std::vector<int64_t> copy_shape{copy_offset_};
DLTensor copy_dst;
copy_dst.data = merged_aux_data_device_->buffer.data;
copy_dst.device = device_;
copy_dst.ndim = 1;
copy_dst.dtype = dtype_aux_;
- copy_dst.shape = copy_shape_.data();
+ copy_dst.shape = copy_shape.data();
copy_dst.strides = nullptr;
copy_dst.byte_offset = 0;
@@ -501,70 +505,61 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
* \brief Calculate the start element offsets of the auxiliary arrays in the
local cache.
* \return Return the local cache size (total number of elements in the
local cache).
*/
- int64_t InitializeArrayElemOffset(int64_t reserved_num_seqs, int64_t
num_total_pages,
- int64_t prefill_chunk_size) {
- // For safety, we align the start offset of the arrays to
`offset_alignment`.
- auto f_ceil_div_elem_alignment = [this](int n) {
- return (n + offset_alignment_ - 1) / offset_alignment_ *
offset_alignment_;
- };
-
- // - Element offsets of the arrays that every depth has.
- qo_indptr_in_depth_offset_ = 0;
- page_indptr_in_depth_offset_ =
- qo_indptr_in_depth_offset_ +
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
- page_indices_in_depth_offset_ =
- page_indptr_in_depth_offset_ +
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
- length_info_in_depth_offset_ =
- page_indices_in_depth_offset_ +
f_ceil_div_elem_alignment(num_total_pages);
- k_rope_pos_offset_in_depth_offset_ =
- length_info_in_depth_offset_ + f_ceil_div_elem_alignment(3 *
reserved_num_seqs);
-
- // - Element offsets of each depth.
- int64_t elem_per_depth =
- k_rope_pos_offset_in_depth_offset_ +
f_ceil_div_elem_alignment(reserved_num_seqs);
- for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
- depth_offsets_.push_back(d * elem_per_depth);
- }
-
- // - Element offsets of other arrays.
- cur_append_length_indptr_offset_ = kPagedKVCacheMaxBlockDepth *
elem_per_depth;
- k_ragged_rope_pos_offset_offset_ =
- cur_append_length_indptr_offset_ +
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
- q_rope_position_map_offset_ =
- k_ragged_rope_pos_offset_offset_ +
f_ceil_div_elem_alignment(reserved_num_seqs);
- append_position_map_offset_ =
- q_rope_position_map_offset_ +
f_ceil_div_elem_alignment(prefill_chunk_size);
-
- // - The total number of elements after alignment.
- return append_position_map_offset_ +
f_ceil_div_elem_alignment(prefill_chunk_size);
+ int64_t CalculateCacheSize(int64_t reserved_num_seqs, int64_t
num_total_pages,
+ int64_t prefill_chunk_size) {
+ int64_t cache_size = 0;
+ // - Array size of the arrays that every depth has.
+ // Corresponding to the following arrays respectively
+ // - qo_indptr_in_depth
+ // - page_indptr_in_depth
+ // - page_indices_in_depth
+ // - length_info_in_depth
+ // - k_rope_pos_offset_in_depth
+ cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
+ cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
+ cache_size += CeilDivElemAlignment(num_total_pages);
+ cache_size += CeilDivElemAlignment(3 * reserved_num_seqs);
+ cache_size += CeilDivElemAlignment(reserved_num_seqs);
+ cache_size *= kPagedKVCacheMaxBlockDepth;
+
+ // - Array size of other arrays.
+ // Corresponding to the following arrays respectively
+ // - cur_append_length_indptr
+ // - k_ragged_rope_pos_offset
+ // - q_rope_position_map
+ // - append_position_map
+ cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
+ cache_size += CeilDivElemAlignment(reserved_num_seqs);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+
+ return cache_size;
}
/*!
* \brief Copy the input data to the cache at the given offset.
* And return the NDArray view of the cache starting at the offset.
*/
- NDArray CopyVecToCacheAtOffset(std::vector<int32_t>* data, int64_t offset) {
+ NDArray CopyVecToCache(std::vector<int32_t>* data) {
int64_t n_elem = data->size();
- std::memcpy(merged_aux_data_host_.data() + offset, data->data(), n_elem *
elem_byte_size_);
- return merged_aux_data_device_->AllocNDArray(offset * elem_byte_size_,
{n_elem}, dtype_aux_);
+ std::memcpy(merged_aux_data_host_.data() + copy_offset_, data->data(),
+ n_elem * elem_byte_size_);
+ NDArray view =
+ merged_aux_data_device_->AllocNDArray(copy_offset_ * elem_byte_size_,
{n_elem}, dtype_aux_);
+ copy_offset_ += CeilDivElemAlignment(n_elem);
+ return view;
}
- const int64_t cuda_byte_alignment_ = 256;
+ /*! \brief For safety, we align the start offset of the arrays to
`offset_alignment`. */
+ int64_t CeilDivElemAlignment(int n) {
+ return (n + offset_alignment_ - 1) / offset_alignment_ * offset_alignment_;
+ }
+
+ const int64_t cuda_byte_alignment_ = 16;
const int64_t elem_byte_size_;
const int64_t offset_alignment_;
- int64_t qo_indptr_in_depth_offset_;
- int64_t page_indptr_in_depth_offset_;
- int64_t page_indices_in_depth_offset_;
- int64_t length_info_in_depth_offset_;
- int64_t k_rope_pos_offset_in_depth_offset_;
- std::vector<int64_t> depth_offsets_;
- int64_t cur_append_length_indptr_offset_;
- int64_t k_ragged_rope_pos_offset_offset_;
- int64_t q_rope_position_map_offset_;
- int64_t append_position_map_offset_;
-
- std::vector<int64_t> copy_shape_;
+ int64_t copy_offset_ = 0;
std::vector<int32_t> merged_aux_data_host_;
memory::Storage merged_aux_data_device_;
};
@@ -1692,6 +1687,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
total_append_length = cur_append_lengths_indptr_host_.back();
ICHECK_EQ(total_append_length, append_position_map_host_.size());
+ // - Reset the copy.
+ aux_data_manager_->ResetCopy();
+
// 1. qo_indptr_on_depths
for (int d = 0; d < num_depths_; ++d) {
qo_indptr_on_depths_view_[d] =