This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b87d1f9b01 [Runtime] Stateless interface of PagedKVCache leaf node 
commit (#17057)
b87d1f9b01 is described below

commit b87d1f9b0124877769d537d4748c63546d2b2d8b
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jun 2 07:43:09 2024 -0400

    [Runtime] Stateless interface of PagedKVCache leaf node commit (#17057)
    
    This PR changes the interface of the function
    `CommitAcceptedTokenTreeNodeToKVCache` introduced recently for
    PagedKVCache to a stateless interface. Previously the interace
    is a stateful one, which makes strong assumption on the caller
    side. This commit removes the assumption so that the interface
    becomes less confusing.
---
 src/runtime/relax_vm/kv_state.h                    |   4 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 177 +++++++++++++--------
 ...runtime_builtin_paged_attention_kv_cache_tir.py |   9 +-
 3 files changed, 119 insertions(+), 71 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 8de560f122..f4d6036b96 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -151,9 +151,11 @@ class AttentionKVCacheObj : public KVStateObj {
    * The commit will update the KV cache, by compacting the KV data and discard
    * the KV data of rejected tokens.
    * This is a mandatory step when the BeginForward is given with a token tree.
+   * \param seq_ids The ids of the sequences to commit.
    * \param leaf_indices The leaf token tree node index of each sequence.
    */
-  virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
+  virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
+                                            const IntTuple& leaf_indices) = 0;
 
   /************** Attention **************/
 
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index a5b970e817..2fc5da78e9 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -151,6 +151,18 @@ struct Sequence {
    */
   int last_block_attn_sink_size = 0;
 
+  /*! \brief Whether the current appended tokens form a chain (not a tree). */
+  bool is_chain = true;
+  /*! \brief The token tree parent pointer array of the current appended 
tokens. */
+  std::vector<int32_t> token_tree_parent_ptr;
+  /*! \brief The depth of each node in the token tree. */
+  std::vector<int32_t> token_tree_node_depths;
+  /*!
+   * \brief A boolean denoting whether the accepted token tree indices of
+   * this sequence are committed
+   */
+  bool accepted_indices_committed = true;
+
   explicit Sequence(std::vector<Block>* global_block_pool, int32_t 
last_block_idx) {
     ++global_block_pool->at(last_block_idx).external_ref_cnt;
     this->last_block_idx = last_block_idx;
@@ -879,10 +891,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   IntTuple cur_seq_ids_;
   /*! \brief The append lengths of the sequences in the current round of 
forwarding. */
   IntTuple cur_append_lengths_;
-  /*! \brief The token tree parent array of the sequences in the current round 
of forwarding. */
-  IntTuple cur_token_tree_parent_ptr_{nullptr};
-  /*! \brief The depth of each node in the token tree, for the sequences in 
the current batch. */
-  std::vector<std::vector<int32_t>> cur_token_tree_node_depths_;
   /*! \brief Whether the current batch of sequences are token chains (not 
token trees). */
   bool is_chain_;
   /*! \brief Number of fork depth in the current round of forward. */
@@ -1187,6 +1195,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         << "The forked position should be non-negative, or -1 for last 
position as default.";
     CHECK_LE(fork_pos, parent_it->second.seq_length)
         << "The forked position should not exceed the total length of parent 
sequence.";
+    CHECK(parent_it->second.accepted_indices_committed)
+        << "The parent sequence's token tree computed in the last round of 
forward has not been "
+           "committed with accepted nodes.";
 
     int32_t child_block_idx = GetFreeBlock();
     if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
@@ -1434,10 +1445,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
   void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
                     const Optional<IntTuple>& opt_token_tree_parent_ptr) final 
{
-    CHECK(!cur_token_tree_parent_ptr_.defined())
-        << "The last round of forward which involves token tree has not been 
committed. Please "
-           "call \"CommitAcceptedTreeNodes\" to commit the accepted tokens.";
-
     CHECK_EQ(seq_ids.size(), append_lengths.size())
         << "The seq_ids size (" << seq_ids.size() << ") and append_lengths 
size ("
         << append_lengths.size() << ") mismatch.";
@@ -1445,14 +1452,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     cur_seq_ids_ = seq_ids;
     cur_append_lengths_ = append_lengths;
 
-    // - Check token tree validity and process the token tree.
-    is_chain_ = true;
-    tree_attn_mask_host_.clear();
-    tree_attn_mn_indptr_host_.clear();
-    if (opt_token_tree_parent_ptr.defined()) {
-      is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value());
-    }
-
     // - Collect sequence/block/page information for attention.
     std::vector<Sequence*> sequences;
     std::vector<int32_t> last_block_length_before_append;
@@ -1474,6 +1473,29 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       }
     }
 
+    // - Check token tree validity and process the token tree.
+    is_chain_ = true;
+    tree_attn_mask_host_.clear();
+    tree_attn_mn_indptr_host_.clear();
+    if (opt_token_tree_parent_ptr.defined()) {
+      is_chain_ = ConstructTokenTreeMask(sequences, 
opt_token_tree_parent_ptr.value());
+    } else {
+      // The input batch does not form trees. So each sequence in the batch
+      // is required to have all past accepted tokens committed.
+      for (int i = 0; i < cur_batch_size_; ++i) {
+        Sequence* sequence = sequences[i];
+        CHECK(sequence->accepted_indices_committed)
+            << "The input batch does not form a tree, in which case the 
sequences in the input "
+               "batch are expected to have their accepted tokens token tree 
nodes committed. "
+               "Please invoke CommitAcceptedTokenTreeNodes for sequence "
+            << seq_ids[i];
+        sequence->is_chain = true;
+        sequence->token_tree_parent_ptr.clear();
+        sequence->token_tree_node_depths.clear();
+      }
+      is_chain_ = true;
+    }
+
     std::vector<std::vector<int32_t>> block_ids_on_depths = 
GetBlockIdsOnDepth(sequences);
     num_depths_ = block_ids_on_depths.size();
     ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
@@ -1559,7 +1581,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       for (int64_t pos = 0; pos < append_length; ++pos) {
         q_rope_position_map_host_.push_back(
             k_ragged_rope_pos_offset_host_[i] +
-            (is_chain_ ? pos : cur_token_tree_node_depths_[i][pos]));
+            (is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos]));
 
         int32_t pos_in_block = block.seq_length - append_length + pos;
         if (last_block_length_before_append[i] + pos < block.sink_length) {
@@ -1649,19 +1671,26 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
-  void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final {
-    CHECK_NE(cur_batch_size_, -1)
-        << "Cannot commit accepted token tree nodes since BeginForward is not 
invoked.";
-    CHECK_EQ(leaf_indices.size(), cur_batch_size_)
-        << "The number of input leaf indices does not equal to the current 
batch size.";
+  void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& 
leaf_indices) final {
+    CHECK_EQ(seq_ids.size(), leaf_indices.size())
+        << "The given seq_ids and leaf_indices have different size.";
+    int num_seq_to_commit = seq_ids.size();
 
-    for (int i = 0; i < cur_batch_size_; ++i) {
-      CHECK_GE(leaf_indices[i], 0)
-          << "Invalid tree index " << leaf_indices[i] << " which is negative";
-      CHECK_LT(leaf_indices[i], cur_append_lengths_[i])
+    std::vector<Sequence*> sequences;
+    sequences.reserve(num_seq_to_commit);
+    for (int i = 0; i < num_seq_to_commit; ++i) {
+      auto it = seq_map_.find(seq_ids[i]);
+      CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
+                                  << "\" cannot be found in KV cache.";
+      sequences.push_back(&it->second);
+      CHECK(!it->second.accepted_indices_committed)
+          << "The accepted nodes of sequence " << seq_ids[i] << " are already 
committed.";
+      CHECK_GE(leaf_indices[i], -1)
+          << "Invalid tree index " << leaf_indices[i] << " which is less than 
-1";
+      CHECK_LT(leaf_indices[i], 
static_cast<int64_t>(it->second.token_tree_parent_ptr.size()))
           << "Invalid tree index " << leaf_indices[i]
-          << " which is larger than or equals to the append length " << 
cur_append_lengths_[i]
-          << " of the sequence";
+          << " which is larger than or equals to the append length "
+          << it->second.token_tree_parent_ptr.size() << " of the sequence";
     }
 
     if (!is_chain_) {
@@ -1670,16 +1699,21 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       commit_copy_dst_pos_in_page_table_host_.clear();
       commit_copy_length_indptr_host_.push_back(0);
 
-      for (int i = 0; i < cur_batch_size_; ++i) {
+      for (int i = 0; i < num_seq_to_commit; ++i) {
+        if (leaf_indices[i] == -1) {
+          // No node is accepted. All nodes in the token tree need to be 
popped.
+          continue;
+        }
+
         // Get the accepted node path on the token tree.
         std::vector<int32_t> path_on_tree;
-        path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 
1);
+        
path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
         int node = leaf_indices[i];
         while (node != -1) {
           path_on_tree.push_back(node);
-          node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] 
+ node];
+          node = sequences[i]->token_tree_parent_ptr[node];
         }
-        ICHECK_EQ(path_on_tree.size(), 
cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
+        ICHECK_EQ(path_on_tree.size(), 
sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
         // Get the destination array (range [0, path_length - 1)) of KV cache 
copy.
         std::vector<int32_t> copy_dst_pos_in_seq;
         copy_dst_pos_in_seq.resize(path_on_tree.size());
@@ -1714,14 +1748,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     //   Note: Function "PopN" only changes the page table structure and does 
not
     //         change the KV cache data. Therefore, we can directly use it, 
since
     //         we have already launched all copies.
-    for (int i = 0; i < cur_batch_size_; ++i) {
+    for (int i = 0; i < num_seq_to_commit; ++i) {
       int64_t length_to_pop =
-          cur_append_lengths_[i] - 
cur_token_tree_node_depths_[i][leaf_indices[i]] - 1;
+          cur_append_lengths_[i] -
+          (leaf_indices[i] != -1 ? 
(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0);
       PopN(cur_seq_ids_[i], length_to_pop);
+      // Reset the sequence states.
+      sequences[i]->accepted_indices_committed = true;
+      sequences[i]->token_tree_parent_ptr.clear();
+      sequences[i]->token_tree_node_depths.clear();
     }
-
-    // Reset the token tree.
-    cur_token_tree_parent_ptr_ = IntTuple{nullptr};
   }
 
   NDArray GetQueryPositions() final {
@@ -1814,57 +1850,67 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     return block_idx;
   }
 
-  bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) {
+  bool ConstructTokenTreeMask(const std::vector<Sequence*>& sequences,
+                              const IntTuple& token_tree_parent_ptr) {
     // We check if the token tree deteriorates to a chain,
     // because chain cases can have simplified attention work flow.
     bool is_chain = true;
-    cur_token_tree_parent_ptr_ = token_tree_parent_ptr;
-    cur_token_tree_node_depths_.clear();
-    cur_token_tree_node_depths_.reserve(cur_batch_size_);
-
-    int64_t sum_append_length = 0;
+    int64_t sum_new_append_length = 0;
     // - Construct the mn indptr array, which is the indptr of the mask size 
of each sequence.
     tree_attn_mn_indptr_host_.push_back(0);
-    for (int64_t append_length : cur_append_lengths_) {
-      sum_append_length += append_length;
+    ICHECK_EQ(sequences.size(), cur_batch_size_);
+    ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_);
+    for (int i = 0; i < cur_batch_size_; ++i) {
+      int64_t append_length = cur_append_lengths_[i];
+      // Update the token tree parent pointers.
+      sequences[i]->token_tree_parent_ptr = {
+          token_tree_parent_ptr->data + sum_new_append_length,
+          token_tree_parent_ptr->data + sum_new_append_length + 
cur_append_lengths_[i]};
+      sum_new_append_length += cur_append_lengths_[i];
+
+      CHECK_LE(append_length, kTreeAttnMaxTreeSize)
+          << "The tree size is " << append_length << " which exceeds the 
maximum tree size limit "
+          << kTreeAttnMaxTreeSize;
       tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() +
-                                          static_cast<int32_t>(append_length * 
append_length));
+                                          append_length * append_length);
     }
-    CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length)
-        << "Invalid token tree size. The sum of \"append_lengths\" is " << 
sum_append_length
+    CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length)
+        << "Invalid token tree size. The sum of \"append_lengths\" is " << 
sum_new_append_length
         << " while there are " << token_tree_parent_ptr.size()
         << " elements in \"token_tree_parent_ptr\".";
 
     // - Construct the mask of each sequence.
-    int processed_pos = 0;
     for (int i = 0; i < cur_batch_size_; ++i) {
-      int64_t append_length = cur_append_lengths_[i];
+      int64_t tree_size = sequences[i]->token_tree_parent_ptr.size();
       std::vector<std::vector<int32_t>> mask;
       std::vector<int32_t> depth;
-      mask.reserve(append_length);
-      depth.reserve(append_length);
-      for (int64_t n = 0; n < append_length; ++n) {
-        CHECK_LT(token_tree_parent_ptr[processed_pos], n)
+      mask.reserve(tree_size);
+      depth.reserve(tree_size);
+      sequences[i]->is_chain = true;
+      sequences[i]->accepted_indices_committed = false;
+      for (int64_t n = 0; n < tree_size; ++n) {
+        CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n)
             << "Invalid token tree. The parent of node " << n << " in tree " 
<< i << " is "
-            << token_tree_parent_ptr[processed_pos] << ", which is not smaller 
than " << n;
-        CHECK_GE(token_tree_parent_ptr[processed_pos], -1)
+            << sequences[i]->token_tree_parent_ptr[n] << ", which is not 
smaller than " << n;
+        CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1)
             << "Invalid token tree. The parent of node " << n << " in tree " 
<< i << " is "
-            << token_tree_parent_ptr[processed_pos];
-        if (token_tree_parent_ptr[processed_pos] != n - 1) {
+            << sequences[i]->token_tree_parent_ptr[n];
+        if (sequences[i]->token_tree_parent_ptr[n] != n - 1) {
           // The parent of the current node is not the last node.
           // Therefore the tree is not a chain.
+          sequences[i]->is_chain = false;
           is_chain = false;
         }
 
         std::vector<int32_t> single_pos_mask;
-        if (token_tree_parent_ptr[processed_pos] != -1) {
+        if (sequences[i]->token_tree_parent_ptr[n] != -1) {
           // The current node has a parent in the token tree.
-          single_pos_mask = 
{mask[token_tree_parent_ptr[processed_pos]].begin(),
-                             mask[token_tree_parent_ptr[processed_pos]].end()};
-          depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1);
+          single_pos_mask = 
{mask[sequences[i]->token_tree_parent_ptr[n]].begin(),
+                             
mask[sequences[i]->token_tree_parent_ptr[n]].end()};
+          depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1);
         } else {
           // The current node is root in the token tree.
-          single_pos_mask.resize(append_length, /*value=*/0);
+          single_pos_mask.resize(tree_size, /*value=*/0);
           depth.push_back(0);
         }
         single_pos_mask[n] = 1;
@@ -1872,12 +1918,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         for (int32_t mask_val : single_pos_mask) {
           tree_attn_mask_host_.push_back(mask_val);
         }
-
-        ++processed_pos;
       }
-      cur_token_tree_node_depths_.push_back(std::move(depth));
+      sequences[i]->token_tree_node_depths = std::move(depth);
     }
-
     return is_chain;
   }
 
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 0a69d184e5..c5c88211ba 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -438,7 +438,10 @@ def apply_attention(
     fend_forward(kv_cache)
 
     if accepted_leaf_indices is not None:
-        fcommit_accepted_token_tree_nodes(kv_cache, 
ShapeTuple(accepted_leaf_indices))
+        seq_ids = [seq_id for seq_id, _ in batch]
+        fcommit_accepted_token_tree_nodes(
+            kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
+        )
         for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
             zip(accepted_leaf_indices, batch)
         ):
@@ -449,7 +452,7 @@ def apply_attention(
                 node = token_tree_parent_ptr_list[i][node]
             offset = cached_k[seq_id].shape[1] - append_length
             length_to_pop = append_length - len(tree_path)
-            assert 0 <= length_to_pop < append_length
+            assert 0 <= length_to_pop <= append_length
             for dst_pos, src_pos in enumerate(reversed(tree_path)):
                 if dst_pos == src_pos:
                     continue
@@ -773,7 +776,7 @@ def 
test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
             [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
             [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # chain of length 
14
         ],
-        accepted_leaf_indices=[2, 6, 6, 4],
+        accepted_leaf_indices=[2, 6, -1, 4],
     )
     # Do 5 rounds of decode.
     for _ in range(5):

Reply via email to