This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bbb8f40d2592b5beec42eb8ea76abe635060ed25 Author: Masahiro Masuda <[email protected]> AuthorDate: Mon Dec 11 06:49:28 2023 +0000 add kernel for copying cache blocks --- src/runtime/contrib/vllm/cache_kernels.cu | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index 29ab9bfa2e..349b8e8e50 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -69,6 +69,35 @@ __global__ void reshape_and_cache_kernel( } } +// Grid: (num_layers, num_pairs) +template<typename scalar_t> +__global__ void copy_blocks_kernel( + int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + } // namespace vllm namespace tvm { @@ -105,5 +134,45 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") return Array{key_cache, value_cache}; }); + +TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") + .set_body_typed([](Array<NDArray> key_value_caches, NDArray block_mapping) { + auto num_layers = key_value_caches.size() / 2; + auto num_pairs = block_mapping->shape[0] / 2; + + if (num_layers == 0) { + return; + } + + std::vector<int64_t> key_cache_ptrs(num_layers); + std::vector<int64_t> value_cache_ptrs(num_layers); + for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx]->data); + value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx + 1]->data); + } + + NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads, head_size, block_size] + DLDevice dev = key_cache->device; + + NDArray key_cache_ptrs_gpu = NDArray::Empty({static_cast<int>(num_layers)}, runtime::DataType::Int(64), dev); + NDArray value_cache_ptrs_gpu = NDArray::Empty({static_cast<int>(num_layers)}, runtime::DataType::Int(64), dev); + key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), sizeof(int64_t) * key_cache_ptrs.size()); + value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), sizeof(int64_t) * value_cache_ptrs.size()); + + NDArray block_mapping_gpu = NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); + block_mapping_gpu.CopyFromBytes(block_mapping->data, sizeof(int64_t) * block_mapping->shape[0]); + + const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] * key_cache->shape[3]; + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + + using scalar_t = uint16_t; + vllm::copy_blocks_kernel<scalar_t><<<grid, block>>>( + static_cast<int64_t*>(key_cache_ptrs_gpu->data), + static_cast<int64_t*>(value_cache_ptrs_gpu->data), + static_cast<int64_t*>(block_mapping_gpu->data), + numel_per_block); + }); + } // namespace runtime } // namespace tvm
