This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89b91e2b11 [KVCache] Partial layers support (#17192)
89b91e2b11 is described below
commit 89b91e2b1195b53bf7e1f6c250bc9a1247367d13
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Jul 23 21:13:41 2024 -0700
[KVCache] Partial layers support (#17192)
This PR updates the KVCache implementation, to support partial layers.
---
include/tvm/runtime/disco/disco_worker.h | 15 ++++
src/runtime/disco/disco_worker.cc | 9 ---
src/runtime/relax_vm/paged_kv_cache.cc | 82 +++++++++++++++-------
..._builtin_paged_attention_kv_cache_flashinfer.py | 2 +-
...runtime_builtin_paged_attention_kv_cache_tir.py | 2 +-
5 files changed, 73 insertions(+), 37 deletions(-)
diff --git a/include/tvm/runtime/disco/disco_worker.h
b/include/tvm/runtime/disco/disco_worker.h
index 301b5b8d62..13f94802c8 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -93,6 +93,21 @@ class DiscoWorker {
struct Impl;
friend struct DiscoWorker::Impl;
};
+/*!
+ * \brief A threadlocal wrapper of DiscoWorker.
+ */
+struct ThreadLocalDiscoWorker {
+ /*! \brief The Disco worker */
+ DiscoWorker* worker;
+
+ /*!
+ * \brief Get the threadlocal Disco worker.
+ */
+ static ThreadLocalDiscoWorker* Get() {
+ thread_local static ThreadLocalDiscoWorker worker;
+ return &worker;
+ }
+};
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/disco/disco_worker.cc
b/src/runtime/disco/disco_worker.cc
index b281a3aca7..5e6f401054 100644
--- a/src/runtime/disco/disco_worker.cc
+++ b/src/runtime/disco/disco_worker.cc
@@ -28,15 +28,6 @@
namespace tvm {
namespace runtime {
-struct ThreadLocalDiscoWorker {
- DiscoWorker* worker;
-
- static ThreadLocalDiscoWorker* Get() {
- thread_local static ThreadLocalDiscoWorker worker;
- return &worker;
- }
-};
-
TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index ec1cc3593a..2fb8a72f42 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -21,6 +21,7 @@
* \brief Runtime paged KV cache object for language models.
*/
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/ndarray.h>
@@ -825,6 +826,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
const int64_t page_size_;
/*! \brief The number of layers in the model. */
const int64_t num_layers_;
+ /*! \brief The beginning layer id offset. */
+ const int64_t layer_id_begin_offset_;
/*! \brief The number of query/output heads in the model. */
const int64_t num_qo_heads_;
/*! \brief The number of key/value heads in the model. */
@@ -981,14 +984,14 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
public:
/*! \brief Constructor. Take the cache configuration and initialize the
NDArrays. */
explicit PagedAttentionKVCacheObj(
- int64_t page_size, //
- int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t
head_dim,
- int64_t reserved_num_seqs, int64_t num_total_pages, int64_t
prefill_chunk_size,
- bool support_sliding_window, RoPEMode rope_mode, double rotary_scale,
double rotary_theta,
- DLDataType dtype, Device device, PackedFunc f_transpose_append,
PackedFunc f_compact_copy,
- PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
- PackedFunc f_attention_prefill_sliding_window, PackedFunc
f_attention_decode_sliding_window,
- PackedFunc f_attention_prefill_ragged, PackedFunc
f_attention_prefill_with_tree_mask,
+ int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, //
+ int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t
reserved_num_seqs,
+ int64_t num_total_pages, int64_t prefill_chunk_size, bool
support_sliding_window,
+ RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType
dtype, Device device,
+ PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc
f_attention_prefill,
+ PackedFunc f_attention_decode, PackedFunc
f_attention_prefill_sliding_window,
+ PackedFunc f_attention_decode_sliding_window, PackedFunc
f_attention_prefill_ragged,
+ PackedFunc f_attention_prefill_with_tree_mask,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
@@ -998,6 +1001,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
+ layer_id_begin_offset_(layer_id_begin_offset),
num_qo_heads_(num_qo_heads),
num_kv_heads_(num_kv_heads),
head_dim_(head_dim),
@@ -1672,7 +1676,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor)
final {
// Part 1. Shape and dtype check.
- NDArray pages = pages_[layer_id];
+ int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+ CHECK_GE(local_layer_id, 0);
+ CHECK_LT(local_layer_id, num_layers_);
+ NDArray pages = pages_[local_layer_id];
CHECK(qkv_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());
@@ -1713,13 +1720,13 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
- f_transpose_append_(pages_[layer_id], k_data, v_data,
append_position_map_view_);
+ f_transpose_append_(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
// Part 4: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data,
attn_score_scaling_factor);
// Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not
set.
if (!append_before_attn_) {
- f_transpose_append_(pages_[layer_id], k_data, v_data,
append_position_map_view_);
+ f_transpose_append_(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
}
@@ -2238,6 +2245,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
*/
void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,
NDArray output, double attn_score_scaling_factor) {
+ int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+ CHECK_GE(local_layer_id, 0);
+ CHECK_LT(local_layer_id, num_layers_);
PackedFunc f_prefill =
!support_sliding_window_ ? f_attention_prefill_ :
f_attention_prefill_sliding_window_;
PackedFunc f_decode =
@@ -2245,7 +2255,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
CHECK_GE(num_depths_, 1) << "The number of effective depths must be
greater or equal to 1.";
if (append_before_attn_) {
f_decode(
- /*depth=*/0, q_data, pages_[layer_id],
page_indptr_on_depths_view_[0],
+ /*depth=*/0, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[0],
page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
k_rope_pos_offset_view_[0], q_rope_position_map_view_, output,
merged_attn_scores_view_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_,
@@ -2280,7 +2290,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
if (use_decode_kernel_[d]) {
// Use decode kernel for depth d
- f_decode(/*depth=*/d, q_data, pages_[layer_id],
page_indptr_on_depths_view_[d],
+ f_decode(/*depth=*/d, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[d],
page_indices_on_depths_view_[d],
length_info_on_depths_view_[d],
k_rope_pos_offset_view_[d], q_rope_position_map_view_,
temp_attn_output_view_,
temp_attn_scores_view_,
@@ -2289,7 +2299,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
} else {
// Use prefill kernel for depth d
f_prefill(
- /*depth=*/d, q_data, qo_indptr_on_depths_view_[d],
pages_[layer_id],
+ /*depth=*/d, q_data, qo_indptr_on_depths_view_[d],
pages_[local_layer_id],
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
q_rope_position_map_view_,
temp_attn_output_view_, temp_attn_scores_view_,
@@ -2436,7 +2446,17 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
- int64_t num_layers = args[1];
+ ShapeTuple layer_indptr_tuple = args[1];
+ int num_groups = 1;
+ int group_id = 0;
+ if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+ // In the Disco worker thread
+ num_groups = disco_worker->num_groups;
+ group_id = disco_worker->worker_id / (disco_worker->num_workers /
num_groups);
+ }
+ CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+ int64_t num_layers = layer_indptr_tuple[group_id + 1] -
layer_indptr_tuple[group_id];
+ int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
int64_t num_qo_heads = args[2];
int64_t num_kv_heads = args[3];
int64_t head_dim = args[4];
@@ -2482,11 +2502,11 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
num_total_pages += reserved_num_seqs * 2;
}
ObjectPtr<PagedAttentionKVCacheObj> n =
make_object<PagedAttentionKVCacheObj>(
- page_size, num_layers, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs,
- num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode),
- rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append),
- std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode),
- std::move(f_attention_prefill_sliding_window),
+ page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
+ reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
+ RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype,
init->device,
+ std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
+ std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
std::move(f_attention_prefill_ragged_begin_forward),
@@ -2503,7 +2523,17 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
- int64_t num_layers = args[1];
+ ShapeTuple layer_indptr_tuple = args[1];
+ int num_groups = 1;
+ int group_id = 0;
+ if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+ // In the Disco worker thread
+ num_groups = disco_worker->num_groups;
+ group_id = disco_worker->worker_id / (disco_worker->num_workers /
num_groups);
+ }
+ CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+ int64_t num_layers = layer_indptr_tuple[group_id + 1] -
layer_indptr_tuple[group_id];
+ int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
int64_t num_qo_heads = args[2];
int64_t num_kv_heads = args[3];
int64_t head_dim = args[4];
@@ -2543,11 +2573,11 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
num_total_pages += reserved_num_seqs * 2;
}
ObjectPtr<PagedAttentionKVCacheObj> n =
make_object<PagedAttentionKVCacheObj>(
- page_size, num_layers, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs,
- num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode),
- rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append),
- std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode),
- std::move(f_attention_prefill_sliding_window),
+ page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
+ reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
+ RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype,
init->device,
+ std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
+ std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 048cf49806..bade04a7d7 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -354,7 +354,7 @@ def create_kv_cache(rope_mode):
support_sliding_window,
]
),
- num_layers,
+ tvm.runtime.ShapeTuple([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 34680160c8..9192bb901f 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -153,7 +153,7 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
int(support_sliding_window),
]
),
- num_layers,
+ tvm.runtime.ShapeTuple([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,