This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cd09ab64b5 [Runtime] Reorganize PagedKVCache attn kernel invocation 
(#17237)
cd09ab64b5 is described below

commit cd09ab64b5ccf6ff0a96d887a968acd4602188a8
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 3 20:01:01 2024 -0400

    [Runtime] Reorganize PagedKVCache attn kernel invocation (#17237)
    
    This PR reorganizes the attention kernel invocation logic in the
    PagedKVCache, so that in cases of sequence fork, we can effectively
    merge one ragged-prefill kernel and a decode kernel into a single
    decode kernel.
---
 src/relax/transform/fuse_ops.cc        |   2 +-
 src/runtime/relax_vm/paged_kv_cache.cc | 127 +++++++++++++++++----------------
 2 files changed, 65 insertions(+), 64 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index e791aeab06..85c739e083 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -646,7 +646,7 @@ class FunctionCreator : public ExprMutator {
       return tvm::tir::UndefinedVars(prim_value->value).empty();
     } else if (const auto* shape_expr = expr.as<ShapeExprNode>()) {
       return std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
-                         [this](const PrimExpr& e) { return 
tvm::tir::UndefinedVars(e).empty(); });
+                         [](const PrimExpr& e) { return 
tvm::tir::UndefinedVars(e).empty(); });
     }
     return false;
   }
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 5aa1411ec1..cf5de97202 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
     }
 
-    append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && 
use_decode_kernel_[0];
+    append_before_attn_ = !support_sliding_window_ && 
use_decode_kernel_.back();
     if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
       // When GQA group size is at least 4 and FlashInfer is enabled,
       // we always use prefill kernel for better performance.
@@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       return;
     }
 
