This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8b4df725b7 [Runtime][KVCache] Initial interface setup for MLA (#17616)
8b4df725b7 is described below
commit 8b4df725b797a05e87d80d165dc7bbb6774aa869
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Jan 30 20:00:22 2025 -0500
[Runtime][KVCache] Initial interface setup for MLA (#17616)
This PR introduces the initial KV cache interface setup for multi-head
latent attention in DeepSeek models.
Some interface implementations are marked todo for implementation
in the soon future.
---
src/runtime/relax_vm/kv_state.h | 63 +++++++
src/runtime/relax_vm/paged_kv_cache.cc | 313 ++++++++++++++++++++++++++++-----
2 files changed, 330 insertions(+), 46 deletions(-)
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 7df3215d08..77c17d1c55 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -181,6 +181,69 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
NDArray o_data, double
attn_score_scaling_factor) = 0;
+ /*!
+ * \brief Compute attention with Q/K/V data.
+ * \param layer_id The model layer where the attention compute happens.
+ * \param q_data The input Q data, in layout `(total_length, num_qo_heads,
head_dim)`
+ * \param k_data The input K data, in layout `(total_length, num_kv_heads,
head_dim)`
+ * \param v_data The input V data, in layout `(total_length, num_kv_heads,
head_dim)`
+ * \param mask The input mask data, in layout `(total_sqr_length)`.
+ * \param o_data The output O data, in layout `(total_length, num_qo_heads,
head_dim)`.
+ * \param attn_score_scaling_factor The additional attention scaling factor.
+ */
+ virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data,
NDArray k_data,
+ NDArray v_data, Optional<NDArray>
mask, NDArray o_data,
+ double attn_score_scaling_factor) = 0;
+
+ /*!
+ * \brief Compute multi-head latent attention after applying weight
absorption.
+ * \param layer_id The model layer where the attention compute happens.
+ * \param q_data The input Q data, in layout `(total_length, num_qo_heads,
qk_head_dim)`
+ * \param compressed_kv_data The compressed latent KV data, in layout
+ * `(total_length, num_kv_heads, kv_lora_rank)`
+ * \param k_pe_data The positional embedding part of K data, in layout
+ * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank +
qk_rope_head_dim`
+ * equals qk_head_dim
+ * \param o_data The output O data, in layout `(total_length, num_qo_heads,
v_head_dim)`.
+ * \param attn_score_scaling_factor The additional attention scaling factor.
+ */
+ virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray
compressed_kv_data,
+ NDArray k_pe_data, NDArray o_data, double
attn_score_scaling_factor) = 0;
+
+ /*!
+ * \brief Compute multi-head latent attention in normal style.
+ * \param layer_id The model layer where the attention compute happens.
+ * \param q_data The input Q data, in layout
+ * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
+ * \param k_data The input K data, in layout
+ * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
+ * \param v_data The input V data, in layout
+ * `(total_length, num_qo_heads, v_head_dim)`
+ * \param compressed_kv_data The compressed latent KV data, in layout
+ * `(total_length, num_kv_heads, kv_lora_rank)`
+ * \param k_pe_data The positional embedding part of K data, in layout
+ * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank +
qk_rope_head_dim`
+ * equals qk_head_dim
+ * \param o_data The output O data, in layout `(total_length, num_qo_heads,
v_head_dim)`.
+ * \param attn_score_scaling_factor The additional attention scaling factor.
+ */
+ virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,
+ NDArray compressed_kv_data, NDArray k_pe_data,
NDArray o_data,
+ double attn_score_scaling_factor) = 0;
+
+ /*!
+ * \brief Compute linear attention with Q/K/V data.
+ * \param layer_id The model layer where the attention compute happens.
+ * \param q_data The input Q data, in layout `(total_length, num_qo_heads,
head_dim)`.
+ * \param k_data The input K data, in layout `(total_length, num_kv_heads,
head_dim)`.
+ * \param v_data The input V data, in layout `(total_length, num_kv_heads,
head_dim)`.
+ * \param o_data The output O data, in layout `(total_length, num_qo_heads,
head_dim)`.
+ * \param attn_score_scaling_factor The additional attention scaling factor.
+ * \sa AttentionKVCache::Attention
+ */
+ virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray
k_data, NDArray v_data,
+ double attn_score_scaling_factor) = 0;
+
/************** Positions **************/
/*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 81c55bfcb6..8e5dfb4bd8 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -64,6 +64,33 @@ constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 *
1024;
/*! \brief The id of the temporary logical page, which is useful for sliding
window. */
constexpr const int kPagedKVCacheTempPageId = -1;
+/*!
+ * \brief The supported attention kinds in PagedKVCache.
+ * "MHA" means multi-head attention, multi-query attention and grouped query
attention in general.
+ * "MLA" means multi-head latent attention.
+ * "LinearAttn" means linear attention.
+ */
+enum class AttnKind : int {
+ kMHA = 0,
+ kMLA = 1,
+ kLinearAttn = 2,
+};
+
+ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int
num_sequence,
+ int64_t num_kv_heads, int64_t page_size, int64_t
qk_head_dim,
+ int64_t v_head_dim, int64_t qk_rope_head_dim) {
+ if (attn_kind == AttnKind::kMHA) {
+ // Ignore v_head_dim since multi-head attention requires K/V to have the
same head dim.
+ return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim};
+ } else if (attn_kind == AttnKind::kMLA) {
+ return {num_total_pages, num_kv_heads, page_size, qk_head_dim +
qk_rope_head_dim};
+ } else if (attn_kind == AttnKind::kLinearAttn) {
+ return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim};
+ }
+ ICHECK(false);
+ throw;
+}
+
/*!
* \brief The block structure in paged KV cache with common prefix support.
* Each block contains a list of pages for cached KV data.
@@ -940,13 +967,25 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
/*! \brief The number of key/value heads in the model. */
const int64_t num_kv_heads_;
/*! \brief The number of features each head has. */
- const int64_t head_dim_;
+ const int64_t qk_head_dim_;
+ /*!
+ * \brief The number of features each head has for V.
+ * For layers that use multi-head attention, this field is overriden by
qk_head_dim.
+ */
+ const int64_t v_head_dim_;
+ /*!
+ * \brief The number of features each head has for RoPE in multi-head latent
attention.
+ * This field is ignored for non-MLA.
+ */
+ const int64_t qk_rope_head_dim_;
/*! \brief The number of total pages allocated in KV cache. */
const int64_t num_total_pages_;
/*! \brief The maximum total sequence length in a prefill. */
const int64_t prefill_chunk_size_;
/*! \brief A boolean flag indicating if the KV cache supports sliding
window. */
const bool support_sliding_window_;
+ /*! \brief The attention kinds for each layer. */
+ const std::vector<AttnKind> attn_kinds_;
/*! \brief The RoPE application mode of KV cache.*/
const RoPEMode rope_mode_;
@@ -967,7 +1006,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
* If KV transfer function is specifed, pages_ will be allocated by NVSHMEM
as a whole NDArray.
* pages_ will contain tensor view of each layer.
* Otherwise, pages_ has `num_layers` NDArrays, each of them
- * has layout (num_pages, 2, num_heads, page_size, head_dim).
+ * has layout (num_pages, 2, num_heads, page_size, qk_head_dim).
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
*/
std::vector<NDArray> pages_;
@@ -1086,6 +1125,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
std::vector<NDArray> tree_attn_mn_indptr_view_;
PackedFunc f_transpose_append_;
+ PackedFunc f_transpose_append_mla_;
Optional<PackedFunc> f_transfer_kv_;
Optional<PackedFunc> f_transfer_kv_page_to_page_ = NullOpt;
PackedFunc f_compact_copy_;
@@ -1102,8 +1142,13 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
Optional<PackedFunc> f_attention_prefill_end_forward_;
Optional<PackedFunc> f_attention_decode_begin_forward_;
Optional<PackedFunc> f_attention_decode_end_forward_;
+ PackedFunc f_mla_prefill_;
+ PackedFunc f_mla_decode_;
+ PackedFunc f_mla_prefill_ragged_normal_;
+ PackedFunc f_mla_prefill_ragged_absorbed_;
PackedFunc f_merge_inplace_;
PackedFunc f_split_rotary_;
+ PackedFunc f_separate_rotary_;
PackedFunc f_copy_single_page_;
Optional<PackedFunc> f_debug_get_kv_;
@@ -1120,37 +1165,45 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
/*! \brief Constructor. Take the cache configuration and initialize the
NDArrays. */
explicit PagedAttentionKVCacheObj(
int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, //
- int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t
reserved_num_seqs,
+ int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t
v_head_dim,
+ int64_t qk_rope_head_dim, std::vector<AttnKind> attn_kinds, int64_t
reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool
support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta,
Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType
dtype, Device device,
- PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc
f_attention_prefill,
- PackedFunc f_attention_decode, PackedFunc
f_attention_prefill_sliding_window,
- PackedFunc f_attention_decode_sliding_window, PackedFunc
f_attention_prefill_ragged,
- PackedFunc f_attention_prefill_with_tree_mask,
+ PackedFunc f_transpose_append, PackedFunc f_transpose_append_mla,
PackedFunc f_compact_copy,
+ PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
+ PackedFunc f_attention_prefill_sliding_window, PackedFunc
f_attention_decode_sliding_window,
+ PackedFunc f_attention_prefill_ragged, PackedFunc
f_attention_prefill_with_tree_mask,
PackedFunc f_attention_prefill_with_tree_mask_paged_kv,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
- Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc
f_merge_inplace,
- PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv)
+ Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc
f_mla_prefill,
+ PackedFunc f_mla_decode, PackedFunc f_mla_prefill_ragged_normal,
+ PackedFunc f_mla_prefill_ragged_absorbed, PackedFunc f_merge_inplace,
+ PackedFunc f_split_rotary, PackedFunc f_separate_rotary, PackedFunc
f_copy_single_page,
+ Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
layer_id_begin_offset_(layer_id_begin_offset),
num_qo_heads_(num_qo_heads),
num_kv_heads_(num_kv_heads),
- head_dim_(head_dim),
+ qk_head_dim_(qk_head_dim),
+ v_head_dim_(v_head_dim),
+ qk_rope_head_dim_(qk_rope_head_dim),
num_total_pages_(num_total_pages),
prefill_chunk_size_(prefill_chunk_size),
support_sliding_window_(support_sliding_window),
+ attn_kinds_(std::move(attn_kinds)),
rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ?
RoPEMode::kInline
:
rope_mode),
rotary_scale_(rotary_scale),
rotary_theta_(rotary_theta),
rope_ext_factors_(std::move(rope_ext_factors)),
f_transpose_append_(std::move(f_transpose_append)),
+ f_transpose_append_mla_(std::move(f_transpose_append_mla)),
f_compact_copy_(std::move(f_compact_copy)),
f_attention_prefill_(std::move(f_attention_prefill)),
f_attention_decode_(std::move(f_attention_decode)),
@@ -1167,24 +1220,33 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)),
f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)),
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
+ f_mla_prefill_(std::move(f_mla_prefill)),
+ f_mla_decode_(std::move(f_mla_decode)),
+ f_mla_prefill_ragged_normal_(std::move(f_mla_prefill_ragged_normal)),
+
f_mla_prefill_ragged_absorbed_(std::move(f_mla_prefill_ragged_absorbed)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
+ f_separate_rotary_(std::move(f_separate_rotary)),
f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
if (enable_kv_transfer) {
+ // For now, KV transfer only supports MHA.
+ for (AttnKind attn_kind : attn_kinds_) {
+ CHECK(attn_kind == AttnKind::kMHA);
+ }
CHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem") != nullptr)
<< "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when
compiling TVM.";
const PackedFunc* f_nvshmem_empty =
runtime::Registry::Get("runtime.disco.nvshmem.empty");
ICHECK_NOTNULL(f_nvshmem_empty);
nvshmem_pages_ = (*f_nvshmem_empty)(
- ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size,
head_dim}), dtype,
+ ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size,
qk_head_dim}), dtype,
device);
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(nvshmem_pages_.CreateView(
- {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_},
nvshmem_pages_->dtype,
- i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ *
+ {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_},
nvshmem_pages_->dtype,
+ i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ *
qk_head_dim_ *
nvshmem_pages_.DataType().bytes()));
}
@@ -1197,8 +1259,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
} else {
for (int i = 0; i < num_layers; ++i) {
- pages_.push_back(
- NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size,
head_dim}, dtype, device));
+ ShapeTuple kv_cache_shape = GetKVCacheShape(
+ attn_kinds_[layer_id_begin_offset_ + i], num_total_pages,
reserved_num_seqs,
+ num_kv_heads, page_size, qk_head_dim, v_head_dim,
qk_rope_head_dim);
+ pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device));
}
}
@@ -1274,13 +1338,13 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
temp_attn_q_device_ =
- NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype,
device);
+ NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim},
dtype, device);
temp_attn_k_device_ =
- NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype,
device);
+ NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim},
dtype, device);
temp_attn_v_device_ =
- NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype,
device);
+ NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype,
device);
temp_attn_output_device_ =
- NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype,
device);
+ NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim},
dtype, device);
temp_attn_scores_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads},
DataType::Float(32), device);
merged_attn_scores_device_ =
@@ -2019,8 +2083,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
CHECK(qkv_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());
- // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
- // o_data: (num_total_length, num_qo_heads, head_dim)
+ CHECK(attn_kinds_[layer_id] == AttnKind::kMHA);
+ // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads,
qk_head_dim)
+ // o_data: (num_total_length, num_qo_heads, qk_head_dim)
CHECK_EQ(qkv_data->ndim, 3);
CHECK_EQ(o_data->ndim, 3);
@@ -2033,7 +2098,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
- CHECK_EQ(qkv_data->shape[2], head_dim_);
+ CHECK_EQ(qkv_data->shape[2], qk_head_dim_);
int64_t total_seq_length = 0;
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
@@ -2044,11 +2109,11 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
ICHECK(!dirty_aux_data_device_);
- NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length,
num_qo_heads_, head_dim_},
+ NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length,
num_qo_heads_, qk_head_dim_},
qkv_data->dtype);
- NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length,
num_kv_heads_, head_dim_},
+ NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length,
num_kv_heads_, qk_head_dim_},
qkv_data->dtype);
- NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length,
num_kv_heads_, head_dim_},
+ NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length,
num_kv_heads_, qk_head_dim_},
qkv_data->dtype);
NDArray qkv_data_view = qkv_data;
@@ -2057,7 +2122,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
qkv_data_view = qkv_data.CreateView(
{total_seq_length, qkv_data->shape[1], qkv_data->shape[2]},
qkv_data->dtype);
o_data_view =
- o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_},
qkv_data->dtype);
+ o_data.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_},
qkv_data->dtype);
}
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
if (transfer_kv_) {
@@ -2105,6 +2170,28 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
+ void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray
k_data, NDArray v_data,
+ Optional<NDArray> mask, NDArray o_data,
+ double attn_score_scaling_factor) final {
+ // Todo(ruihang): implement it
+ }
+
+ void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray
compressed_kv_data, NDArray k_pe_data,
+ NDArray o_data, double attn_score_scaling_factor) {
+ // Todo(ruihang): implement it
+ }
+
+ void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray
v_data,
+ NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
+ double attn_score_scaling_factor) {
+ // Todo(ruihang): implement it
+ }
+
+ void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,
+ double attn_score_scaling_factor) {
+ // Todo(ruihang): implement it
+ }
+
void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple&
leaf_indices) final {
CHECK_EQ(seq_ids.size(), leaf_indices.size())
<< "The given seq_ids and leaf_indices have different size.";
@@ -2216,9 +2303,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept
out-of-range end_pos";
CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >=
end_pos\"";
- // k/v_data: (num_layers, seq_length, num_kv_heads, head_dim)
+ // k/v_data: (num_layers, seq_length, num_kv_heads, qk_head_dim)
static constexpr const char* error_msg =
- "DebugGetKV expects the k_data in layout (num_layers, seq_length,
num_kv_heads, head_dim).";
+ "DebugGetKV expects the k_data in layout (num_layers, seq_length,
num_kv_heads, "
+ "qk_head_dim).";
std::vector<NDArray*> vec_kv_data = {&k_data, &v_data};
for (const NDArray* data_ptr : vec_kv_data) {
CHECK_EQ((*data_ptr)->ndim, 4) << error_msg;
@@ -2228,7 +2316,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
<< error_msg << " The sequence length mismatches.";
CHECK_EQ((*data_ptr)->shape[2], num_kv_heads_)
<< error_msg << " The number of heads mismatches.";
- CHECK_EQ((*data_ptr)->shape[3], head_dim_)
+ CHECK_EQ((*data_ptr)->shape[3], qk_head_dim_)
<< error_msg << " The number of head features mismatches.";
}
@@ -2250,6 +2338,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
append_position_map.data() + start_pos,
(end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) /
8));
for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+ CHECK(attn_kinds_[layer_id] == AttnKind::kMHA) << "Only MHA is supported
for DebugGetKV";
f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data,
v_data, layer_id);
}
}
@@ -2649,7 +2738,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
num_qo_heads_,
- num_kv_heads_, head_dim_, copy_stream_);
+ num_kv_heads_, qk_head_dim_, copy_stream_);
}
}
for (int d = 0; d < num_depths_; ++d) {
@@ -2661,15 +2750,15 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
f_attention_decode_begin_forward_.value()(
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
page_indptr_on_depths_host_[d].as_ndarray(),
- last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_,
num_kv_heads_, head_dim_,
- page_size_,
+ last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_,
num_kv_heads_,
+ qk_head_dim_, page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_float_attn_workspace_,
temp_int_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(),
page_indptr_on_depths_host_[d].as_ndarray(),
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1,
num_qo_heads_,
- num_kv_heads_, head_dim_, page_size_, copy_stream_);
+ num_kv_heads_, qk_head_dim_, page_size_, copy_stream_);
}
}
}
@@ -2893,7 +2982,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
// 16. Create view for temporary arrays for attention computation.
temp_attn_output_view_ = temp_attn_output_device_.CreateView(
- {total_append_length, num_qo_heads_, head_dim_},
temp_attn_output_device_->dtype);
+ {total_append_length, num_qo_heads_, qk_head_dim_},
temp_attn_output_device_->dtype);
temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
{total_append_length, num_qo_heads_}, temp_attn_scores_device_->dtype);
merged_attn_scores_view_ = merged_attn_scores_device_.CreateView(
@@ -2964,6 +3053,9 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
enable_kv_transfer = args[29];
}
+ std::vector<AttnKind> attn_kinds(/*size=*/layer_indptr_tuple[num_groups],
+ /*value=*/AttnKind::kMHA);
+
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
@@ -2975,13 +3067,18 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
// When sliding window is enabled, each sequence may use two more
pages at most.
num_total_pages += reserved_num_seqs * 2;
}
+ // NOTE: We will remove this legacy construction after finishing the
transition phase.
+ // Some `PackedFunc()` here are placeholders that will be filled.
ObjectPtr<PagedAttentionKVCacheObj> n =
make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
- reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
- RoPEMode(rope_mode), rotary_scale, rotary_theta,
std::move(rope_ext_factors), //
- enable_kv_transfer, init->dtype, init->device,
//
- std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
- std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
+ head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs,
num_total_pages,
+ prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale,
+ rotary_theta,
+ std::move(rope_ext_factors), //
+ enable_kv_transfer, init->dtype, init->device, //
+ std::move(f_transpose_append), PackedFunc(),
std::move(f_compact_copy),
+ std::move(f_attention_prefill), std::move(f_attention_decode),
+ std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
std::move(f_attention_prefill_with_tree_mask_paged_kv),
@@ -2989,7 +3086,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_prefill_ragged_end_forward),
std::move(f_attention_prefill_begin_forward),
std::move(f_attention_prefill_end_forward),
std::move(f_attention_decode_begin_forward),
std::move(f_attention_decode_end_forward),
- std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_copy_single_page),
+ PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(),
std::move(f_merge_inplace),
+ std::move(f_split_rotary), PackedFunc(),
std::move(f_copy_single_page),
std::move(f_debug_get_kv));
*rv = AttentionKVCache(std::move(n));
});
@@ -3040,6 +3138,9 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
enable_kv_transfer = args[23];
}
+ std::vector<AttnKind> attn_kinds(/*size=*/layer_indptr_tuple[num_groups],
+ /*value=*/AttnKind::kMHA);
+
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
@@ -3051,18 +3152,138 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
// When sliding window is enabled, each sequence may use two more
pages at most.
num_total_pages += reserved_num_seqs * 2;
}
+ // NOTE: We will remove this legacy construction after finishing the
transition phase.
+ // Some `PackedFunc()` here are placeholders that will be filled.
ObjectPtr<PagedAttentionKVCacheObj> n =
make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
- reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
- RoPEMode(rope_mode), rotary_scale, rotary_theta,
std::move(rope_ext_factors), //
- enable_kv_transfer, init->dtype, init->device,
//
- std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
- std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
+ head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs,
num_total_pages,
+ prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale,
+ rotary_theta,
+ std::move(rope_ext_factors), //
+ enable_kv_transfer, init->dtype, init->device, //
+ std::move(f_transpose_append), PackedFunc(),
std::move(f_compact_copy),
+ std::move(f_attention_prefill), std::move(f_attention_decode),
+ std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
std::move(f_attention_prefill_with_tree_mask_paged_kv), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
- std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_copy_single_page),
+ PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(),
std::move(f_merge_inplace),
+ std::move(f_split_rotary), PackedFunc(),
std::move(f_copy_single_page),
+ std::move(f_debug_get_kv));
+ *rv = AttentionKVCache(std::move(n));
+ });
+
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla")
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ CHECK(args.size() == 39) << "Invalid number of KV cache constructor
args.";
+ ShapeTuple cache_config = args[0];
+ ShapeTuple layer_indptr_tuple = args[1];
+ int num_groups = 1;
+ int group_id = 0;
+ if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+ // In the Disco worker thread
+ num_groups = disco_worker->num_groups;
+ group_id = disco_worker->worker_id / (disco_worker->num_workers /
num_groups);
+ }
+ CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+ int64_t num_layers = layer_indptr_tuple[group_id + 1] -
layer_indptr_tuple[group_id];
+ int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
+ int64_t num_qo_heads = args[2];
+ int64_t num_kv_heads = args[3];
+ int64_t qk_head_dim = args[4];
+ int64_t v_head_dim = args[5];
+ int64_t qk_rope_head_dim = args[6];
+ IntTuple attn_kinds = args[7];
+ int rope_mode = args[8];
+ double rotary_scale = args[9];
+ double rotary_theta = args[10];
+ NDArray init = args[11];
+ PackedFunc f_transpose_append = args[12];
+ PackedFunc f_transpose_append_mla = args[13];
+ PackedFunc f_attention_prefill = args[14];
+ PackedFunc f_attention_decode = args[15];
+ PackedFunc f_attention_prefill_sliding_window = args[16];
+ PackedFunc f_attention_decode_sliding_window = args[17];
+ PackedFunc f_attention_prefill_ragged = args[18];
+ Optional<PackedFunc> f_attention_prefill_ragged_begin_forward = NullOpt;
+ Optional<PackedFunc> f_attention_prefill_ragged_end_forward = NullOpt;
+ Optional<PackedFunc> f_attention_prefill_begin_forward = NullOpt;
+ Optional<PackedFunc> f_attention_prefill_end_forward = NullOpt;
+ Optional<PackedFunc> f_attention_decode_begin_forward = NullOpt;
+ Optional<PackedFunc> f_attention_decode_end_forward = NullOpt;
+ PackedFunc f_mla_prefill = args[25];
+ PackedFunc f_mla_decode = args[26];
+ PackedFunc f_mla_prefill_ragged_normal = args[27];
+ PackedFunc f_mla_prefill_ragged_absorbed = args[28];
+ PackedFunc f_merge_inplace = args[29];
+ PackedFunc f_split_rotary = args[30];
+ PackedFunc f_separate_rotary = args[31];
+ PackedFunc f_copy_single_page = args[32];
+ Optional<PackedFunc> f_debug_get_kv = args[33];
+ PackedFunc f_compact_copy = args[34];
+ PackedFunc f_attention_prefill_with_tree_mask = args[35];
+ PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[36];
+ Optional<NDArray> rope_ext_factors = NullOpt;
+ bool enable_kv_transfer = false;
+
+ if (args[37].IsObjectRef<NDArray>()) {
+ rope_ext_factors = args[37].AsObjectRef<NDArray>();
+ }
+ enable_kv_transfer = args[38];
+
+ auto f_convert_optional_packed_func = [&args](int arg_idx) ->
Optional<PackedFunc> {
+ if (args[arg_idx].IsObjectRef<PackedFunc>()) {
+ return args[arg_idx].AsObjectRef<PackedFunc>();
+ }
+ return NullOpt;
+ };
+ f_attention_prefill_ragged_begin_forward =
f_convert_optional_packed_func(19);
+ f_attention_prefill_ragged_end_forward =
f_convert_optional_packed_func(20);
+ f_attention_prefill_begin_forward = f_convert_optional_packed_func(21);
+ f_attention_prefill_end_forward = f_convert_optional_packed_func(22);
+ f_attention_decode_begin_forward = f_convert_optional_packed_func(23);
+ f_attention_decode_end_forward = f_convert_optional_packed_func(24);
+
+ std::vector<AttnKind> attn_kinds_vec;
+ attn_kinds_vec.reserve(attn_kinds.size());
+ for (int64_t attn_kind : attn_kinds) {
+ attn_kinds_vec.push_back(static_cast<AttnKind>(attn_kind));
+ }
+
+ CHECK_EQ(cache_config.size(), 5);
+ int64_t reserved_num_seqs = cache_config[0];
+ int64_t total_token_capacity = cache_config[1];
+ int64_t prefill_chunk_size = cache_config[2];
+ int64_t page_size = cache_config[3];
+ bool support_sliding_window = cache_config[4];
+ int64_t num_total_pages = (total_token_capacity + page_size - 1) /
page_size + 1;
+ if (support_sliding_window) {
+ // When sliding window is enabled, each sequence may use two more
pages at most.
+ num_total_pages += reserved_num_seqs * 2;
+ }
+ // NOTE: We will remove this legacy construction after finishing the
transition phase.
+ // Some `PackedFunc()` here are placeholders that will be filled.
+ ObjectPtr<PagedAttentionKVCacheObj> n =
make_object<PagedAttentionKVCacheObj>(
+ page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, qk_head_dim,
+ v_head_dim, qk_rope_head_dim, attn_kinds_vec, reserved_num_seqs,
num_total_pages,
+ prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale,
+ rotary_theta,
+ std::move(rope_ext_factors), //
+ enable_kv_transfer, init->dtype, init->device, //
+ std::move(f_transpose_append), std::move(f_transpose_append_mla),
+ std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode),
+ std::move(f_attention_prefill_sliding_window),
+ std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
+ std::move(f_attention_prefill_with_tree_mask), //
+ std::move(f_attention_prefill_with_tree_mask_paged_kv), //
+ std::move(f_attention_prefill_ragged_begin_forward),
+ std::move(f_attention_prefill_ragged_end_forward),
+ std::move(f_attention_prefill_begin_forward),
std::move(f_attention_prefill_end_forward),
+ std::move(f_attention_decode_begin_forward),
std::move(f_attention_decode_end_forward),
+ std::move(f_mla_prefill), std::move(f_mla_decode),
std::move(f_mla_prefill_ragged_normal),
+ std::move(f_mla_prefill_ragged_absorbed), std::move(f_merge_inplace),
+ std::move(f_split_rotary), std::move(f_separate_rotary),
std::move(f_copy_single_page),
std::move(f_debug_get_kv));
*rv = AttentionKVCache(std::move(n));
});