tqchen commented on code in PR #15910: URL: https://github.com/apache/tvm/pull/15910#discussion_r1356882128
########## src/runtime/relax_vm/paged_kv_cache.cc: ########## @@ -0,0 +1,633 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/paged_kv_cache.cc + * \brief Runtime paged KV cache object for language models. + */ +#include <tvm/runtime/device_api.h> +#include <tvm/runtime/logging.h> +#include <tvm/runtime/ndarray.h> +#include <tvm/runtime/registry.h> + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//------------------------------------------- +// We keep the implementation private as +// they may subject to future changes. +// +// Users can interact with it through the +// runtime API function calls +//------------------------------------------- + +/*! + * \brief The paged KV cache for attention. + * - It supports managing the K/V data of **multiple sequences**. + * - It manages K/V values by doing paging along the sequence-length + * dimension with a configured page size. + * - The basic example use of the paged KV cache after initialization + * in each round of model forwarding is the following: + * - step 1. use `ResetAppendLengths` to reset the appending information + * for preparation, + * - step 2. use `ReserveExtraLengthForAppend` to specify the length + * of K/V data to be appended for each sequence, + * - step 3. use `SyncAuxArrayToDevice` to synchronize auxiliary arrays + * to device for append/attention computation, + * - step 4. for each layer, use `Append` to append the K/V data to the + * cache, and then use `Attention` to compute attention results with + * Q data. + */ +class PagedAttentionKVCacheObj : public Object { + private: + /*! \brief The total number of sequences managed in the KV cache. */ + int64_t num_total_seqs_ = 0; + /*! \brief The number of pages that are in use by the sequences. */ + int64_t num_pages_in_use_ = 0; + /*! + * \brief The number of allocated pages, including the in-use pages + * and the pages released due to sequence removal. + */ + int64_t num_pages_allocated_ = 0; + + /********************* Configuration *********************/ + + /*! \brief The page size (the sequence length each page manages) of the cache. */ + const int64_t page_size_; + /*! \brief The number of layers in the model. */ + const int64_t num_layers_; + /*! \brief The number of heads in the model. */ + const int64_t num_heads_; + /*! \brief The number of features each head has. */ + const int64_t head_dim_; + + /*! \brief We fix int32 to be the index dtype of auxiliary data. */ + const DLDataType dtype_aux_ = DataType::Int(32, 1).operator DLDataType(); + + /********************* Page Structures *********************/ + + /*! + * \brief The KV data managed by the KV cache. + * It has layout (num_pages, num_layers, 2, num_heads, page_size, head_dim). + * Along on the "2" dimension, index 0 stands for K and 1 stands for V. + */ + NDArray pages_; + /*! \brief The list of ids of released pages for page reuse. */ + std::vector<int32_t> free_page_ids_; + + /*! \brief The list of page ids assigned for each sequence in the cache. */ + std::vector<std::vector<int32_t>> page_table_; + /*! \brief The lengths of each sequence in the cache. */ + std::vector<int32_t> seq_lengths_; + + /********************* Current Batch Info *********************/ + + /*! + * \brief The current lengths to append for each sequence. + * - The new K/V data appended to the cache must have the same length + * as stored in this array. + * - The Q data passed in for attention must also have the same length + * as stored. + * \note Invoke "ResetAppendLengths" to reset this array to all-zero. + */ + std::vector<int64_t> cur_append_lengths_; + + /********************* Auxiliary Arrays on Device *********************/ + //------------------------------------------- + // The following fields are auxiliary arrays on device. + // All of them are directly derivable from the fields above. + // We store them for efficient execution of attentions, + // cache append, etc. + //------------------------------------------- + /*! + * \brief A boolean flag indicating if the auxiliary arrays are dirty. + * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked. + */ + bool dirty_aux_data_device_ = false; + /*! + * \brief The page table indptr array on device. + * \note Since page table is a ragged data structure, we represent it + * in CSR format (which uses indptr and values below) on device. + */ + NDArray page_table_indptr_device_; + /*! \brief The page table value array on device. */ + NDArray page_table_values_device_; + /*! + * \brief The array storing "the number of used slots in the last page + * of each sequence" on device. Its values range in (0, page_size_]. + */ + NDArray last_page_offset_device_; + /*! + * \brief The append_length indptr array on device. + * \note Since the Q/K/V data may have raggedness in terms of lengths, + * we represent the the append lengths in CSR format. + */ + NDArray cur_append_length_indptr_device_; + /*! + * \brief The corresponding sequence id for each position along the + * length dimension of K/V data. It is used for efficient computation. + */ + NDArray cur_pos2seqid_device_; + + public: + /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ + explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers, int64_t num_heads, + int64_t head_dim, int64_t reserved_num_seqs, + int64_t reserved_num_pages, DLDataType dtype, DLDevice device) + : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads), head_dim_(head_dim) { + pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads, page_size, head_dim}, + dtype, device); + page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + page_table_values_device_ = NDArray::Empty({reserved_num_pages}, dtype_aux_, device); + last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size}, dtype_aux_, device); + } + + /*! + * \brief Given a sequence id and a required extra length, allocate new + * new pages for the sequence until the total capacity can cover the + * current sequence length plus the required extra length. + * \note By reserving extra length for the sequence, the subsequent appended + * K/V data for the sequence must have the same length as the input length here. + * \param seq_id The id of the sequence to process. + * \param extra_length The extra length to reserve for the sequence. + */ + void ReserveExtraLengthForAppend(int64_t seq_id, int64_t extra_length) { + CHECK_GE(seq_id, 0) << "Input sequence id should be positive"; + CHECK_LE(seq_id, num_total_seqs_) + << "Invalid input sequence id " << seq_id + << ", which is larger than the current total number of sequences."; + CHECK_GT(extra_length, 0) << "The input length should be positive."; + + if (seq_id == num_total_seqs_) { + // Initialize if it is a new sequence. + page_table_.push_back({}); + seq_lengths_.push_back(0); + cur_append_lengths_.push_back(0); + ++num_total_seqs_; + } + + // The reservation is based on the current sequence length. + // If "current sequence + input extra length" does not exceed the + // current capacity (number of pages * page size), no action is taken. + int64_t cur_npage = page_table_[seq_id].size(); + int64_t tgt_npage = (seq_lengths_[seq_id] + extra_length + page_size_ - 1) / page_size_; + for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { + AllocatePageForSequence(seq_id); + } + seq_lengths_[seq_id] += extra_length; + cur_append_lengths_[seq_id] = extra_length; + dirty_aux_data_device_ = true; + } + + /*! + * \brief The entrance of attention. It takes an attention compute + * function together with the query data, invokes the attention + * function to complete the computation. + * \param f_attention The input attention compute function. + * \param q_data The query data. We support the following two layout settings: + * - in batch decode settings, q_data has layout (num_total_seqs, 1, num_heads, head_dim) + * where num_total_seqs should exactly equal to the number of sequences in the cache. + * - in other settings (single-sequence prefill, batch prefill, or speculation + * verification), q_data has the **flattened layout** to handle raggedness: + * (1, total query length, num_heads, head_dim). + * \param layer_id The model layer index of the current attention. + * \param output The attention output array. + * \param apply_rotary A boolean flag indicating if to apply RoPE. + * \param rotary_scale The RoPE scale if applicable. + * \param rotary_theta The RoPE theta if applicable. + */ + void Attention(PackedFunc f_attention, NDArray q_data, int64_t layer_id, NDArray output, + bool apply_rotary = false, double rotary_scale = 1.0f, double rotary_theta = 1e4) { + // Check q_data shape validity. + CHECK_EQ(q_data->ndim, 4); + CHECK_GT(q_data->shape[1], 0); + CHECK_EQ(q_data->shape[2], num_heads_); + CHECK_EQ(q_data->shape[3], head_dim_); + CHECK(q_data.DataType() == pages_.DataType()); + + if (q_data->shape[0] > 1) { + CHECK_EQ(q_data->shape[0], num_total_seqs_); + CHECK_EQ(q_data->shape[1], 1); + } + int64_t ntoken = 0; + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + ntoken += cur_append_lengths_[seq_id]; + CHECK_LE(cur_append_lengths_[seq_id], seq_lengths_[seq_id]); + if (q_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths_[seq_id], 1); + } + } + CHECK_EQ(ntoken, q_data->shape[0] * q_data->shape[1]); + + // The auxiliary data structure on device must have been synchronized. + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`SyncAuxArrayToDevice` to synchronize before calling `Attention`."; + + // Provide a 8MB size float temp buffer for the attention kernel. + static NDArray tmp_buffer = + NDArray::Empty({8 * 1024 * 1024}, DLDataType(DataType::Float(32)), pages_->device); + + f_attention(q_data, pages_, // + page_table_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_), // + page_table_values_device_.CreateView({num_pages_in_use_}, dtype_aux_), // + last_page_offset_device_.CreateView({num_total_seqs_}, dtype_aux_), // + cur_append_length_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_), // + layer_id, tmp_buffer, output, apply_rotary, rotary_scale, rotary_theta); + } + + /*! + * \brief Append the k/v data to the cache on the given layer. + * \param f_transpose_append The function that copies the input data to the + * cache data. It does a transpose-copy due to the layout difference between + * K/V and the cache data. + * \param k_data The input k data. + * \param v_data The input v data. + * \param layer_id The model layer index of the current append. + * \note We support the following two layout settings for K/V data: + * - in batch decode settings, k/v_data has layout (num_total_seqs, 1, num_heads, head_dim) + * where num_total_seqs should exactly equal to the number of sequences in the cache. + * - in other settings (single-sequence prefill, batch prefill, or speculation + * verification), k/v_data has the **flattened layout** to handle raggedness: + * (1, total query length, num_heads, head_dim). + */ + void Append(PackedFunc f_transpose_append, NDArray k_data, NDArray v_data, int64_t layer_id) { + // Check k/v_data shape validity + CHECK_EQ(k_data->ndim, 4); + CHECK_GT(k_data->shape[1], 0); + CHECK_EQ(k_data->shape[2], num_heads_); + CHECK_EQ(k_data->shape[3], head_dim_); + for (int i = 0; i < 4; ++i) { + CHECK_EQ(k_data->shape[i], v_data->shape[i]); + } + CHECK(k_data.DataType() == pages_.DataType()); + CHECK(v_data.DataType() == pages_.DataType()); + + if (k_data->shape[0] > 1) { + CHECK_EQ(k_data->shape[0], num_total_seqs_); + CHECK_EQ(k_data->shape[1], 1); + } + int64_t ntoken = 0; + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + ntoken += cur_append_lengths_[seq_id]; + CHECK_LE(cur_append_lengths_[seq_id], seq_lengths_[seq_id]); + if (k_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths_[seq_id], 1); + } + } + CHECK_EQ(ntoken, k_data->shape[0] * k_data->shape[1]); + + // The auxiliary data structure on device must have been synchronized. + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`SyncAuxArrayToDevice` to synchronize before calling `Append`."; + + // Copy data + f_transpose_append( + pages_, // + k_data.CreateView({ntoken, num_heads_, head_dim_}, k_data->dtype), + v_data.CreateView({ntoken, num_heads_, head_dim_}, v_data->dtype), + page_table_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_), + page_table_values_device_.CreateView({num_pages_in_use_}, dtype_aux_), + last_page_offset_device_.CreateView({num_total_seqs_}, dtype_aux_), + cur_append_length_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_), + cur_pos2seqid_device_.CreateView({ntoken}, dtype_aux_), // + layer_id); + } + + /*! + * \brief Remove the given sequence from the cache. + * The id of all sequences on behind of it will be decreased by 1. + * \param seq_id The sequence to remove. + */ + void Remove(int64_t seq_id) { + CHECK_LT(seq_id, num_total_seqs_); + for (int32_t page_id : page_table_[seq_id]) { + FreePage(page_id); + } + page_table_.erase(page_table_.begin() + seq_id); + seq_lengths_.erase(seq_lengths_.begin() + seq_id); + cur_append_lengths_.erase(cur_append_lengths_.begin() + seq_id); + --num_total_seqs_; + dirty_aux_data_device_ = true; + } + + /*! + * \brief Pop the last `n` slots of K/V values for the given sequence. + * \param seq_id The sequence to be processed. + * \param n The length to pop. + */ + void PopN(int64_t seq_id, int64_t n) { + CHECK_LT(seq_id, num_total_seqs_); + CHECK_GE(n, 0); + CHECK_LE(n, seq_lengths_[seq_id]); + + // NOTE: this method does not free pages. + seq_lengths_[seq_id] -= n; + dirty_aux_data_device_ = true; + } + + /*! + * \brief Returning the cached K/V values in the form of "an array of NDArray". + * Each returned NDArray has layout (num_layers, 2, seqlen, num_heads, head_dim), + * where along on the "2" dimension, index 0 stands for K and 1 stands for V. + * \param f_view The function used for copying data out from cache. + * \return The cached K/V values, one NDArray per sequence. + * \note This method is majorly for testing purpose. + */ + Array<NDArray> View(PackedFunc f_view) { Review Comment: This is no longer a view since it copies data out. Let us do DebugGetKV -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
