This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d92f2a85a [KVCache] Added support for normal MLA kernel (#17624)
6d92f2a85a is described below

commit 6d92f2a85a363e08a6a2d20d7ac22aeb863d099e
Author: Annanya <[email protected]>
AuthorDate: Thu Feb 20 14:12:24 2025 -0500

    [KVCache] Added support for normal MLA kernel (#17624)
    
    * Refactored code to allow for different v dimension from q/k dimension
    
    * Made a small fix after the rebase
    
    * Made changes to the runtime to support normal kernel
    
    * Fixed a compilation issue
    
    * Fix lint
    
    ---------
    
    Co-authored-by: Ruihang Lai <[email protected]>
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py | 73 +++++++++++++++++++++-----
 src/runtime/relax_vm/kv_state.cc             | 10 ++++
 src/runtime/relax_vm/paged_kv_cache.cc       | 77 +++++++++++++++++++++++++++-
 3 files changed, 147 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index f5ff0105d0..ea6f153316 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -180,6 +180,49 @@ class PagedKVCache(Object):  # pylint: 
disable=too-few-public-methods
             )
         ).reshape(b, s, h_qo, kv_lora_rank)
 
+    def mla_normal(
+        self,
+        layer_id: int,
+        q: Tensor,
+        k: Tensor,
+        v: Tensor,
+        compressed_kv: Tensor,
+        k_pe: Tensor,
+        attn_score_scaling_factor: float = 1.0,
+    ) -> Tensor:
+        """Compute multi-head latent attention with the given data
+        on the specified layer using the normal flow(WITHOUT weight 
absorption).
+        """
+        # pylint: disable=protected-access
+        b, s, h_qo, d_qk = q._expr.struct_info.shape
+        d_v = v._expr.struct_info.shape[3]
+        kv_lora_rank = compressed_kv._expr.struct_info.shape[3]
+        qk_rope_head_dim = k_pe._expr.struct_info.shape[3]
+        q = q.reshape(b * s, h_qo, d_qk)
+        k = k.reshape(b * s, h_qo, d_qk)
+        v = v.reshape(b * s, h_qo, d_v)
+        compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank)
+        k_pe = k_pe.reshape(b * s, qk_rope_head_dim)
+
+        return Tensor(
+            _expr=rx.BlockBuilder.current().emit(
+                rx.call_dps_packed(
+                    "vm.builtin.attention_kv_cache_mla_normal",
+                    [
+                        self._expr,
+                        rx.PrimValue(layer_id),  # type: ignore[arg-type]
+                        rx.PrimValue(attn_score_scaling_factor),
+                        q._expr,
+                        k._expr,
+                        v._expr,
+                        compressed_kv._expr,
+                        k_pe._expr,
+                    ],
+                    out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype),
+                )
+            )
+        ).reshape(b, s, h_qo, d_v)
+
     def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
         """Get the in-sequence positions of each slot in the query,
         which are needed for applying positional embeddings in some models.
@@ -591,7 +634,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
             rx.PrimValue(0),
             bb.add_func(_attention_prefill_mla(num_attention_heads, 
kv_lora_rank, qk_rope_head_dim, dtype, False, target), 
"tir_attention_prefill_mla"),
             bb.add_func(_attention_decode_mla(num_attention_heads, 
kv_lora_rank, qk_rope_head_dim, dtype, False, target), 
"tir_attention_decode_mla"),
-            bb.add_func(_attention_prefill_ragged(num_key_value_heads, 
num_attention_heads, v_head_dim, dtype, {}, target), 
"tir_attention_prefill_ragged_mla_normal"),
+            bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, 
num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), 
"tir_attention_prefill_ragged_mla_normal"),
             
bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, 
kv_lora_rank, qk_rope_head_dim, dtype, target), 
"tir_attention_prefill_ragged_mla_absorbed"),
             bb.add_func(_merge_state_inplace(num_attention_heads, 
kv_lora_rank, dtype, target), "tir_attention_merge_state"),
             bb.add_func(llama_rope_with_position_map(10000, 1, 
qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), 
"tir_split_rotary"),
@@ -2420,6 +2463,12 @@ def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, A
 
 
 def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, 
Any], target: Target):
+    return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype, 
rope_scaling, target)
+
+
+def _attention_prefill_ragged_generic(
+    h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target
+):
     # pylint: disable=line-too-long
     (
         NUM_BLKS,
@@ -2431,7 +2480,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
         tile_x,
         tile_y,
         tile_z,
-    ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
+    ) = _get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target)
 
     # fmt: off
     @T.prim_func
@@ -2459,14 +2508,14 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
         q_rope_position_elem_offset = T.int32(is_size_var=True)
         k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
 
-        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+        q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype)
         q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
-        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
-        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+        k = T.match_buffer(var_k, (kv_len, h_kv, d_qk), dtype)
+        v = T.match_buffer(var_v, (kv_len, h_kv, d_v), dtype)
         kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", 
elem_offset=kv_indptr_elem_offset)
         q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
         k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
-        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+        output = T.match_buffer(var_output, (qo_len, h_q, d_v), dtype)
         lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
 
         # kernel code
@@ -2485,13 +2534,13 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
                             iterator = _var("int32")
                             kv_chunk_len = _var("int32")
 
-                            Q_smem = T.alloc_buffer((tile_x, d), dtype, 
scope="shared")
-                            K_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
-                            V_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, 
scope="shared")
+                            K_smem = T.alloc_buffer((tile_z, d_qk), dtype, 
scope="shared")
+                            V_smem = T.alloc_buffer((tile_z, d_v), dtype, 
scope="shared")
                             S_smem = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="shared")
 
                             S_local = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="local")
-                            O_local = T.alloc_buffer((tile_x, d), "float32", 
scope="local")
+                            O_local = T.alloc_buffer((tile_x, d_v), "float32", 
scope="local")
 
                             m_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
                             m_prev_smem = T.alloc_buffer((tile_x, ), 
"float32", scope="shared")
@@ -2548,7 +2597,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
                                             if cur_L < q_indptr[b_idx + 1]:
                                                 Q_smem[i, j] = T.if_then_else(
                                                     rotary_mode == 1,
-                                                    _rope(q, 
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, 
rope_scaling),
+                                                    _rope(q, 
q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j), 
dtype, rope_scaling),
                                                     q[cur_L, cur_H_qo, j]
                                                 )
                                             else:
@@ -2565,7 +2614,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, 
rope_scaling: Dict[str, Any],
                                                 if cur_L < kv_chunk_len[0]:
                                                     K_smem[i, j] = 
T.if_then_else(
                                                         rotary_mode == 1,
-                                                        _rope(k, 
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + 
cur_L, by, j), dtype, rope_scaling),
+                                                        _rope(k, 
k_rope_pos_offset[b_idx] + cur_L, d_qk, rope_theta, rope_scale, (L_kv_base + 
cur_L, by, j), dtype, rope_scaling),
                                                         k[L_kv_base + cur_L, 
by, j]
                                                     )
                                                 else:
diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index c78ada58e6..1b1867f060 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -90,6 +90,16 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
                             std::move(k_pe_data), std::move(o_data), 
attn_score_scaling_factor);
     });
 
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
+    .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
+                       double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
+                       NDArray v_data, NDArray compressed_kv_data, NDArray 
k_pe_data,
+                       NDArray o_data) {
+      kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
+                          std::move(compressed_kv_data), std::move(k_pe_data), 
std::move(o_data),
+                          attn_score_scaling_factor);
+    });
+
 // RNN State methods
 
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
 TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 075ff0b944..a936f429ee 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -2241,7 +2241,82 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray 
v_data,
                  NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
                  double attn_score_scaling_factor) {
-    // Todo(ruihang): implement it
+    // Part 1: Basic Checks and Setup.
+    int64_t local_layer_id = layer_id - layer_id_begin_offset_;
+    CHECK_GE(local_layer_id, 0);
+    CHECK_LT(local_layer_id, num_layers_);
+    NDArray pages = pages_[local_layer_id];
+    CHECK(q_data.DataType() == pages.DataType());
+    CHECK(k_data.DataType() == pages.DataType());
+    CHECK(v_data.DataType() == pages.DataType());
+    CHECK(compressed_kv_data.DataType() == pages.DataType());
+    CHECK(k_pe_data.DataType() == pages.DataType());
+    CHECK(o_data.DataType() == pages.DataType());
+    CHECK(attn_kinds_[layer_id] == AttnKind::kMLA);
+
+    // Expected shapes:
+    //   q_data:             (num_total_length, num_qo_heads, qk_head_dim)
+    //   k_data:             (num_total_length, num_qo_heads, qk_head_dim)
+    //   v_data:             (num_total_length, num_qo_heads, v_head_dim)
+    //   compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
+    //   k_pe_data:          (num_total_length, qk_rope_head_dim)
+    //   o_data:             (num_total_length, num_qo_heads, v_head_dim)
+    CHECK_EQ(q_data->ndim, 3);
+    CHECK_EQ(k_data->ndim, 3);
+    CHECK_EQ(v_data->ndim, 3);
+    CHECK_EQ(compressed_kv_data->ndim, 2);
+    CHECK_EQ(k_pe_data->ndim, 2);
+    CHECK_EQ(o_data->ndim, 3);
+
+    int64_t total_seq_length = 0;
+    for (int64_t i = 0; i < cur_batch_size_; ++i) {
+      total_seq_length += cur_append_lengths_[i];
+    }
+    CHECK_LE(q_data->shape[0], total_seq_length);
+    CHECK_LE(k_data->shape[0], total_seq_length);
+    CHECK_LE(v_data->shape[0], total_seq_length);
+    CHECK_LE(compressed_kv_data->shape[0], total_seq_length);
+    CHECK_LE(k_pe_data->shape[0], total_seq_length);
+    CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_);
+    CHECK_LE(o_data->shape[0], total_seq_length);
+    CHECK_EQ(q_data->shape[1], num_qo_heads_);
+    CHECK_EQ(o_data->shape[1], num_qo_heads_);
+    CHECK_EQ(k_data->shape[1], num_qo_heads_);
+    CHECK_EQ(v_data->shape[1], num_qo_heads_);
+    CHECK_EQ(q_data->shape[2], qk_head_dim_);
+    CHECK_EQ(k_data->shape[2], qk_head_dim_);
+    CHECK_EQ(v_data->shape[2], v_head_dim_);
+    CHECK_EQ(o_data->shape[2], v_head_dim_);
+
+    // Part 2: Synchronize streams and update auxiliary data.
+    ComputeStreamWaitForCopyStream();
+    ICHECK(!dirty_aux_data_device_);
+
+    // Append k/v data to kv-cache if flag "append_before_attn" is set.
+    if (append_before_attn_) {
+      f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, 
k_pe_data,
+                              append_position_map_view_);
+    }
+
+    // Part 4: Call the ragged kernel.
+    // Here, we use f_mla_prefill_ragged_normal_, which is designed to work 
for both decode
+    // and normal prefill cases. Optionally, you could check a flag like 
`use_decode_kernel_[0]`
+    // to adjust parameters; here we assume the kernel internally supports 
both cases.
+    f_mla_prefill_ragged_normal_(q_data, cur_append_length_indptr_view_, 
k_data, v_data,
+                                 cur_append_length_indptr_view_, 
q_rope_position_map_view_,
+                                 k_ragged_rope_pos_offset_view_,
+                                 o_data,  // output tensor
+                                 merged_attn_scores_view_,
+                                 /*causal=*/1, 
static_cast<int>(RoPEMode::kNone),
+                                 0,  // Rope param, not important
+                                 0,  // Rope param, not important
+                                 attn_score_scaling_factor);
+
+    // Part 5: If appending is to occur after attention, call the append 
kernel.
+    if (!append_before_attn_) {
+      f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, 
k_pe_data,
+                              append_position_map_view_);
+    }
   }
 
   void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,

Reply via email to