This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d92f2a85a [KVCache] Added support for normal MLA kernel (#17624)
6d92f2a85a is described below
commit 6d92f2a85a363e08a6a2d20d7ac22aeb863d099e
Author: Annanya <[email protected]>
AuthorDate: Thu Feb 20 14:12:24 2025 -0500
[KVCache] Added support for normal MLA kernel (#17624)
* Refactored code to allow for different v dimension from q/k dimension
* Made a small fix after the rebase
* Made changes to the runtime to support normal kernel
* Fixed a compilation issue
* Fix lint
---------
Co-authored-by: Ruihang Lai <[email protected]>
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 73 +++++++++++++++++++++-----
src/runtime/relax_vm/kv_state.cc | 10 ++++
src/runtime/relax_vm/paged_kv_cache.cc | 77 +++++++++++++++++++++++++++-
3 files changed, 147 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index f5ff0105d0..ea6f153316 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -180,6 +180,49 @@ class PagedKVCache(Object): # pylint:
disable=too-few-public-methods
)
).reshape(b, s, h_qo, kv_lora_rank)
+ def mla_normal(
+ self,
+ layer_id: int,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ compressed_kv: Tensor,
+ k_pe: Tensor,
+ attn_score_scaling_factor: float = 1.0,
+ ) -> Tensor:
+ """Compute multi-head latent attention with the given data
+ on the specified layer using the normal flow(WITHOUT weight
absorption).
+ """
+ # pylint: disable=protected-access
+ b, s, h_qo, d_qk = q._expr.struct_info.shape
+ d_v = v._expr.struct_info.shape[3]
+ kv_lora_rank = compressed_kv._expr.struct_info.shape[3]
+ qk_rope_head_dim = k_pe._expr.struct_info.shape[3]
+ q = q.reshape(b * s, h_qo, d_qk)
+ k = k.reshape(b * s, h_qo, d_qk)
+ v = v.reshape(b * s, h_qo, d_v)
+ compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank)
+ k_pe = k_pe.reshape(b * s, qk_rope_head_dim)
+
+ return Tensor(
+ _expr=rx.BlockBuilder.current().emit(
+ rx.call_dps_packed(
+ "vm.builtin.attention_kv_cache_mla_normal",
+ [
+ self._expr,
+ rx.PrimValue(layer_id), # type: ignore[arg-type]
+ rx.PrimValue(attn_score_scaling_factor),
+ q._expr,
+ k._expr,
+ v._expr,
+ compressed_kv._expr,
+ k_pe._expr,
+ ],
+ out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype),
+ )
+ )
+ ).reshape(b, s, h_qo, d_v)
+
def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
"""Get the in-sequence positions of each slot in the query,
which are needed for applying positional embeddings in some models.
@@ -591,7 +634,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
rx.PrimValue(0),
bb.add_func(_attention_prefill_mla(num_attention_heads,
kv_lora_rank, qk_rope_head_dim, dtype, False, target),
"tir_attention_prefill_mla"),
bb.add_func(_attention_decode_mla(num_attention_heads,
kv_lora_rank, qk_rope_head_dim, dtype, False, target),
"tir_attention_decode_mla"),
- bb.add_func(_attention_prefill_ragged(num_key_value_heads,
num_attention_heads, v_head_dim, dtype, {}, target),
"tir_attention_prefill_ragged_mla_normal"),
+ bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads,
num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target),
"tir_attention_prefill_ragged_mla_normal"),
bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads,
kv_lora_rank, qk_rope_head_dim, dtype, target),
"tir_attention_prefill_ragged_mla_absorbed"),
bb.add_func(_merge_state_inplace(num_attention_heads,
kv_lora_rank, dtype, target), "tir_attention_merge_state"),
bb.add_func(llama_rope_with_position_map(10000, 1,
qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None),
"tir_split_rotary"),
@@ -2420,6 +2463,12 @@ def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, A
def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str,
Any], target: Target):
+ return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype,
rope_scaling, target)
+
+
+def _attention_prefill_ragged_generic(
+ h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target
+):
# pylint: disable=line-too-long
(
NUM_BLKS,
@@ -2431,7 +2480,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
tile_x,
tile_y,
tile_z,
- ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
+ ) = _get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target)
# fmt: off
@T.prim_func
@@ -2459,14 +2508,14 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
q_rope_position_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
- q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+ q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset)
- k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
- v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+ k = T.match_buffer(var_k, (kv_len, h_kv, d_qk), dtype)
+ v = T.match_buffer(var_v, (kv_len, h_kv, d_v), dtype)
kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32",
elem_offset=kv_indptr_elem_offset)
q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,),
"int32", elem_offset=q_rope_position_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset,
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
- output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+ output = T.match_buffer(var_output, (qo_len, h_q, d_v), dtype)
lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint:
disable=unused-variable
# kernel code
@@ -2485,13 +2534,13 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
iterator = _var("int32")
kv_chunk_len = _var("int32")
- Q_smem = T.alloc_buffer((tile_x, d), dtype,
scope="shared")
- K_smem = T.alloc_buffer((tile_z, d), dtype,
scope="shared")
- V_smem = T.alloc_buffer((tile_z, d), dtype,
scope="shared")
+ Q_smem = T.alloc_buffer((tile_x, d_qk), dtype,
scope="shared")
+ K_smem = T.alloc_buffer((tile_z, d_qk), dtype,
scope="shared")
+ V_smem = T.alloc_buffer((tile_z, d_v), dtype,
scope="shared")
S_smem = T.alloc_buffer((tile_x, tile_z),
"float32", scope="shared")
S_local = T.alloc_buffer((tile_x, tile_z),
"float32", scope="local")
- O_local = T.alloc_buffer((tile_x, d), "float32",
scope="local")
+ O_local = T.alloc_buffer((tile_x, d_v), "float32",
scope="local")
m_smem = T.alloc_buffer((tile_x, ), "float32",
scope="shared")
m_prev_smem = T.alloc_buffer((tile_x, ),
"float32", scope="shared")
@@ -2548,7 +2597,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
- _rope(q,
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype,
rope_scaling),
+ _rope(q,
q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j),
dtype, rope_scaling),
q[cur_L, cur_H_qo, j]
)
else:
@@ -2565,7 +2614,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
if cur_L < kv_chunk_len[0]:
K_smem[i, j] =
T.if_then_else(
rotary_mode == 1,
- _rope(k,
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base +
cur_L, by, j), dtype, rope_scaling),
+ _rope(k,
k_rope_pos_offset[b_idx] + cur_L, d_qk, rope_theta, rope_scale, (L_kv_base +
cur_L, by, j), dtype, rope_scaling),
k[L_kv_base + cur_L,
by, j]
)
else:
diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index c78ada58e6..1b1867f060 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -90,6 +90,16 @@
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
std::move(k_pe_data), std::move(o_data),
attn_score_scaling_factor);
});
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
+ .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+ double attn_score_scaling_factor, NDArray q_data,
NDArray k_data,
+ NDArray v_data, NDArray compressed_kv_data, NDArray
k_pe_data,
+ NDArray o_data) {
+ kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data),
std::move(v_data),
+ std::move(compressed_kv_data), std::move(k_pe_data),
std::move(o_data),
+ attn_score_scaling_factor);
+ });
+
// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 075ff0b944..a936f429ee 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -2241,7 +2241,82 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray
v_data,
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
double attn_score_scaling_factor) {
- // Todo(ruihang): implement it
+ // Part 1: Basic Checks and Setup.
+ int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+ CHECK_GE(local_layer_id, 0);
+ CHECK_LT(local_layer_id, num_layers_);
+ NDArray pages = pages_[local_layer_id];
+ CHECK(q_data.DataType() == pages.DataType());
+ CHECK(k_data.DataType() == pages.DataType());
+ CHECK(v_data.DataType() == pages.DataType());
+ CHECK(compressed_kv_data.DataType() == pages.DataType());
+ CHECK(k_pe_data.DataType() == pages.DataType());
+ CHECK(o_data.DataType() == pages.DataType());
+ CHECK(attn_kinds_[layer_id] == AttnKind::kMLA);
+
+ // Expected shapes:
+ // q_data: (num_total_length, num_qo_heads, qk_head_dim)
+ // k_data: (num_total_length, num_qo_heads, qk_head_dim)
+ // v_data: (num_total_length, num_qo_heads, v_head_dim)
+ // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
+ // k_pe_data: (num_total_length, qk_rope_head_dim)
+ // o_data: (num_total_length, num_qo_heads, v_head_dim)
+ CHECK_EQ(q_data->ndim, 3);
+ CHECK_EQ(k_data->ndim, 3);
+ CHECK_EQ(v_data->ndim, 3);
+ CHECK_EQ(compressed_kv_data->ndim, 2);
+ CHECK_EQ(k_pe_data->ndim, 2);
+ CHECK_EQ(o_data->ndim, 3);
+
+ int64_t total_seq_length = 0;
+ for (int64_t i = 0; i < cur_batch_size_; ++i) {
+ total_seq_length += cur_append_lengths_[i];
+ }
+ CHECK_LE(q_data->shape[0], total_seq_length);
+ CHECK_LE(k_data->shape[0], total_seq_length);
+ CHECK_LE(v_data->shape[0], total_seq_length);
+ CHECK_LE(compressed_kv_data->shape[0], total_seq_length);
+ CHECK_LE(k_pe_data->shape[0], total_seq_length);
+ CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_);
+ CHECK_LE(o_data->shape[0], total_seq_length);
+ CHECK_EQ(q_data->shape[1], num_qo_heads_);
+ CHECK_EQ(o_data->shape[1], num_qo_heads_);
+ CHECK_EQ(k_data->shape[1], num_qo_heads_);
+ CHECK_EQ(v_data->shape[1], num_qo_heads_);
+ CHECK_EQ(q_data->shape[2], qk_head_dim_);
+ CHECK_EQ(k_data->shape[2], qk_head_dim_);
+ CHECK_EQ(v_data->shape[2], v_head_dim_);
+ CHECK_EQ(o_data->shape[2], v_head_dim_);
+
+ // Part 2: Synchronize streams and update auxiliary data.
+ ComputeStreamWaitForCopyStream();
+ ICHECK(!dirty_aux_data_device_);
+
+ // Append k/v data to kv-cache if flag "append_before_attn" is set.
+ if (append_before_attn_) {
+ f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data,
k_pe_data,
+ append_position_map_view_);
+ }
+
+ // Part 4: Call the ragged kernel.
+ // Here, we use f_mla_prefill_ragged_normal_, which is designed to work
for both decode
+ // and normal prefill cases. Optionally, you could check a flag like
`use_decode_kernel_[0]`
+ // to adjust parameters; here we assume the kernel internally supports
both cases.
+ f_mla_prefill_ragged_normal_(q_data, cur_append_length_indptr_view_,
k_data, v_data,
+ cur_append_length_indptr_view_,
q_rope_position_map_view_,
+ k_ragged_rope_pos_offset_view_,
+ o_data, // output tensor
+ merged_attn_scores_view_,
+ /*causal=*/1,
static_cast<int>(RoPEMode::kNone),
+ 0, // Rope param, not important
+ 0, // Rope param, not important
+ attn_score_scaling_factor);
+
+ // Part 5: If appending is to occur after attention, call the append
kernel.
+ if (!append_before_attn_) {
+ f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data,
k_pe_data,
+ append_position_map_view_);
+ }
}
void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,