This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b87d1f9b01 [Runtime] Stateless interface of PagedKVCache leaf node
commit (#17057)
b87d1f9b01 is described below
commit b87d1f9b0124877769d537d4748c63546d2b2d8b
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jun 2 07:43:09 2024 -0400
[Runtime] Stateless interface of PagedKVCache leaf node commit (#17057)
This PR changes the interface of the function
`CommitAcceptedTokenTreeNodeToKVCache` introduced recently for
PagedKVCache to a stateless interface. Previously the interace
is a stateful one, which makes strong assumption on the caller
side. This commit removes the assumption so that the interface
becomes less confusing.
---
src/runtime/relax_vm/kv_state.h | 4 +-
src/runtime/relax_vm/paged_kv_cache.cc | 177 +++++++++++++--------
...runtime_builtin_paged_attention_kv_cache_tir.py | 9 +-
3 files changed, 119 insertions(+), 71 deletions(-)
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 8de560f122..f4d6036b96 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -151,9 +151,11 @@ class AttentionKVCacheObj : public KVStateObj {
* The commit will update the KV cache, by compacting the KV data and discard
* the KV data of rejected tokens.
* This is a mandatory step when the BeginForward is given with a token tree.
+ * \param seq_ids The ids of the sequences to commit.
* \param leaf_indices The leaf token tree node index of each sequence.
*/
- virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
+ virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
+ const IntTuple& leaf_indices) = 0;
/************** Attention **************/
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index a5b970e817..2fc5da78e9 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -151,6 +151,18 @@ struct Sequence {
*/
int last_block_attn_sink_size = 0;
+ /*! \brief Whether the current appended tokens form a chain (not a tree). */
+ bool is_chain = true;
+ /*! \brief The token tree parent pointer array of the current appended
tokens. */
+ std::vector<int32_t> token_tree_parent_ptr;
+ /*! \brief The depth of each node in the token tree. */
+ std::vector<int32_t> token_tree_node_depths;
+ /*!
+ * \brief A boolean denoting whether the accepted token tree indices of
+ * this sequence are committed
+ */
+ bool accepted_indices_committed = true;
+
explicit Sequence(std::vector<Block>* global_block_pool, int32_t
last_block_idx) {
++global_block_pool->at(last_block_idx).external_ref_cnt;
this->last_block_idx = last_block_idx;
@@ -879,10 +891,6 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
IntTuple cur_seq_ids_;
/*! \brief The append lengths of the sequences in the current round of
forwarding. */
IntTuple cur_append_lengths_;
- /*! \brief The token tree parent array of the sequences in the current round
of forwarding. */
- IntTuple cur_token_tree_parent_ptr_{nullptr};
- /*! \brief The depth of each node in the token tree, for the sequences in
the current batch. */
- std::vector<std::vector<int32_t>> cur_token_tree_node_depths_;
/*! \brief Whether the current batch of sequences are token chains (not
token trees). */
bool is_chain_;
/*! \brief Number of fork depth in the current round of forward. */
@@ -1187,6 +1195,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
<< "The forked position should be non-negative, or -1 for last
position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent
sequence.";
+ CHECK(parent_it->second.accepted_indices_committed)
+ << "The parent sequence's token tree computed in the last round of
forward has not been "
+ "committed with accepted nodes.";
int32_t child_block_idx = GetFreeBlock();
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
@@ -1434,10 +1445,6 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& opt_token_tree_parent_ptr) final
{
- CHECK(!cur_token_tree_parent_ptr_.defined())
- << "The last round of forward which involves token tree has not been
committed. Please "
- "call \"CommitAcceptedTreeNodes\" to commit the accepted tokens.";
-
CHECK_EQ(seq_ids.size(), append_lengths.size())
<< "The seq_ids size (" << seq_ids.size() << ") and append_lengths
size ("
<< append_lengths.size() << ") mismatch.";
@@ -1445,14 +1452,6 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
cur_seq_ids_ = seq_ids;
cur_append_lengths_ = append_lengths;
- // - Check token tree validity and process the token tree.
- is_chain_ = true;
- tree_attn_mask_host_.clear();
- tree_attn_mn_indptr_host_.clear();
- if (opt_token_tree_parent_ptr.defined()) {
- is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value());
- }
-
// - Collect sequence/block/page information for attention.
std::vector<Sequence*> sequences;
std::vector<int32_t> last_block_length_before_append;
@@ -1474,6 +1473,29 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
+ // - Check token tree validity and process the token tree.
+ is_chain_ = true;
+ tree_attn_mask_host_.clear();
+ tree_attn_mn_indptr_host_.clear();
+ if (opt_token_tree_parent_ptr.defined()) {
+ is_chain_ = ConstructTokenTreeMask(sequences,
opt_token_tree_parent_ptr.value());
+ } else {
+ // The input batch does not form trees. So each sequence in the batch
+ // is required to have all past accepted tokens committed.
+ for (int i = 0; i < cur_batch_size_; ++i) {
+ Sequence* sequence = sequences[i];
+ CHECK(sequence->accepted_indices_committed)
+ << "The input batch does not form a tree, in which case the
sequences in the input "
+ "batch are expected to have their accepted tokens token tree
nodes committed. "
+ "Please invoke CommitAcceptedTokenTreeNodes for sequence "
+ << seq_ids[i];
+ sequence->is_chain = true;
+ sequence->token_tree_parent_ptr.clear();
+ sequence->token_tree_node_depths.clear();
+ }
+ is_chain_ = true;
+ }
+
std::vector<std::vector<int32_t>> block_ids_on_depths =
GetBlockIdsOnDepth(sequences);
num_depths_ = block_ids_on_depths.size();
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
@@ -1559,7 +1581,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
for (int64_t pos = 0; pos < append_length; ++pos) {
q_rope_position_map_host_.push_back(
k_ragged_rope_pos_offset_host_[i] +
- (is_chain_ ? pos : cur_token_tree_node_depths_[i][pos]));
+ (is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos]));
int32_t pos_in_block = block.seq_length - append_length + pos;
if (last_block_length_before_append[i] + pos < block.sink_length) {
@@ -1649,19 +1671,26 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
- void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final {
- CHECK_NE(cur_batch_size_, -1)
- << "Cannot commit accepted token tree nodes since BeginForward is not
invoked.";
- CHECK_EQ(leaf_indices.size(), cur_batch_size_)
- << "The number of input leaf indices does not equal to the current
batch size.";
+ void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple&
leaf_indices) final {
+ CHECK_EQ(seq_ids.size(), leaf_indices.size())
+ << "The given seq_ids and leaf_indices have different size.";
+ int num_seq_to_commit = seq_ids.size();
- for (int i = 0; i < cur_batch_size_; ++i) {
- CHECK_GE(leaf_indices[i], 0)
- << "Invalid tree index " << leaf_indices[i] << " which is negative";
- CHECK_LT(leaf_indices[i], cur_append_lengths_[i])
+ std::vector<Sequence*> sequences;
+ sequences.reserve(num_seq_to_commit);
+ for (int i = 0; i < num_seq_to_commit; ++i) {
+ auto it = seq_map_.find(seq_ids[i]);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
+ << "\" cannot be found in KV cache.";
+ sequences.push_back(&it->second);
+ CHECK(!it->second.accepted_indices_committed)
+ << "The accepted nodes of sequence " << seq_ids[i] << " are already
committed.";
+ CHECK_GE(leaf_indices[i], -1)
+ << "Invalid tree index " << leaf_indices[i] << " which is less than
-1";
+ CHECK_LT(leaf_indices[i],
static_cast<int64_t>(it->second.token_tree_parent_ptr.size()))
<< "Invalid tree index " << leaf_indices[i]
- << " which is larger than or equals to the append length " <<
cur_append_lengths_[i]
- << " of the sequence";
+ << " which is larger than or equals to the append length "
+ << it->second.token_tree_parent_ptr.size() << " of the sequence";
}
if (!is_chain_) {
@@ -1670,16 +1699,21 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
commit_copy_dst_pos_in_page_table_host_.clear();
commit_copy_length_indptr_host_.push_back(0);
- for (int i = 0; i < cur_batch_size_; ++i) {
+ for (int i = 0; i < num_seq_to_commit; ++i) {
+ if (leaf_indices[i] == -1) {
+ // No node is accepted. All nodes in the token tree need to be
popped.
+ continue;
+ }
+
// Get the accepted node path on the token tree.
std::vector<int32_t> path_on_tree;
- path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] +
1);
+
path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
int node = leaf_indices[i];
while (node != -1) {
path_on_tree.push_back(node);
- node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i]
+ node];
+ node = sequences[i]->token_tree_parent_ptr[node];
}
- ICHECK_EQ(path_on_tree.size(),
cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
+ ICHECK_EQ(path_on_tree.size(),
sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
// Get the destination array (range [0, path_length - 1)) of KV cache
copy.
std::vector<int32_t> copy_dst_pos_in_seq;
copy_dst_pos_in_seq.resize(path_on_tree.size());
@@ -1714,14 +1748,16 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// Note: Function "PopN" only changes the page table structure and does
not
// change the KV cache data. Therefore, we can directly use it,
since
// we have already launched all copies.
- for (int i = 0; i < cur_batch_size_; ++i) {
+ for (int i = 0; i < num_seq_to_commit; ++i) {
int64_t length_to_pop =
- cur_append_lengths_[i] -
cur_token_tree_node_depths_[i][leaf_indices[i]] - 1;
+ cur_append_lengths_[i] -
+ (leaf_indices[i] != -1 ?
(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0);
PopN(cur_seq_ids_[i], length_to_pop);
+ // Reset the sequence states.
+ sequences[i]->accepted_indices_committed = true;
+ sequences[i]->token_tree_parent_ptr.clear();
+ sequences[i]->token_tree_node_depths.clear();
}
-
- // Reset the token tree.
- cur_token_tree_parent_ptr_ = IntTuple{nullptr};
}
NDArray GetQueryPositions() final {
@@ -1814,57 +1850,67 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
return block_idx;
}
- bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) {
+ bool ConstructTokenTreeMask(const std::vector<Sequence*>& sequences,
+ const IntTuple& token_tree_parent_ptr) {
// We check if the token tree deteriorates to a chain,
// because chain cases can have simplified attention work flow.
bool is_chain = true;
- cur_token_tree_parent_ptr_ = token_tree_parent_ptr;
- cur_token_tree_node_depths_.clear();
- cur_token_tree_node_depths_.reserve(cur_batch_size_);
-
- int64_t sum_append_length = 0;
+ int64_t sum_new_append_length = 0;
// - Construct the mn indptr array, which is the indptr of the mask size
of each sequence.
tree_attn_mn_indptr_host_.push_back(0);
- for (int64_t append_length : cur_append_lengths_) {
- sum_append_length += append_length;
+ ICHECK_EQ(sequences.size(), cur_batch_size_);
+ ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_);
+ for (int i = 0; i < cur_batch_size_; ++i) {
+ int64_t append_length = cur_append_lengths_[i];
+ // Update the token tree parent pointers.
+ sequences[i]->token_tree_parent_ptr = {
+ token_tree_parent_ptr->data + sum_new_append_length,
+ token_tree_parent_ptr->data + sum_new_append_length +
cur_append_lengths_[i]};
+ sum_new_append_length += cur_append_lengths_[i];
+
+ CHECK_LE(append_length, kTreeAttnMaxTreeSize)
+ << "The tree size is " << append_length << " which exceeds the
maximum tree size limit "
+ << kTreeAttnMaxTreeSize;
tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() +
- static_cast<int32_t>(append_length *
append_length));
+ append_length * append_length);
}
- CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length)
- << "Invalid token tree size. The sum of \"append_lengths\" is " <<
sum_append_length
+ CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length)
+ << "Invalid token tree size. The sum of \"append_lengths\" is " <<
sum_new_append_length
<< " while there are " << token_tree_parent_ptr.size()
<< " elements in \"token_tree_parent_ptr\".";
// - Construct the mask of each sequence.
- int processed_pos = 0;
for (int i = 0; i < cur_batch_size_; ++i) {
- int64_t append_length = cur_append_lengths_[i];
+ int64_t tree_size = sequences[i]->token_tree_parent_ptr.size();
std::vector<std::vector<int32_t>> mask;
std::vector<int32_t> depth;
- mask.reserve(append_length);
- depth.reserve(append_length);
- for (int64_t n = 0; n < append_length; ++n) {
- CHECK_LT(token_tree_parent_ptr[processed_pos], n)
+ mask.reserve(tree_size);
+ depth.reserve(tree_size);
+ sequences[i]->is_chain = true;
+ sequences[i]->accepted_indices_committed = false;
+ for (int64_t n = 0; n < tree_size; ++n) {
+ CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n)
<< "Invalid token tree. The parent of node " << n << " in tree "
<< i << " is "
- << token_tree_parent_ptr[processed_pos] << ", which is not smaller
than " << n;
- CHECK_GE(token_tree_parent_ptr[processed_pos], -1)
+ << sequences[i]->token_tree_parent_ptr[n] << ", which is not
smaller than " << n;
+ CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1)
<< "Invalid token tree. The parent of node " << n << " in tree "
<< i << " is "
- << token_tree_parent_ptr[processed_pos];
- if (token_tree_parent_ptr[processed_pos] != n - 1) {
+ << sequences[i]->token_tree_parent_ptr[n];
+ if (sequences[i]->token_tree_parent_ptr[n] != n - 1) {
// The parent of the current node is not the last node.
// Therefore the tree is not a chain.
+ sequences[i]->is_chain = false;
is_chain = false;
}
std::vector<int32_t> single_pos_mask;
- if (token_tree_parent_ptr[processed_pos] != -1) {
+ if (sequences[i]->token_tree_parent_ptr[n] != -1) {
// The current node has a parent in the token tree.
- single_pos_mask =
{mask[token_tree_parent_ptr[processed_pos]].begin(),
- mask[token_tree_parent_ptr[processed_pos]].end()};
- depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1);
+ single_pos_mask =
{mask[sequences[i]->token_tree_parent_ptr[n]].begin(),
+
mask[sequences[i]->token_tree_parent_ptr[n]].end()};
+ depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1);
} else {
// The current node is root in the token tree.
- single_pos_mask.resize(append_length, /*value=*/0);
+ single_pos_mask.resize(tree_size, /*value=*/0);
depth.push_back(0);
}
single_pos_mask[n] = 1;
@@ -1872,12 +1918,9 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
for (int32_t mask_val : single_pos_mask) {
tree_attn_mask_host_.push_back(mask_val);
}
-
- ++processed_pos;
}
- cur_token_tree_node_depths_.push_back(std::move(depth));
+ sequences[i]->token_tree_node_depths = std::move(depth);
}
-
return is_chain;
}
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 0a69d184e5..c5c88211ba 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -438,7 +438,10 @@ def apply_attention(
fend_forward(kv_cache)
if accepted_leaf_indices is not None:
- fcommit_accepted_token_tree_nodes(kv_cache,
ShapeTuple(accepted_leaf_indices))
+ seq_ids = [seq_id for seq_id, _ in batch]
+ fcommit_accepted_token_tree_nodes(
+ kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
+ )
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
@@ -449,7 +452,7 @@ def apply_attention(
node = token_tree_parent_ptr_list[i][node]
offset = cached_k[seq_id].shape[1] - append_length
length_to_pop = append_length - len(tree_path)
- assert 0 <= length_to_pop < append_length
+ assert 0 <= length_to_pop <= append_length
for dst_pos, src_pos in enumerate(reversed(tree_path)):
if dst_pos == src_pos:
continue
@@ -773,7 +776,7 @@ def
test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length
14
],
- accepted_leaf_indices=[2, 6, 6, 4],
+ accepted_leaf_indices=[2, 6, -1, 4],
)
# Do 5 rounds of decode.
for _ in range(5):