tqchen commented on code in PR #15910:
URL: https://github.com/apache/tvm/pull/15910#discussion_r1356863129


##########
src/runtime/relax_vm/paged_kv_cache.cc:
##########
@@ -0,0 +1,633 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/runtime/relax_vm/paged_kv_cache.cc
+ * \brief Runtime paged KV cache object for language models.
+ */
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+//-------------------------------------------
+// We keep the implementation private as
+// they may subject to future changes.
+//
+// Users can interact with it through the
+// runtime API function calls
+//-------------------------------------------
+
+/*!
+ * \brief The paged KV cache for attention.
+ * - It supports managing the K/V data of **multiple sequences**.
+ * - It manages K/V values by doing paging along the sequence-length
+ * dimension with a configured page size.
+ * - The basic example use of the paged KV cache after initialization
+ * in each round of model forwarding is the following:
+ *   - step 1. use `ResetAppendLengths` to reset the appending information
+ *     for preparation,
+ *   - step 2. use `ReserveExtraLengthForAppend` to specify the length
+ *     of K/V data to be appended for each sequence,
+ *   - step 3. use `SyncAuxArrayToDevice` to synchronize auxiliary arrays
+ *     to device for append/attention computation,
+ *   - step 4. for each layer, use `Append` to append the K/V data to the
+ *     cache, and then use `Attention` to compute attention results with
+ *     Q data.
+ */
+class PagedAttentionKVCacheObj : public Object {
+ private:
+  /*! \brief The total number of sequences managed in the KV cache. */
+  int64_t num_total_seqs_ = 0;
+  /*! \brief The number of pages that are in use by the sequences. */
+  int64_t num_pages_in_use_ = 0;
+  /*!
+   * \brief The number of allocated pages, including the in-use pages
+   * and the pages released due to sequence removal.
+   */
+  int64_t num_pages_allocated_ = 0;
+
+  /********************* Configuration *********************/
+
+  /*! \brief The page size (the sequence length each page manages) of the 
cache. */
+  const int64_t page_size_;
+  /*! \brief The number of layers in the model. */
+  const int64_t num_layers_;
+  /*! \brief The number of heads in the model. */
+  const int64_t num_heads_;
+  /*! \brief The number of features each head has. */
+  const int64_t head_dim_;
+
+  /*! \brief We fix int32 to be the index dtype of auxiliary data. */
+  const DLDataType dtype_aux_ = DataType::Int(32, 1).operator DLDataType();
+
+  /********************* Page Structures *********************/
+
+  /*!
+   * \brief The KV data managed by the KV cache.
+   * It has layout (num_pages, num_layers, 2, num_heads, page_size, head_dim).
+   * Along on the "2" dimension, index 0 stands for K and 1 stands for V.
+   */
+  NDArray pages_;
+  /*! \brief The list of ids of released pages for page reuse. */
+  std::vector<int32_t> free_page_ids_;
+
+  /*! \brief The list of page ids assigned for each sequence in the cache. */
+  std::vector<std::vector<int32_t>> page_table_;
+  /*! \brief The lengths of each sequence in the cache. */
+  std::vector<int32_t> seq_lengths_;
+
+  /********************* Current Batch Info *********************/
+
+  /*!
+   * \brief The current lengths to append for each sequence.
+   * - The new K/V data appended to the cache must have the same length
+   * as stored in this array.
+   * - The Q data passed in for attention must also have the same length
+   * as stored.
+   * \note Invoke "ResetAppendLengths" to reset this array to all-zero.
+   */
+  std::vector<int64_t> cur_append_lengths_;
+
+  /********************* Auxiliary Arrays on Device *********************/
+  //-------------------------------------------
+  // The following fields are auxiliary arrays on device.
+  // All of them are directly derivable from the fields above.
+  // We store them for efficient execution of attentions,
+  // cache append, etc.
+  //-------------------------------------------
+  /*!
+   * \brief A boolean flag indicating if the auxiliary arrays are dirty.
+   * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
+   */
+  bool dirty_aux_data_device_ = false;
+  /*!
+   * \brief The page table indptr array on device.
+   * \note Since page table is a ragged data structure, we represent it
+   * in CSR format (which uses indptr and values below) on device.
+   */
+  NDArray page_table_indptr_device_;
+  /*! \brief The page table value array on device. */
+  NDArray page_table_values_device_;
+  /*!
+   * \brief The array storing "the number of used slots in the last page
+   * of each sequence" on device. Its values range in (0, page_size_].
+   */
+  NDArray last_page_offset_device_;
+  /*!
+   * \brief The append_length indptr array on device.
+   * \note Since the Q/K/V data may have raggedness in terms of lengths,
+   * we represent the the append lengths in CSR format.
+   */
+  NDArray cur_append_length_indptr_device_;
+  /*!
+   * \brief The corresponding sequence id for each position along the
+   * length dimension of K/V data. It is used for efficient computation.
+   */
+  NDArray cur_pos2seqid_device_;
+
+ public:
+  /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
+  explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers, 
int64_t num_heads,
+                                    int64_t head_dim, int64_t 
reserved_num_seqs,
+                                    int64_t reserved_num_pages, DLDataType 
dtype, DLDevice device)
+      : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads), 
head_dim_(head_dim) {
+    pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads, 
page_size, head_dim},
+                            dtype, device);
+    page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+    page_table_values_device_ = NDArray::Empty({reserved_num_pages}, 
dtype_aux_, device);
+    last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, 
device);
+    cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+    cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size}, 
dtype_aux_, device);
+  }
+
+  /*!
+   * \brief Given a sequence id and a required extra length, allocate new
+   * new pages for the sequence until the total capacity can cover the
+   * current sequence length plus the required extra length.
+   * \note By reserving extra length for the sequence, the subsequent appended
+   * K/V data for the sequence must have the same length as the input length 
here.
+   * \param seq_id The id of the sequence to process.
+   * \param extra_length The extra length to reserve for the sequence.
+   */
+  void ReserveExtraLengthForAppend(int64_t seq_id, int64_t extra_length) {
+    CHECK_GE(seq_id, 0) << "Input sequence id should be positive";
+    CHECK_LE(seq_id, num_total_seqs_)
+        << "Invalid input sequence id " << seq_id
+        << ", which is larger than the current total number of sequences.";
+    CHECK_GT(extra_length, 0) << "The input length should be positive.";
+
+    if (seq_id == num_total_seqs_) {

Review Comment:
   I think it is useful to have a separate function to grow the sequence to 
keep the concern of concurrent sequence growing and seq len growing separately



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to