tqchen commented on code in PR #15910:
URL: https://github.com/apache/tvm/pull/15910#discussion_r1356878752


##########
src/runtime/relax_vm/paged_kv_cache.cc:
##########
@@ -0,0 +1,633 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/runtime/relax_vm/paged_kv_cache.cc
+ * \brief Runtime paged KV cache object for language models.
+ */
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+//-------------------------------------------
+// We keep the implementation private as
+// they may subject to future changes.
+//
+// Users can interact with it through the
+// runtime API function calls
+//-------------------------------------------
+
+/*!
+ * \brief The paged KV cache for attention.
+ * - It supports managing the K/V data of **multiple sequences**.
+ * - It manages K/V values by doing paging along the sequence-length
+ * dimension with a configured page size.
+ * - The basic example use of the paged KV cache after initialization
+ * in each round of model forwarding is the following:
+ *   - step 1. use `ResetAppendLengths` to reset the appending information
+ *     for preparation,
+ *   - step 2. use `ReserveExtraLengthForAppend` to specify the length
+ *     of K/V data to be appended for each sequence,
+ *   - step 3. use `SyncAuxArrayToDevice` to synchronize auxiliary arrays
+ *     to device for append/attention computation,
+ *   - step 4. for each layer, use `Append` to append the K/V data to the
+ *     cache, and then use `Attention` to compute attention results with
+ *     Q data.
+ */
+class PagedAttentionKVCacheObj : public Object {
+ private:
+  /*! \brief The total number of sequences managed in the KV cache. */
+  int64_t num_total_seqs_ = 0;
+  /*! \brief The number of pages that are in use by the sequences. */
+  int64_t num_pages_in_use_ = 0;
+  /*!
+   * \brief The number of allocated pages, including the in-use pages
+   * and the pages released due to sequence removal.
+   */
+  int64_t num_pages_allocated_ = 0;
+
+  /********************* Configuration *********************/
+
+  /*! \brief The page size (the sequence length each page manages) of the 
cache. */
+  const int64_t page_size_;
+  /*! \brief The number of layers in the model. */
+  const int64_t num_layers_;
+  /*! \brief The number of heads in the model. */
+  const int64_t num_heads_;
+  /*! \brief The number of features each head has. */
+  const int64_t head_dim_;
+
+  /*! \brief We fix int32 to be the index dtype of auxiliary data. */
+  const DLDataType dtype_aux_ = DataType::Int(32, 1).operator DLDataType();
+
+  /********************* Page Structures *********************/
+
+  /*!
+   * \brief The KV data managed by the KV cache.
+   * It has layout (num_pages, num_layers, 2, num_heads, page_size, head_dim).
+   * Along on the "2" dimension, index 0 stands for K and 1 stands for V.
+   */
+  NDArray pages_;
+  /*! \brief The list of ids of released pages for page reuse. */
+  std::vector<int32_t> free_page_ids_;
+
+  /*! \brief The list of page ids assigned for each sequence in the cache. */
+  std::vector<std::vector<int32_t>> page_table_;
+  /*! \brief The lengths of each sequence in the cache. */
+  std::vector<int32_t> seq_lengths_;
+
+  /********************* Current Batch Info *********************/
+
+  /*!
+   * \brief The current lengths to append for each sequence.
+   * - The new K/V data appended to the cache must have the same length
+   * as stored in this array.
+   * - The Q data passed in for attention must also have the same length
+   * as stored.
+   * \note Invoke "ResetAppendLengths" to reset this array to all-zero.
+   */
+  std::vector<int64_t> cur_append_lengths_;
+
+  /********************* Auxiliary Arrays on Device *********************/
+  //-------------------------------------------
+  // The following fields are auxiliary arrays on device.
+  // All of them are directly derivable from the fields above.
+  // We store them for efficient execution of attentions,
+  // cache append, etc.
+  //-------------------------------------------
+  /*!
+   * \brief A boolean flag indicating if the auxiliary arrays are dirty.
+   * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
+   */
+  bool dirty_aux_data_device_ = false;
+  /*!
+   * \brief The page table indptr array on device.
+   * \note Since page table is a ragged data structure, we represent it
+   * in CSR format (which uses indptr and values below) on device.
+   */
+  NDArray page_table_indptr_device_;
+  /*! \brief The page table value array on device. */
+  NDArray page_table_values_device_;
+  /*!
+   * \brief The array storing "the number of used slots in the last page
+   * of each sequence" on device. Its values range in (0, page_size_].
+   */
+  NDArray last_page_offset_device_;
+  /*!
+   * \brief The append_length indptr array on device.
+   * \note Since the Q/K/V data may have raggedness in terms of lengths,
+   * we represent the the append lengths in CSR format.
+   */
+  NDArray cur_append_length_indptr_device_;
+  /*!
+   * \brief The corresponding sequence id for each position along the
+   * length dimension of K/V data. It is used for efficient computation.
+   */
+  NDArray cur_pos2seqid_device_;
+
+ public:
+  /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
+  explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers, 
int64_t num_heads,
+                                    int64_t head_dim, int64_t 
reserved_num_seqs,
+                                    int64_t reserved_num_pages, DLDataType 
dtype, DLDevice device)
+      : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads), 
head_dim_(head_dim) {
+    pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads, 
page_size, head_dim},
+                            dtype, device);
+    page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+    page_table_values_device_ = NDArray::Empty({reserved_num_pages}, 
dtype_aux_, device);
+    last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
+    cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+    cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size}, 
dtype_aux_, device);
+  }
+
+  /*!
+   * \brief Given a sequence id and a required extra length, allocate new
+   * new pages for the sequence until the total capacity can cover the
+   * current sequence length plus the required extra length.
+   * \note By reserving extra length for the sequence, the subsequent appended
+   * K/V data for the sequence must have the same length as the input length 
here.
+   * \param seq_id The id of the sequence to process.
+   * \param extra_length The extra length to reserve for the sequence.
+   */
+  void ReserveExtraLengthForAppend(int64_t seq_id, int64_t extra_length) {
+    CHECK_GE(seq_id, 0) << "Input sequence id should be positive";
+    CHECK_LE(seq_id, num_total_seqs_)
+        << "Invalid input sequence id " << seq_id
+        << ", which is larger than the current total number of sequences.";
+    CHECK_GT(extra_length, 0) << "The input length should be positive.";
+
+    if (seq_id == num_total_seqs_) {
+      // Initialize if it is a new sequence.
+      page_table_.push_back({});
+      seq_lengths_.push_back(0);
+      cur_append_lengths_.push_back(0);
+      ++num_total_seqs_;
+    }
+
+    // The reservation is based on the current sequence length.
+    // If "current sequence + input extra length" does not exceed the
+    // current capacity (number of pages * page size), no action is taken.
+    int64_t cur_npage = page_table_[seq_id].size();
+    int64_t tgt_npage = (seq_lengths_[seq_id] + extra_length + page_size_ - 1) 
/ page_size_;
+    for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) {
+      AllocatePageForSequence(seq_id);
+    }
+    seq_lengths_[seq_id] += extra_length;
+    cur_append_lengths_[seq_id] = extra_length;
+    dirty_aux_data_device_ = true;
+  }
+
+  /*!
+   * \brief The entrance of attention. It takes an attention compute
+   * function together with the query data, invokes the attention
+   * function to complete the computation.
+   * \param f_attention The input attention compute function.
+   * \param q_data The query data. We support the following two layout 
settings:
+   * - in batch decode settings, q_data has layout (num_total_seqs, 1, 
num_heads, head_dim)
+   *   where num_total_seqs should exactly equal to the number of sequences in 
the cache.
+   * - in other settings (single-sequence prefill, batch prefill, or 
speculation
+   *   verification), q_data has the **flattened layout** to handle raggedness:
+   *   (1, total query length, num_heads, head_dim).
+   * \param layer_id The model layer index of the current attention.
+   * \param output The attention output array.
+   * \param apply_rotary A boolean flag indicating if to apply RoPE.
+   * \param rotary_scale The RoPE scale if applicable.
+   * \param rotary_theta The RoPE theta if applicable.
+   */
+  void Attention(PackedFunc f_attention, NDArray q_data, int64_t layer_id, 
NDArray output,
+                 bool apply_rotary = false, double rotary_scale = 1.0f, double 
rotary_theta = 1e4) {
+    // Check q_data shape validity.
+    CHECK_EQ(q_data->ndim, 4);
+    CHECK_GT(q_data->shape[1], 0);
+    CHECK_EQ(q_data->shape[2], num_heads_);
+    CHECK_EQ(q_data->shape[3], head_dim_);
+    CHECK(q_data.DataType() == pages_.DataType());
+
+    if (q_data->shape[0] > 1) {
+      CHECK_EQ(q_data->shape[0], num_total_seqs_);
+      CHECK_EQ(q_data->shape[1], 1);
+    }
+    int64_t ntoken = 0;
+    for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) {
+      ntoken += cur_append_lengths_[seq_id];
+      CHECK_LE(cur_append_lengths_[seq_id], seq_lengths_[seq_id]);
+      if (q_data->shape[0] > 1) {
+        CHECK_EQ(cur_append_lengths_[seq_id], 1);
+      }
+    }
+    CHECK_EQ(ntoken, q_data->shape[0] * q_data->shape[1]);
+
+    // The auxiliary data structure on device must have been synchronized.
+    CHECK(!dirty_aux_data_device_)
+        << "The auxiliary arrays are not synchronized to device. Please call "
+           "`SyncAuxArrayToDevice` to synchronize before calling `Attention`.";
+
+    // Provide a 8MB size float temp buffer for the attention kernel.
+    static NDArray tmp_buffer =
+        NDArray::Empty({8 * 1024 * 1024}, DLDataType(DataType::Float(32)), 
pages_->device);
+
+    f_attention(q_data, pages_,                                                
                  //
+                page_table_indptr_device_.CreateView({num_total_seqs_ + 1}, 
dtype_aux_),         //

Review Comment:
   would be great these views are also created once during sync device and kept 
there without having to recreate on the fly



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to