This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bbb8f40d2592b5beec42eb8ea76abe635060ed25
Author: Masahiro Masuda <[email protected]>
AuthorDate: Mon Dec 11 06:49:28 2023 +0000

    add kernel for copying cache blocks
---
 src/runtime/contrib/vllm/cache_kernels.cu | 69 +++++++++++++++++++++++++++++++
 1 file changed, 69 insertions(+)

diff --git a/src/runtime/contrib/vllm/cache_kernels.cu 
b/src/runtime/contrib/vllm/cache_kernels.cu
index 29ab9bfa2e..349b8e8e50 100644
--- a/src/runtime/contrib/vllm/cache_kernels.cu
+++ b/src/runtime/contrib/vllm/cache_kernels.cu
@@ -69,6 +69,35 @@ __global__ void reshape_and_cache_kernel(
   }
 }
 
+// Grid: (num_layers, num_pairs)
+template<typename scalar_t>
+__global__ void copy_blocks_kernel(
+  int64_t* key_cache_ptrs,
+  int64_t* value_cache_ptrs,
+  const int64_t* __restrict__ block_mapping,
+  const int numel_per_block) {
+  const int layer_idx = blockIdx.x;
+  const int pair_idx = blockIdx.y;
+
+  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
+  scalar_t* value_cache = 
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
+  int64_t src_block_number = block_mapping[2 * pair_idx];
+  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+  const int64_t src_block_offset = src_block_number * numel_per_block;
+  const int64_t dst_block_offset = dst_block_number * numel_per_block;
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    key_cache[dst_offset] = key_cache[src_offset];
+  }
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    value_cache[dst_offset] = value_cache[src_offset];
+  }
+}
+
 }  // namespace vllm
 
 namespace tvm {
@@ -105,5 +134,45 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
 
       return Array{key_cache, value_cache};
     });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks")
+    .set_body_typed([](Array<NDArray> key_value_caches, NDArray block_mapping) 
{
+      auto num_layers = key_value_caches.size() / 2;
+      auto num_pairs = block_mapping->shape[0] / 2;
+
+      if (num_layers == 0) {
+        return;
+      }
+
+      std::vector<int64_t> key_cache_ptrs(num_layers);
+      std::vector<int64_t> value_cache_ptrs(num_layers);
+      for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+        key_cache_ptrs[layer_idx] = 
reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx]->data);
+        value_cache_ptrs[layer_idx] = 
reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx + 1]->data);
+      }
+
+      NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads, 
head_size, block_size]
+      DLDevice dev = key_cache->device;
+
+      NDArray key_cache_ptrs_gpu = 
NDArray::Empty({static_cast<int>(num_layers)}, runtime::DataType::Int(64), dev);
+      NDArray value_cache_ptrs_gpu = 
NDArray::Empty({static_cast<int>(num_layers)}, runtime::DataType::Int(64), dev);
+      key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), sizeof(int64_t) 
* key_cache_ptrs.size());
+      value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), 
sizeof(int64_t) * value_cache_ptrs.size());
+
+      NDArray block_mapping_gpu = NDArray::Empty(block_mapping.Shape(), 
runtime::DataType::Int(64), dev);
+      block_mapping_gpu.CopyFromBytes(block_mapping->data, sizeof(int64_t) * 
block_mapping->shape[0]);
+
+      const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] * 
key_cache->shape[3];
+      dim3 grid(num_layers, num_pairs);
+      dim3 block(std::min(1024, numel_per_block));
+
+      using scalar_t = uint16_t;
+      vllm::copy_blocks_kernel<scalar_t><<<grid, block>>>(
+        static_cast<int64_t*>(key_cache_ptrs_gpu->data),
+        static_cast<int64_t*>(value_cache_ptrs_gpu->data),
+        static_cast<int64_t*>(block_mapping_gpu->data),
+        numel_per_block);
+    });
+
 }  // namespace runtime
 }  // namespace tvm

Reply via email to