-    if (append_before_attn_) {
-      if (!support_sliding_window_) {
+    if (!append_before_attn_) {
+      if (is_chain_) {
+        f_attention_prefill_ragged_begin_forward_.value()(
+            temp_attn_workspace_[0], 
cur_append_lengths_indptr_host_.as_ndarray(),
+            cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, 
num_qo_heads_,
+            num_kv_heads_, head_dim_, copy_stream_);
+      } else {
+        LOG(FATAL) << "Kernel BeginForward doesn't support tree attn.";
+      }
+    }
+    for (int d = 0; d < num_depths_; ++d) {
+      if (page_indices_on_depths_view_[d]->shape[0] == 0) {
+        continue;
+      }
+      CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support 
sliding window.";
+      if (use_decode_kernel_[d]) {
         f_attention_decode_begin_forward_.value()(
-            /*depth=*/0, temp_attn_workspace_[1], 
page_indptr_on_depths_host_[0].as_ndarray(),
-            last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, 
num_kv_heads_, head_dim_,
+            d, temp_attn_workspace_[d + 1], 
page_indptr_on_depths_host_[d].as_ndarray(),
+            last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, 
num_kv_heads_, head_dim_,
             page_size_,
             /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
-      }
-    } else {
-      f_attention_prefill_ragged_begin_forward_.value()(
-          temp_attn_workspace_[0], 
cur_append_lengths_indptr_host_.as_ndarray(),
-          cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, 
num_qo_heads_,
-          num_kv_heads_, head_dim_, copy_stream_);
-      if (support_sliding_window_) {
-        return;
-      }
-      for (int d = 0; d < num_depths_; ++d) {
-        if (page_indices_on_depths_view_[d]->shape[0] == 0) {
-          continue;
-        }
-        if (use_decode_kernel_[d]) {
-          f_attention_decode_begin_forward_.value()(
-              d, temp_attn_workspace_[d + 1], 
page_indptr_on_depths_host_[d].as_ndarray(),
-              last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, 
num_kv_heads_,
-              head_dim_, page_size_,
-              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
-        } else {
-          f_attention_prefill_begin_forward_.value()(
-              /*depth=*/d, temp_attn_workspace_[d + 1], 
qo_indptr_on_depths_host_[d].as_ndarray(),
-              page_indptr_on_depths_host_[d].as_ndarray(),
-              static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, 
num_qo_heads_,
-              num_kv_heads_, head_dim_, page_size_, copy_stream_);
-        }
+      } else {
+        f_attention_prefill_begin_forward_.value()(
+            /*depth=*/d, temp_attn_workspace_[d + 1], 
qo_indptr_on_depths_host_[d].as_ndarray(),
+            page_indptr_on_depths_host_[d].as_ndarray(),
+            static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, 
num_qo_heads_,
+            num_kv_heads_, head_dim_, page_size_, copy_stream_);
       }
     }
   }
@@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     PackedFunc f_decode =
         !support_sliding_window_ ? f_attention_decode_ : 
f_attention_decode_sliding_window_;
     CHECK_GE(num_depths_, 1) << "The number of effective depths must be 
greater or equal to 1.";
-    if (append_before_attn_) {
-      f_decode(
-          /*depth=*/0, q_data, pages_[local_layer_id], 
page_indptr_on_depths_view_[0],
-          page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
-          k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, 
merged_attn_scores_view_,
-          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
-          attn_score_scaling_factor);
-    } else {
-      // Compute appended text self-attention
+
+    bool is_first_kernel = true;
+    if (!append_before_attn_) {
+      // The first part of attention, which only involves the q and the newly 
appended k/v.
+      is_first_kernel = false;
       if (is_chain_) {
         // If the batch does not form a tree, use raggedness prefill kernel.
         f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, 
k_data, v_data,
@@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
             merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline,
             rotary_scale_, rotary_theta_, attn_score_scaling_factor, 
cur_batch_size_);
       }
+    }
 
-      for (int d = 0; d < num_depths_; ++d) {
-        if (page_indices_on_depths_view_[d]->shape[0] == 0) {
-          continue;
-        }
-        if (use_decode_kernel_[d]) {
-          // Use decode kernel for depth d
-          f_decode(/*depth=*/d, q_data, pages_[local_layer_id], 
page_indptr_on_depths_view_[d],
-                   page_indices_on_depths_view_[d], 
length_info_on_depths_view_[d],
-                   k_rope_pos_offset_view_[d], q_rope_position_map_view_, 
temp_attn_output_view_,
-                   temp_attn_scores_view_,
-                   /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_, rotary_theta_,
-                   attn_score_scaling_factor);
-        } else {
-          // Use prefill kernel for depth d
-          f_prefill(
-              /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], 
pages_[local_layer_id],
-              page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
-              length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
-              temp_attn_output_view_, temp_attn_scores_view_,
-              /*causal=*/0,
-              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
-              attn_score_scaling_factor);
-        }
+    for (int d = 0; d < num_depths_; ++d) {
+      if (page_indices_on_depths_view_[d]->shape[0] == 0) {
+        continue;
+      }
+      NDArray attn_output;
+      NDArray attn_scores;
+      if (is_first_kernel) {
+        attn_output = output;
+        attn_scores = merged_attn_scores_view_;
+      } else {
+        attn_output = temp_attn_output_view_;
+        attn_scores = temp_attn_scores_view_;
+      }
+      if (use_decode_kernel_[d]) {
+        // Use decode kernel for depth d
+        f_decode(/*depth=*/d, q_data, pages_[local_layer_id], 
page_indptr_on_depths_view_[d],
+                 page_indices_on_depths_view_[d], 
length_info_on_depths_view_[d],
+                 k_rope_pos_offset_view_[d], q_rope_position_map_view_, 
attn_output, attn_scores,
+                 /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_, rotary_theta_,
+                 attn_score_scaling_factor);
+      } else {
+        // Use prefill kernel for depth d
+        f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], 
pages_[local_layer_id],
+                  page_indptr_on_depths_view_[d], 
page_indices_on_depths_view_[d],
+                  length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
+                  q_rope_position_map_view_, attn_output, attn_scores, 
/*causal=*/0,
+                  /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_, rotary_theta_,
+                  attn_score_scaling_factor);
+      }
+
+      if (!is_first_kernel) {
         f_merge_inplace_(output, merged_attn_scores_view_, 
temp_attn_output_view_,
                          temp_attn_scores_view_);
+      } else {
+        is_first_kernel = false;
       }
     }
   }

Reply via email to