This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3fad0109bf [Unity][Contrib] Add vLLM paged attention kernel (#16350)
3fad0109bf is described below
commit 3fad0109bf5ef2cc0ca37f0c47d207fd2ce99fa8
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Jan 5 12:44:08 2024 -0700
[Unity][Contrib] Add vLLM paged attention kernel (#16350)
* Add vllm kernels
* Add vllm paged attention v2
* [vllm] remove head_mapping
* revert cmake changes
* add kernel for copying cache blocks
* add kernel for reconstructing past kv tensors from cache
* lint
---------
Co-authored-by: Masahiro Masuda <[email protected]>
---
CMakeLists.txt | 1 +
cmake/modules/contrib/vllm.cmake | 25 +
licenses/LICENSE.vllm.txt | 201 +++++++
src/runtime/contrib/vllm/attention_kernels.cu | 774 ++++++++++++++++++++++++++
src/runtime/contrib/vllm/attention_utils.cuh | 55 ++
src/runtime/contrib/vllm/cache_alloc.cc | 55 ++
src/runtime/contrib/vllm/cache_kernels.cu | 234 ++++++++
src/runtime/contrib/vllm/dtype_float16.h | 697 +++++++++++++++++++++++
tests/lint/check_file_type.py | 1 +
tests/python/relax/test_contrib_vllm.py | 746 +++++++++++++++++++++++++
10 files changed, 2789 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fcd670429f..410ed439f1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -572,6 +572,7 @@ include(cmake/modules/contrib/VitisAI.cmake)
include(cmake/modules/contrib/Verilator.cmake)
include(cmake/modules/contrib/UMA.cmake)
include(cmake/modules/contrib/MSC.cmake)
+include(cmake/modules/contrib/vllm.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/RustExt.cmake)
diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake
new file mode 100644
index 0000000000..4a09edd02e
--- /dev/null
+++ b/cmake/modules/contrib/vllm.cmake
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(USE_VLLM)
+ message(STATUS "Build with vllm paged attention kernel.")
+ include_directories(src/runtime/contrib/vllm)
+ enable_language(CUDA)
+
+ tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu
src/runtime/contrib/vllm/*.cc)
+ list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC})
+endif(USE_VLLM)
diff --git a/licenses/LICENSE.vllm.txt b/licenses/LICENSE.vllm.txt
new file mode 100644
index 0000000000..261eeb9e9f
--- /dev/null
+++ b/licenses/LICENSE.vllm.txt
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/src/runtime/contrib/vllm/attention_kernels.cu
b/src/runtime/contrib/vllm/attention_kernels.cu
new file mode 100644
index 0000000000..fe6e974dad
--- /dev/null
+++ b/src/runtime/contrib/vllm/attention_kernels.cu
@@ -0,0 +1,774 @@
+/*
+ * Adapted from
+ *
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <type_traits>
+
+#include "attention_utils.cuh"
+#include "dtype_float16.h"
+
+#define WARP_SIZE 32
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
+
+namespace vllm {
+
+// Utility function for attention softmax.
+template <int NUM_WARPS>
+inline __device__ float block_sum(float* red_smem, float sum) {
+ // Decompose the thread index into warp / lane.
+ int warp = threadIdx.x / WARP_SIZE;
+ int lane = threadIdx.x % WARP_SIZE;
+
+ // Compute the sum per warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+ }
+
+ // Warp leaders store the data to shared memory.
+ if (lane == 0) {
+ red_smem[warp] = sum;
+ }
+
+ // Make sure the data is in shared memory.
+ __syncthreads();
+
+ // The warps compute the final sums.
+ if (lane < NUM_WARPS) {
+ sum = red_smem[lane];
+ }
+
+ // Parallel reduction inside the warp.
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+ }
+
+ // Broadcast to other threads.
+ return __shfl_sync(uint32_t(-1), sum, 0);
+}
+
+// TODO(woosuk): Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads,
max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
max_num_partitions]
+ scalar_t* __restrict__ out, // [num_seqs, num_heads,
max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
head_size/x, block_size, x]
+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
head_size, block_size]
+ const int num_kv_heads, const float scale,
+ const int* __restrict__ block_tables, // [num_seqs,
max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+ const int seq_idx = blockIdx.y;
+ const int partition_idx = blockIdx.z;
+ const int max_num_partitions = gridDim.z;
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+ const int context_len = context_lens[seq_idx];
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+ // No work to do. Terminate the thread block.
+ return;
+ }
+
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+ const int num_blocks_per_partition =
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
+ const int start_block_idx = USE_PARTITIONING ? partition_idx *
num_blocks_per_partition : 0;
+ const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition,
num_context_blocks);
+ const int num_blocks = end_block_idx - start_block_idx;
+
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
+ const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE,
context_len);
+ const int num_tokens = end_token_idx - start_token_idx;
+
+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ constexpr int NUM_THREAD_GROUPS =
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes
THREAD_GROUP_SIZE divides NUM_THREADS
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE,
WARP_SIZE);
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int thread_idx = threadIdx.x;
+ const int warp_idx = thread_idx / WARP_SIZE;
+ const int lane = thread_idx % WARP_SIZE;
+
+ const int head_idx = blockIdx.x;
+ const int num_heads = gridDim.x;
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+ const int kv_head_idx = head_idx / num_queries_per_kv;
+ const float alibi_slope = alibi_slopes == nullptr ? 0.f :
alibi_slopes[head_idx];
+
+ // A vector type to store a part of a key or a query.
+ // The vector size is configured in such a way that the threads in a thread
group
+ // fetch or compute 16 bytes at a time.
+ // For example, if the size of a thread group is 4 and the data type is half,
+ // then the vector size is 16 / (4 * sizeof(half)) == 2.
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+ using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+ using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+ // Load the query to registers.
+ // Each thread in a thread group has a different part of the query.
+ // For example, if the the thread group size is 4, then the first thread in
the group
+ // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5,
9, ...
+ // th vectors of the query, and so on.
+ // NOTE(woosuk): Because q is split from a qkv tensor, it may not be
contiguous.
+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i +=
NUM_THREAD_GROUPS) {
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+ q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr +
vec_idx * VEC_SIZE);
+ }
+ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with
a memory wall right
+ // before we use q_vecs
+
+ // Memory planning.
+ extern __shared__ char shared_mem[];
+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+ float* logits = reinterpret_cast<float*>(shared_mem);
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
+ // Each thread group fetches x elements from the key at a time.
+ constexpr int x = 16 / sizeof(scalar_t);
+ float qk_max = -FLT_MAX;
+
+ // Iterate over the key blocks.
+ // Each warp fetches a block of keys for each iteration.
+ // Each thread group in a warp fetches a key from the block, and computes
+ // dot product with the query.
+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+ block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it
to int64
+ // because int32 can lead to overflow when this variable is multiplied by
large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
+
+ // Load a key to registers.
+ // Each thread in a thread group has a different part of the key.
+ // For example, if the the thread group size is 4, then the first thread
in the group
+ // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5,
9, ... th
+ // vectors of the key, and so on.
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+ const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) %
BLOCK_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+ const scalar_t* k_ptr = k_cache + physical_block_number *
kv_block_stride +
+ kv_head_idx * kv_head_stride +
physical_block_offset * x;
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
+ k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 *
BLOCK_SIZE * x + offset2);
+ }
+
+ // Compute dot product.
+ // This includes a reduction across the threads in the same thread group.
+ float qk =
+ scale * Qk_dot<scalar_t,
THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+ // Add the ALiBi bias if slopes are given.
+ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) :
0;
+
+ if (thread_group_offset == 0) {
+ // Store the partial reductions to shared memory.
+ // NOTE(woosuk): It is required to zero out the masked logits.
+ const bool mask = token_idx >= context_len;
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+ // Update the max value.
+ qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+ }
+ }
+ }
+
+ // Perform reduction across the threads in the same warp to get the
+ // max qk value for each "warp" (not across the thread block yet).
+ // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = qk_max;
+ }
+ __syncthreads();
+
+ // TODO(woosuk): Refactor this part.
+ // Get the max qk value for the sequence.
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ }
+ // Broadcast the max qk value to all threads.
+ qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+
+ // Get the sum of the exp values.
+ float exp_sum = 0.f;
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ float val = __expf(logits[i] - qk_max);
+ logits[i] = val;
+ exp_sum += val;
+ }
+ exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+
+ // Compute softmax.
+ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ logits[i] *= inv_sum;
+ }
+ __syncthreads();
+
+ // If partitioning is enabled, store the max logit and exp_sum.
+ if (USE_PARTITIONING && thread_idx == 0) {
+ float* max_logits_ptr = max_logits + seq_idx * num_heads *
max_num_partitions +
+ head_idx * max_num_partitions + partition_idx;
+ *max_logits_ptr = qk_max;
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions + partition_idx;
+ *exp_sums_ptr = exp_sum;
+ }
+
+ // Each thread will fetch 16 bytes from the value cache at a time.
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+ using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+ using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+ using Float_L_vec = typename FloatVec<L_vec>::Type;
+
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+ constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE,
NUM_ROWS_PER_ITER);
+
+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+ float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ accs[i] = 0.f;
+ }
+
+ scalar_t zero_value;
+ zero(zero_value);
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+ block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it
to int64
+ // because int32 can lead to overflow when this variable is multiplied by
large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ L_vec logits_vec;
+ from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx
- start_token_idx));
+
+ const scalar_t* v_ptr =
+ v_cache + physical_block_number * kv_block_stride + kv_head_idx *
kv_head_stride;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE) {
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+ V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+ if (block_idx == num_context_blocks - 1) {
+ // NOTE(woosuk): When v_vec contains the tokens that are out of the
context,
+ // we should explicitly zero out the values since they may contain
NaNs.
+ // See
https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+ scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+#pragma unroll
+ for (int j = 0; j < V_VEC_SIZE; j++) {
+ v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] :
zero_value;
+ }
+ }
+ accs[i] += dot(logits_vec, v_vec);
+ }
+ }
+ }
+
+ // Perform reduction within each warp.
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ float acc = accs[i];
+#pragma unroll
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+ acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
+ }
+ accs[i] = acc;
+ }
+
+ // NOTE(woosuk): A barrier is required because the shared memory space for
logits
+ // is reused for the output.
+ __syncthreads();
+
+ // Perform reduction across warps.
+ float* out_smem = reinterpret_cast<float*>(shared_mem);
+#pragma unroll
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
+ int mid = i / 2;
+ // Upper warps write to shared memory.
+ if (warp_idx >= mid && warp_idx < i) {
+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ dst[row_idx] = accs[i];
+ }
+ }
+ }
+ __syncthreads();
+
+ // Lower warps update the output.
+ if (warp_idx < mid) {
+ const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ accs[i] += src[row_idx];
+ }
+ }
+ }
+ __syncthreads();
+ }
+
+ // Write the final output.
+ if (warp_idx == 0) {
+ scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions *
HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ from_float(*(out_ptr + row_idx), accs[i]);
+ }
+ }
+ }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
+ int NUM_THREADS>
+__global__ void paged_attention_v1_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
head_size/x, block_size, x]
+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
head_size, block_size]
+ const int num_kv_heads, const float scale,
+ const int* __restrict__ block_tables, // [num_seqs,
max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+ paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
+ /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads,
+ scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride,
+ kv_block_stride, kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads,
max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
max_num_partitions]
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
head_size/x, block_size, x]
+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
head_size, block_size]
+ const int num_kv_heads, const float scale,
+ const int* __restrict__ block_tables, // [num_seqs,
max_num_blocks_per_seq]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+ paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
PARTITION_SIZE>(
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables,
+ context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride,
+ kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs).
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads,
max_num_partitions]
+ const float* __restrict__ max_logits, // [num_seqs, num_heads,
max_num_partitions]
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
max_num_partitions, head_size]
+ const int* __restrict__ context_lens, // [num_seqs]
+ const int max_num_partitions) {
+ const int num_heads = gridDim.x;
+ const int head_idx = blockIdx.x;
+ const int seq_idx = blockIdx.y;
+ const int context_len = context_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ if (num_partitions == 1) {
+ // No need to reduce. Only copy tmp_out to out.
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx *
HEAD_SIZE;
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads *
max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE;
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+ out_ptr[i] = tmp_out_ptr[i];
+ }
+ // Terminate the thread block.
+ return;
+ }
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int warp_idx = threadIdx.x / WARP_SIZE;
+ const int lane = threadIdx.x % WARP_SIZE;
+
+ // Size: 2 * num_partitions.
+ extern __shared__ char shared_mem[];
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // Load max logits to shared memory.
+ float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+ const float* max_logits_ptr =
+ max_logits + seq_idx * num_heads * max_num_partitions + head_idx *
max_num_partitions;
+ float max_logit = -FLT_MAX;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ const float l = max_logits_ptr[i];
+ shared_max_logits[i] = l;
+ max_logit = fmaxf(max_logit, l);
+ }
+ __syncthreads();
+
+ // Get the global max logit.
+ // Reduce within the warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit,
mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = max_logit;
+ }
+ __syncthreads();
+ // Reduce across warps.
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit,
mask));
+ }
+ // Broadcast the max value to all threads.
+ max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
+
+ // Load rescaled exp sums to shared memory.
+ float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float)
* num_partitions);
+ const float* exp_sums_ptr =
+ exp_sums + seq_idx * num_heads * max_num_partitions + head_idx *
max_num_partitions;
+ float global_exp_sum = 0.0f;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ float l = shared_max_logits[i];
+ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+ global_exp_sum += rescaled_exp_sum;
+ shared_exp_sums[i] = rescaled_exp_sum;
+ }
+ __syncthreads();
+ global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+ // Aggregate tmp_out to out.
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads *
max_num_partitions * HEAD_SIZE +
+ head_idx * max_num_partitions * HEAD_SIZE;
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx *
HEAD_SIZE;
+#pragma unroll
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+ float acc = 0.0f;
+ for (int j = 0; j < num_partitions; ++j) {
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
+ }
+ from_float(out_ptr[i], acc);
+ }
+}
+
+} // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)
\
+ cudaFuncSetAttribute(vllm::paged_attention_v1_kernel<T, HEAD_SIZE,
BLOCK_SIZE, NUM_THREADS>, \
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem_size); \
+ vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>
\
+ <<<grid, block, shared_mem_size, stream>>>(
\
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,
scale, \
+ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq,
alibi_slopes_ptr, q_stride, \
+ kv_block_stride, kv_head_stride);
+
+// TODO(woosuk): Tune NUM_THREADS.
+template <typename T, int BLOCK_SIZE, int NUM_THREADS = 128>
+void paged_attention_v1_launcher(DLTensor* out, const DLTensor* query, const
DLTensor* key_cache,
+ const DLTensor* value_cache, float scale,
+ const DLTensor* block_tables, const DLTensor*
context_lens,
+ int max_context_len) {
+ int num_seqs = query->shape[0];
+ int num_heads = query->shape[1];
+ int head_size = query->shape[2];
+ int max_num_blocks_per_seq = block_tables->shape[1];
+ int q_stride = query->shape[1] * query->shape[2];
+
+ int num_kv_heads = key_cache->shape[1];
+ int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] *
key_cache->shape[4];
+ int kv_block_stride = kv_head_stride * key_cache->shape[1];
+
+ // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ // assert(head_size % thread_group_size == 0);
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = nullptr;
+
+ T* out_ptr = static_cast<T*>(out->data);
+ T* query_ptr = static_cast<T*>(query->data);
+ T* key_cache_ptr = static_cast<T*>(key_cache->data);
+ T* value_cache_ptr = static_cast<T*>(value_cache->data);
+ int* block_tables_ptr = static_cast<int*>(block_tables->data);
+ int* context_lens_ptr = static_cast<int*>(context_lens->data);
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) *
BLOCK_SIZE;
+ int logits_size = padded_max_context_len * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
+ // Keep that in sync with the logic here!
+ int shared_mem_size = std::max(logits_size, outputs_size);
+
+ dim3 grid(num_heads, num_seqs, 1);
+ dim3 block(NUM_THREADS);
+ const cudaStream_t stream = nullptr;
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V1(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V1(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V1(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V1(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V1(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V1(256);
+ break;
+ default:
+ // TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V1_LAUNCHER(T, BLOCK_SIZE)
\
+ paged_attention_v1_launcher<T, BLOCK_SIZE>(out, query, key_cache,
value_cache, scale, \
+ block_tables, context_lens,
max_context_len);
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)
\
+ vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
PARTITION_SIZE> \
+ <<<grid, block, shared_mem_size, stream>>>(
\
+ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr,
value_cache_ptr, \
+ num_kv_heads, scale, block_tables_ptr, context_lens_ptr,
max_num_blocks_per_seq, \
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride);
\
+ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,
PARTITION_SIZE> \
+ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(
\
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
context_lens_ptr, \
+ max_num_partitions);
+
+template <typename T, int BLOCK_SIZE, int NUM_THREADS = 128, int
PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+
+ DLTensor* out, DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out,
+ const DLTensor* query, const DLTensor* key_cache, const DLTensor*
value_cache, float scale,
+ const DLTensor* block_tables, const DLTensor* context_lens, int
max_context_len) {
+ int num_seqs = query->shape[0];
+ int num_heads = query->shape[1];
+ int head_size = query->shape[2];
+ int max_num_blocks_per_seq = block_tables->shape[1];
+ int q_stride = query->shape[1] * query->shape[2];
+
+ int num_kv_heads = key_cache->shape[1];
+ int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] *
key_cache->shape[4];
+ int kv_block_stride = kv_head_stride * key_cache->shape[1];
+
+ // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ // assert(head_size % thread_group_size == 0);
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = nullptr;
+
+ T* out_ptr = static_cast<T*>(out->data);
+ float* exp_sums_ptr = static_cast<float*>(exp_sums->data);
+ float* max_logits_ptr = static_cast<float*>(max_logits->data);
+ T* tmp_out_ptr = static_cast<T*>(tmp_out->data);
+ T* query_ptr = static_cast<T*>(query->data);
+ T* key_cache_ptr = static_cast<T*>(key_cache->data);
+ T* value_cache_ptr = static_cast<T*>(value_cache->data);
+ int* block_tables_ptr = static_cast<int*>(block_tables->data);
+ int* context_lens_ptr = static_cast<int*>(context_lens->data);
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+ int logits_size = PARTITION_SIZE * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+ // For paged attention v2 kernel.
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
+ int shared_mem_size = std::max(logits_size, outputs_size);
+ // For paged attention v2 reduce kernel.
+ dim3 reduce_grid(num_heads, num_seqs);
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+ dim3 block(NUM_THREADS);
+ const cudaStream_t stream = nullptr;
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V2(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V2(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V2(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V2(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V2(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V2(256);
+ break;
+ default:
+ // TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V2_LAUNCHER(T, BLOCK_SIZE)
\
+ paged_attention_v2_launcher<T, BLOCK_SIZE>(out, exp_sums, max_logits,
tmp_out, query, key_cache, \
+ value_cache, scale, block_tables,
context_lens, \
+ max_context_len);
+
+void single_query_cached_kv_attention_v1(
+ const DLTensor* query, const DLTensor* key_cache, const DLTensor*
value_cache,
+ const DLTensor* block_tables, const DLTensor* context_lens, int block_size,
+ const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer
+ DLTensor* out) {
+ float scale = 1.0 / sqrt(query->shape[2]);
+ int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+
+ using T = uint16_t; // for half precision
+ if (block_size == 8) {
+ CALL_V1_LAUNCHER(T, 8);
+ } else if (block_size == 16) {
+ CALL_V1_LAUNCHER(T, 16);
+ } else if (block_size == 32) {
+ CALL_V1_LAUNCHER(T, 32);
+ } else {
+ LOG(FATAL) << "Unsupported block size: " << block_size;
+ }
+}
+
+void single_query_cached_kv_attention_v2(
+ const DLTensor* query, const DLTensor* key_cache, const DLTensor*
value_cache,
+ const DLTensor* block_tables, const DLTensor* context_lens, int block_size,
+ const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer
+ DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, DLTensor*
out) {
+ float scale = 1.0 / sqrt(query->shape[2]);
+ int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+
+ using T = uint16_t; // for half precision
+ if (block_size == 8) {
+ CALL_V2_LAUNCHER(T, 8);
+ } else if (block_size == 16) {
+ CALL_V2_LAUNCHER(T, 16);
+ } else if (block_size == 32) {
+ CALL_V2_LAUNCHER(T, 32);
+ } else {
+ LOG(FATAL) << "Unsupported block size: " << block_size;
+ }
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention")
+ .set_body_typed([](const DLTensor* query, const DLTensor* key_cache,
+ const DLTensor* value_cache, const DLTensor*
block_tables,
+ const DLTensor* context_lens, int block_size,
+ const DLTensor* max_context_len_tensor, //
TODO(masahi): pass integer
+ DLTensor* exp_sums, DLTensor* max_logits, DLTensor*
tmp_out, DLTensor* out) {
+ int num_seqs = query->shape[0];
+ int num_heads = query->shape[1];
+ int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+ const int PARTITION_SIZE = 512;
+ int max_num_partitions = DIVIDE_ROUND_UP(max_context_len,
PARTITION_SIZE);
+ bool use_v1 =
+ max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs *
num_heads > 512);
+ if (use_v1) {
+ single_query_cached_kv_attention_v1(query, key_cache, value_cache,
block_tables,
+ context_lens, block_size,
max_context_len_tensor, out);
+ } else {
+ single_query_cached_kv_attention_v2(query, key_cache, value_cache,
block_tables,
+ context_lens, block_size,
max_context_len_tensor,
+ exp_sums, max_logits, tmp_out,
out);
+ }
+ });
+
+// Expose for testing
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1")
+ .set_body_typed(single_query_cached_kv_attention_v1);
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2")
+ .set_body_typed(single_query_cached_kv_attention_v2);
+
+} // namespace runtime
+} // namespace tvm
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
diff --git a/src/runtime/contrib/vllm/attention_utils.cuh
b/src/runtime/contrib/vllm/attention_utils.cuh
new file mode 100644
index 0000000000..85301e89aa
--- /dev/null
+++ b/src/runtime/contrib/vllm/attention_utils.cuh
@@ -0,0 +1,55 @@
+/*
+ * Adapted from
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "dtype_float16.h"
+
+#include <float.h>
+#include <type_traits>
+
+namespace vllm {
+
+// Q*K^T operation.
+template<int THREAD_GROUP_SIZE, typename Vec, int N>
+inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
+ using A_vec = typename FloatVec<Vec>::Type;
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
+ A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
+#pragma unroll
+ for (int ii = 1; ii < N; ++ii) {
+ qk_vec = fma(q[ii], k[ii], qk_vec);
+ }
+
+ // Finalize the reduction across lanes.
+ float qk = sum(qk_vec);
+#pragma unroll
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+ qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+ }
+ return qk;
+}
+
+template<typename T, int THREAD_GROUP_SIZE>
+struct Qk_dot {
+ template<typename Vec, int N>
+ static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
+ return qk_dot_<THREAD_GROUP_SIZE>(q, k);
+ }
+};
+
+} // namespace vllm
diff --git a/src/runtime/contrib/vllm/cache_alloc.cc
b/src/runtime/contrib/vllm/cache_alloc.cc
new file mode 100644
index 0000000000..aea50aa47a
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_alloc.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cuda_runtime.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace runtime {
+namespace vllm {
+
+Array<NDArray> AllocateKVCache(int head_size, int num_layers, int num_heads,
int block_size,
+ int num_blocks) {
+ Array<NDArray> cache;
+ int element_size = 2;
+ int vec_size = 16 / element_size;
+
+ int device_id;
+ cudaGetDevice(&device_id);
+
+ DLDevice dev{DLDeviceType::kDLCUDA, device_id};
+
+ for (int i = 0; i < num_layers; ++i) {
+ NDArray key_blocks =
+ NDArray::Empty({num_blocks, num_heads, head_size / vec_size,
block_size, vec_size},
+ runtime::DataType::Float(16), dev);
+ NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size,
block_size},
+ runtime::DataType::Float(16), dev);
+ cache.push_back(key_blocks);
+ cache.push_back(value_blocks);
+ }
+
+ return cache;
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache);
+
+} // namespace vllm
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/contrib/vllm/cache_kernels.cu
b/src/runtime/contrib/vllm/cache_kernels.cu
new file mode 100644
index 0000000000..537ff31fd0
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_kernels.cu
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <vector>
+
+namespace vllm {
+
+template <typename scalar_t>
+__global__ void reshape_and_cache_kernel(
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads,
head_size]
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads,
head_size]
+ scalar_t* __restrict__ key_cache, // [num_blocks, num_heads,
head_size/x, block_size, x]
+ scalar_t* __restrict__ value_cache, // [num_blocks, num_heads,
head_size, block_size]
+ const int* __restrict__ slot_mapping, // [num_tokens]
+ const int key_stride, const int value_stride, const int num_heads, const
int head_size,
+ const int block_size, const int x) {
+ const int token_idx = blockIdx.x;
+ const int slot_idx = slot_mapping[token_idx];
+ const int block_idx = slot_idx / block_size;
+ const int block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ const int src_key_idx = token_idx * key_stride + i;
+ const int src_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int tgt_key_idx = block_idx * num_heads * (head_size / x) *
block_size * x +
+ head_idx * (head_size / x) * block_size * x +
x_idx * block_size * x +
+ block_offset * x + x_offset;
+ const int tgt_value_idx = block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size + head_offset
* block_size +
+ block_offset;
+ key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
+ value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
+ }
+}
+
+template <typename scalar_t>
+__global__ void reconstruct_from_cache_kernel(
+ const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads,
head_size/x, block_size, x]
+ const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads,
head_size, block_size]
+ const int* __restrict__ slot_mapping, // [num_tokens]
+ scalar_t* __restrict__ key, // [num_tokens, num_heads,
head_size]
+ scalar_t* __restrict__ value, // [num_tokens, num_heads,
head_size]
+ const int key_stride, const int value_stride, const int num_heads, const
int head_size,
+ const int block_size, const int x) {
+ const int token_idx = blockIdx.x;
+ const int slot_idx = slot_mapping[token_idx];
+ const int block_idx = slot_idx / block_size;
+ const int block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ const int tgt_key_idx = token_idx * key_stride + i;
+ const int tgt_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int src_key_idx = block_idx * num_heads * (head_size / x) *
block_size * x +
+ head_idx * (head_size / x) * block_size * x +
x_idx * block_size * x +
+ block_offset * x + x_offset;
+ const int src_value_idx = block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size + head_offset
* block_size +
+ block_offset;
+
+ key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
+ value[src_value_idx] = __ldg(&value_cache[tgt_value_idx]);
+ }
+}
+
+// Grid: (num_layers, num_pairs)
+template <typename scalar_t>
+__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t*
value_cache_ptrs,
+ const int64_t* __restrict__ block_mapping,
+ const int numel_per_block) {
+ const int layer_idx = blockIdx.x;
+ const int pair_idx = blockIdx.y;
+
+ scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
+ scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
+ int64_t src_block_number = block_mapping[2 * pair_idx];
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+ const int64_t src_block_offset = src_block_number * numel_per_block;
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ key_cache[dst_offset] = key_cache[src_offset];
+ }
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ value_cache[dst_offset] = value_cache[src_offset];
+ }
+}
+
+} // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
+ .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, NDArray
value_cache,
+ NDArray slot_mapping) {
+ int num_tokens = key->shape[0];
+ int num_heads = key->shape[1];
+ int head_size = key->shape[2];
+ int block_size = key_cache->shape[3];
+ int vec_size = key_cache->shape[4];
+
+ int key_stride = key->shape[1] * key->shape[2];
+ int value_stride = value->shape[1] * value->shape[2];
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+
+ using scalar_t = uint16_t;
+ vllm::reshape_and_cache_kernel<scalar_t><<<grid, block>>>(
+ static_cast<const scalar_t*>(key->data), static_cast<const
scalar_t*>(value->data),
+ static_cast<scalar_t*>(key_cache->data),
static_cast<scalar_t*>(value_cache->data),
+ static_cast<const int*>(slot_mapping->data), key_stride,
value_stride, num_heads,
+ head_size, block_size, vec_size);
+
+ return Array{key_cache, value_cache};
+ });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache")
+ .set_body_typed([](NDArray key_cache, NDArray value_cache, NDArray
slot_mapping) {
+ int num_tokens = slot_mapping->shape[0];
+ int num_heads = value_cache->shape[1];
+ int head_size = value_cache->shape[2];
+ int block_size = value_cache->shape[3];
+ int vec_size = key_cache->shape[4];
+
+ DLDevice dev = key_cache->device;
+ auto key = NDArray::Empty({num_tokens, num_heads, head_size},
key_cache->dtype, dev);
+ auto value = NDArray::Empty({num_tokens, num_heads, head_size},
key_cache->dtype, dev);
+
+ int key_stride = key->shape[1] * key->shape[2];
+ int value_stride = value->shape[1] * value->shape[2];
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+
+ using scalar_t = uint16_t;
+ vllm::reconstruct_from_cache_kernel<scalar_t>
+ <<<grid, block>>>(static_cast<const scalar_t*>(key_cache->data),
+ static_cast<const scalar_t*>(value_cache->data),
+ static_cast<const int*>(slot_mapping->data),
+ static_cast<scalar_t*>(key->data),
static_cast<scalar_t*>(value->data),
+ key_stride, value_stride, num_heads, head_size,
block_size, vec_size);
+
+ return Array{key, value};
+ });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks")
+ .set_body_typed([](Array<NDArray> key_value_caches, NDArray block_mapping)
{
+ auto num_layers = key_value_caches.size() / 2;
+ auto num_pairs = block_mapping->shape[0] / 2;
+
+ if (num_layers == 0) {
+ return;
+ }
+
+ std::vector<int64_t> key_cache_ptrs(num_layers);
+ std::vector<int64_t> value_cache_ptrs(num_layers);
+ for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+ key_cache_ptrs[layer_idx] =
+ reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx]->data);
+ value_cache_ptrs[layer_idx] =
+ reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx +
1]->data);
+ }
+
+ NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads,
head_size, block_size]
+ DLDevice dev = key_cache->device;
+
+ NDArray key_cache_ptrs_gpu =
+ NDArray::Empty({static_cast<int>(num_layers)},
runtime::DataType::Int(64), dev);
+ NDArray value_cache_ptrs_gpu =
+ NDArray::Empty({static_cast<int>(num_layers)},
runtime::DataType::Int(64), dev);
+ key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(),
+ sizeof(int64_t) *
key_cache_ptrs.size());
+ value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(),
+ sizeof(int64_t) *
value_cache_ptrs.size());
+
+ NDArray block_mapping_gpu =
+ NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64),
dev);
+ block_mapping_gpu.CopyFromBytes(block_mapping->data,
+ sizeof(int64_t) *
block_mapping->shape[0]);
+
+ const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] *
key_cache->shape[3];
+ dim3 grid(num_layers, num_pairs);
+ dim3 block(std::min(1024, numel_per_block));
+
+ using scalar_t = uint16_t;
+ vllm::copy_blocks_kernel<scalar_t>
+ <<<grid, block>>>(static_cast<int64_t*>(key_cache_ptrs_gpu->data),
+ static_cast<int64_t*>(value_cache_ptrs_gpu->data),
+ static_cast<int64_t*>(block_mapping_gpu->data),
numel_per_block);
+ });
+
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/contrib/vllm/dtype_float16.h
b/src/runtime/contrib/vllm/dtype_float16.h
new file mode 100644
index 0000000000..4f6f52bb50
--- /dev/null
+++ b/src/runtime/contrib/vllm/dtype_float16.h
@@ -0,0 +1,697 @@
+/*
+ * Adapted from
+ *
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and
+ *
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <stdint.h>
+
+namespace vllm {
+
+// A vector type to store Q, K, V elements.
+template <typename T, int VEC_SIZE>
+struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template <typename T>
+struct FloatVec {};
+
+// Template vector operations.
+template <typename Acc, typename A, typename B>
+inline __device__ Acc mul(A a, B b);
+
+template <typename T>
+inline __device__ float sum(T v);
+
+template <typename T>
+inline __device__ float dot(T a, T b) {
+ return sum(mul<T, T, T>(a, b));
+}
+
+template <typename A, typename T>
+inline __device__ float dot(T a, T b) {
+ return sum(mul<A, T, T>(a, b));
+}
+
+template <typename T>
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void zero(T& dst) {
+ constexpr int WORDS = sizeof(T) / 4;
+ union {
+ T raw;
+ uint32_t words[WORDS];
+ } tmp;
+
+#pragma unroll
+ for (int ii = 0; ii < WORDS; ++ii) {
+ tmp.words[ii] = 0u;
+ }
+ dst = tmp.raw;
+}
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+ float2 x;
+ float2 y;
+};
+
+struct Float8_ {
+ float2 x;
+ float2 y;
+ float2 z;
+ float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template <>
+struct Vec<float, 1> {
+ using Type = float;
+};
+template <>
+struct Vec<float, 2> {
+ using Type = float2;
+};
+template <>
+struct Vec<float, 4> {
+ using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<float> {
+ using Type = float;
+};
+template <>
+struct FloatVec<float2> {
+ using Type = float2;
+};
+template <>
+struct FloatVec<float4> {
+ using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) { return a + b; }
+
+inline __device__ float2 add(float2 a, float2 b) {
+ float2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+ float4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ float mul<float, float>(float a, float b) {
+ return a * b;
+}
+
+template <>
+inline __device__ float2 mul(float2 a, float2 b) {
+ float2 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ return c;
+}
+
+template <>
+inline __device__ float2 mul(float a, float2 b) {
+ float2 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ return c;
+}
+
+template <>
+inline __device__ float4 mul(float4 a, float4 b) {
+ float4 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ c.z = a.z * b.z;
+ c.w = a.w * b.w;
+ return c;
+}
+
+template <>
+inline __device__ float4 mul(float a, float4 b) {
+ float4 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ c.z = a * b.z;
+ c.w = a * b.w;
+ return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) { return a * b + c; }
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+ Float4_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+ Float8_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(float v) {
+ return v;
+}
+
+template <>
+inline __device__ float sum(float2 v) {
+ return v.x + v.y;
+}
+
+template <>
+inline __device__ float sum(float4 v) {
+ return v.x + v.y + v.z + v.w;
+}
+
+template <>
+inline __device__ float sum(Float4_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template <>
+inline __device__ float sum(Float8_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) { return a * b; }
+
+inline __device__ float dot(float2 a, float2 b) {
+ float2 c = mul<float2, float2, float2>(a, b);
+ return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ acc = fma(a.z, b.z, acc);
+ acc = fma(a.w, b.w, acc);
+ return acc.x + acc.y;
+}
+
+// From float to float.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float& dst, float src) { dst = src; }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
+
+// From float to float.
+inline __device__ float to_float(float u) { return u; }
+
+inline __device__ float2 to_float(float2 u) { return u; }
+
+inline __device__ float4 to_float(float4 u) { return u; }
+
+inline __device__ Float4_ to_float(Float4_ u) { return u; }
+
+inline __device__ Float8_ to_float(Float8_ u) { return u; }
+
+// FP16 vector types for Q, K, V.
+template <>
+struct Vec<uint16_t, 1> {
+ using Type = uint16_t;
+};
+template <>
+struct Vec<uint16_t, 2> {
+ using Type = uint32_t;
+};
+template <>
+struct Vec<uint16_t, 4> {
+ using Type = uint2;
+};
+template <>
+struct Vec<uint16_t, 8> {
+ using Type = uint4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<uint16_t> {
+ using Type = float;
+};
+template <>
+struct FloatVec<uint32_t> {
+ using Type = float2;
+};
+template <>
+struct FloatVec<uint2> {
+ using Type = Float4_;
+};
+template <>
+struct FloatVec<uint4> {
+ using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ uint32_t h0_h0(uint16_t a) {
+ uint32_t b;
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+ return b;
+}
+
+inline __device__ float half_to_float(uint16_t h) {
+ float f;
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+ return f;
+}
+
+inline __device__ float2 half2_to_float2(uint32_t v) {
+ uint16_t lo, hi;
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+ return make_float2(half_to_float(lo), half_to_float(hi));
+}
+
+inline __device__ uint16_t float_to_half(float f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+ return tmp.u16[0];
+}
+
+inline __device__ uint32_t float2_to_half2(float2 f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y),
"f"(f.x));
+#else
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+#endif
+ return tmp.u32;
+}
+
+// Vector addition.
+inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+ uint16_t c;
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+ return c;
+}
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+ uint32_t c;
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+ return c;
+}
+
+inline __device__ uint2 add(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ uint4 add(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+inline __device__ float2 add(uint32_t a, float2 fb) {
+ float2 fa = half2_to_float2(a);
+ return add(fa, fb);
+}
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb) {
+ Float4_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ return fc;
+}
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb) {
+ Float8_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ fc.z = add(a.z, fb.z);
+ fc.w = add(a.w, fb.w);
+ return fc;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+ uint16_t c;
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+ return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+ uint32_t c;
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+ return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ uint2 mul(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+ return c;
+}
+
+template <>
+inline __device__ uint2 mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ uint2 c;
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+ return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
+ return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ uint4 c;
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
+ return c;
+}
+
+template <>
+inline __device__ float mul(uint16_t a, uint16_t b) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb;
+}
+
+template <>
+inline __device__ float2 mul(uint32_t a, uint32_t b) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return mul<float2, float2, float2>(fa, fb);
+}
+
+template <>
+inline __device__ float2 mul(uint16_t a, uint32_t b) {
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ Float4_ mul(uint2 a, uint2 b) {
+ Float4_ fc;
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+ return fc;
+}
+
+template <>
+inline __device__ Float4_ mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ Float4_ fc;
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+ return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint4 a, uint4 b) {
+ Float8_ fc;
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
+ return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ Float8_ fc;
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
+ return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+ uint32_t d;
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b),
"r"(c));
+ return d;
+}
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return
fma(h0_h0(a), b, c); }
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
+ uint2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
+ uint32_t s = h0_h0(a);
+ uint2 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
+ uint4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
+ uint32_t s = h0_h0(a);
+ uint4 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ d.z = fma(s, b.z, c.z);
+ d.w = fma(s, b.w, c.w);
+ return d;
+}
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb + fc;
+}
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return
fma(h0_h0(a), b, fc); }
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
+ Float4_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
+ uint32_t s = h0_h0(a);
+ Float4_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
+ Float8_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ fd.z = fma(a.z, b.z, fc.z);
+ fd.w = fma(a.w, b.w, fc.w);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
+ uint32_t s = h0_h0(a);
+ Float8_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ fd.z = fma(s, b.z, fc.z);
+ fd.w = fma(s, b.w, fc.w);
+ return fd;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(uint16_t v) {
+ return half_to_float(v);
+}
+
+template <>
+inline __device__ float sum(uint32_t v) {
+ float2 tmp = half2_to_float2(v);
+ return tmp.x + tmp.y;
+}
+
+template <>
+inline __device__ float sum(uint2 v) {
+ uint32_t c = add(v.x, v.y);
+ return sum(c);
+}
+
+template <>
+inline __device__ float sum(uint4 v) {
+ uint32_t c = add(v.x, v.y);
+ c = add(c, v.z);
+ c = add(c, v.w);
+ return sum(c);
+}
+
+// From float32 to float16.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint16_t& dst, float src) { dst =
float_to_half(src); }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint32_t& dst, float2 src) { dst =
float2_to_half2(src); }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint2& dst, Float4_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+}
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint4& dst, Float8_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+ dst.z = float2_to_half2(src.z);
+ dst.w = float2_to_half2(src.w);
+}
+
+// From float16 to float32.
+inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
+
+inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
+
+inline __device__ Float4_ to_float(uint2 u) {
+ Float4_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ return tmp;
+}
+
+inline __device__ Float8_ to_float(uint4 u) {
+ Float8_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ tmp.z = half2_to_float2(u.z);
+ tmp.w = half2_to_float2(u.w);
+ return tmp;
+}
+
+// Zero-out a variable.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
+
+} // namespace vllm
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index f3d9b9df32..757c00e0e3 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -46,6 +46,7 @@ ALLOW_EXTENSION = {
"pyd",
"pyx",
"cu",
+ "cuh",
"bat",
# relay text format
"rly",
diff --git a/tests/python/relax/test_contrib_vllm.py
b/tests/python/relax/test_contrib_vllm.py
new file mode 100644
index 0000000000..b674e1c9fb
--- /dev/null
+++ b/tests/python/relax/test_contrib_vllm.py
@@ -0,0 +1,746 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm.testing
+import tvm
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+has_vllm =
tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True)
+
+vllm_enabled = pytest.mark.skipif(
+ not has_vllm,
+ reason="VLLM not enabled.",
+)
+
+pytestmark = [vllm_enabled]
+
+
+def build_and_run(mod, inputs_np, target, legalize=True):
+ if legalize:
+ mod = relax.transform.LegalizeOps()(mod)
+
+ with tvm.target.Target("cuda"):
+ mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+ with tvm.transform.PassContext():
+ ex = relax.build(mod, target)
+
+ dev = tvm.device(target, 0)
+ vm = relax.VirtualMachine(ex, dev)
+ f = vm["main"]
+ inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+
+ out = f(*inputs)
+
+ if isinstance(out, tvm.ir.container.Array):
+ return [arr.numpy() for arr in out]
+
+ return out.numpy()
+
+
+def test_attention():
+ @I.ir_module
+ class ModulePagedAttentionV1:
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ query: R.Tensor(("num_seqs", 1, 64), dtype="float16"),
+ key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+ value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"),
+ block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"),
dtype="int32"),
+ context_lens: R.Tensor(("num_seqs",), dtype="int32"),
+ ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"):
+ with R.dataflow():
+ max_len = R.to_vdevice(R.max(context_lens), "llvm:0")
+ out = R.call_dps_packed(
+ "tvm.contrib.vllm.single_query_cached_kv_attention_v1",
+ [
+ query,
+ key_cache,
+ value_cache,
+ block_tables,
+ context_lens,
+ 16,
+ max_len,
+ ],
+ out_sinfo=query.struct_info,
+ )
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class ModulePagedAttentionV2:
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ query: R.Tensor(("num_seqs", 1, 64), dtype="float16"),
+ key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+ value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"),
+ block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"),
dtype="int32"),
+ context_lens: R.Tensor(("num_seqs",), dtype="int32"),
+ ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"):
+ with R.dataflow():
+ num_seqs = T.int64()
+ max_len = R.to_vdevice(R.max(context_lens), "llvm:0")
+ # alloc workspace
+ exp_sums = R.zeros((num_seqs, 1, 1), "float32")
+ max_logits = R.zeros((num_seqs, 1, 1), "float32")
+ tmp_out = R.zeros((num_seqs, 1, 1, 64), "float16")
+
+ out = R.call_dps_packed(
+ "tvm.contrib.vllm.single_query_cached_kv_attention_v2",
+ [
+ query,
+ key_cache,
+ value_cache,
+ block_tables,
+ context_lens,
+ 16,
+ max_len,
+ exp_sums,
+ max_logits,
+ tmp_out,
+ ],
+ out_sinfo=query.struct_info,
+ )
+ R.output(out)
+ return out
+
+ np.random.seed(0)
+ num_heads = 1
+ head_dim = 64
+ vec_size = 8
+ block_size = 16
+ num_seqs = 2
+ num_blocks = 1
+ query = np.random.randn(num_seqs, num_heads, head_dim).astype("float16")
+ key_cache = np.random.randn(
+ num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+ ).astype("float16")
+ value_cache = np.random.randn(num_blocks, num_heads, head_dim,
block_size).astype("float16")
+ block_tables = np.array([[0], [0]]).astype("int32")
+ context_lens = np.array([3, 5]).astype("int32")
+
+ out_v1 = build_and_run(
+ ModulePagedAttentionV1,
+ [query, key_cache, value_cache, block_tables, context_lens],
+ "cuda",
+ legalize=True,
+ )
+
+ out_v2 = build_and_run(
+ ModulePagedAttentionV2,
+ [query, key_cache, value_cache, block_tables, context_lens],
+ "cuda",
+ legalize=True,
+ )
+
+ ref = np.array(
+ [
+ [
+ [
+ 0.28271484375,
+ 0.197021484375,
+ -0.278564453125,
+ 0.444580078125,
+ -0.47802734375,
+ -0.7548828125,
+ -0.84228515625,
+ -0.80322265625,
+ 0.478759765625,
+ 0.195068359375,
+ -0.59521484375,
+ 0.779296875,
+ 0.35888671875,
+ -0.158935546875,
+ -0.6103515625,
+ 0.188720703125,
+ 0.410400390625,
+ 0.28662109375,
+ 0.40283203125,
+ -1.23046875,
+ -0.01043701171875,
+ -0.0794677734375,
+ -0.0350341796875,
+ 0.12005615234375,
+ 0.63671875,
+ 0.368896484375,
+ -0.58642578125,
+ -0.34228515625,
+ -0.552734375,
+ 0.947265625,
+ -0.079833984375,
+ 0.85302734375,
+ 0.1947021484375,
+ 0.16748046875,
+ -0.083984375,
+ -0.75244140625,
+ -0.568359375,
+ -1.45703125,
+ -1.021484375,
+ -0.2235107421875,
+ -0.98828125,
+ -0.87109375,
+ -0.43359375,
+ -0.3271484375,
+ 0.0557861328125,
+ -0.269287109375,
+ -1.009765625,
+ 0.1387939453125,
+ -0.0831298828125,
+ 0.27978515625,
+ -0.9736328125,
+ 0.7802734375,
+ -0.1329345703125,
+ -0.5927734375,
+ -1.6923828125,
+ 1.1904296875,
+ -1.3759765625,
+ -1.080078125,
+ -0.53173828125,
+ 0.28466796875,
+ -2.02734375,
+ -0.377685546875,
+ -0.81201171875,
+ -0.7412109375,
+ ]
+ ],
+ [
+ [
+ 0.482177734375,
+ 0.114501953125,
+ -0.265869140625,
+ -1.154296875,
+ 0.28857421875,
+ 0.71240234375,
+ -1.1767578125,
+ 0.187744140625,
+ -0.23486328125,
+ 0.07135009765625,
+ -0.34521484375,
+ 0.444091796875,
+ -0.09130859375,
+ 0.900390625,
+ -0.043701171875,
+ 0.61279296875,
+ 0.1201171875,
+ 0.443603515625,
+ -0.4150390625,
+ -0.9560546875,
+ -0.1917724609375,
+ 0.0863037109375,
+ 0.267578125,
+ 0.04931640625,
+ -0.32666015625,
+ 0.5859375,
+ -0.57421875,
+ 0.29541015625,
+ -0.26220703125,
+ 1.177734375,
+ 0.11309814453125,
+ 0.81201171875,
+ 0.346435546875,
+ 0.53271484375,
+ -0.0009765625,
+ -0.35205078125,
+ -0.1298828125,
+ -1.2431640625,
+ -0.2196044921875,
+ 0.31640625,
+ -0.40869140625,
+ 0.25244140625,
+ -0.9853515625,
+ 0.284912109375,
+ 0.399169921875,
+ -1.1435546875,
+ 0.305419921875,
+ 0.300048828125,
+ -0.84521484375,
+ -0.5166015625,
+ -0.787109375,
+ 0.1011962890625,
+ -1.0302734375,
+ -1.35546875,
+ -0.0556640625,
+ 1.0791015625,
+ -0.047607421875,
+ -0.498046875,
+ -0.055999755859375,
+ -0.35009765625,
+ -1.4296875,
+ 0.350341796875,
+ -1.16796875,
+ -0.576171875,
+ ]
+ ],
+ ]
+ ).astype("float16")
+
+ # from vllm import attention_ops
+ # import torch
+ #
+ # def to_torch(arr):
+ # return torch.from_numpy(arr).to("cuda")
+ #
+ # ref = to_torch(np.zeros_like(query))
+ # attention_ops.single_query_cached_kv_attention(
+ # ref,
+ # to_torch(query),
+ # to_torch(key_cache),
+ # to_torch(value_cache),
+ # num_kv_heads,
+ # query.shape[-1] ** -0.5, # scale
+ # to_torch(block_tables),
+ # to_torch(context_lens),
+ # value_cache.shape[-1], # block_size,
+ # np.max(context_lens),
+ # None,
+ # )
+ # ref = ref.cpu().numpy()
+
+ # print(ref.tolist())
+
+ for out in [out_v1, out_v2]:
+ assert np.max(np.abs(ref - out)) == 0.0
+
+
+def test_cache():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ key: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+ value: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+ key_cache: R.Tensor(("num_blocks", 1, 1, 16, 8), dtype="float16"),
+ value_cache: R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+ slot_mapping: R.Tensor(("num_tokens",), dtype="int32"),
+ ) -> R.Tuple(
+ [
+ R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+ R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+ ]
+ ):
+ with R.dataflow():
+ kv = R.call_pure_packed(
+ "tvm.contrib.vllm.reshape_and_cache",
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slot_mapping,
+ sinfo_args=[key_cache.struct_info,
value_cache.struct_info],
+ )
+ out = (kv[0], kv[1])
+ R.output(out)
+ return out
+
+ np.random.seed(0)
+ num_heads = 1
+ head_dim = 8
+ vec_size = 8
+ block_size = 16
+ num_tokens = 8
+ num_blocks = 1
+ key = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+ value = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+ key_cache_before = np.random.randn(
+ num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+ ).astype("float16")
+ value_cache_before = np.random.randn(num_blocks, num_heads, head_dim,
block_size).astype(
+ "float16"
+ )
+ slot_mapping = np.arange(num_tokens).astype("int32")
+
+ key_cache = key_cache_before.copy()
+ value_cache = value_cache_before.copy()
+
+ out_key_cache, out_value_cache = build_and_run(
+ Module,
+ [key, value, key_cache, value_cache, slot_mapping],
+ "cuda",
+ )
+
+ ref_key_cache = np.array(
+ [
+ [
+ [
+ [
+ [
+ 1.763671875,
+ 0.400146484375,
+ 0.978515625,
+ 2.240234375,
+ 1.8671875,
+ -0.97705078125,
+ 0.9501953125,
+ -0.1513671875,
+ ],
+ [
+ -0.10321044921875,
+ 0.41064453125,
+ 0.14404296875,
+ 1.4541015625,
+ 0.76123046875,
+ 0.1217041015625,
+ 0.44384765625,
+ 0.333740234375,
+ ],
+ [
+ 1.494140625,
+ -0.2052001953125,
+ 0.31298828125,
+ -0.85400390625,
+ -2.552734375,
+ 0.65380859375,
+ 0.8642578125,
+ -0.7421875,
+ ],
+ [
+ 2.26953125,
+ -1.4541015625,
+ 0.045745849609375,
+ -0.1871337890625,
+ 1.533203125,
+ 1.4697265625,
+ 0.1549072265625,
+ 0.378173828125,
+ ],
+ [
+ -0.8876953125,
+ -1.98046875,
+ -0.347900390625,
+ 0.1563720703125,
+ 1.23046875,
+ 1.2021484375,
+ -0.38720703125,
+ -0.30224609375,
+ ],
+ [
+ -1.048828125,
+ -1.419921875,
+ -1.7060546875,
+ 1.951171875,
+ -0.509765625,
+ -0.43798828125,
+ -1.2529296875,
+ 0.77734375,
+ ],
+ [
+ -1.6142578125,
+ -0.2127685546875,
+ -0.8955078125,
+ 0.386962890625,
+ -0.5107421875,
+ -1.1806640625,
+ -0.0281829833984375,
+ 0.42822265625,
+ ],
+ [
+ 0.0665283203125,
+ 0.302490234375,
+ -0.63427734375,
+ -0.36279296875,
+ -0.67236328125,
+ -0.359619140625,
+ -0.81298828125,
+ -1.7265625,
+ ],
+ [
+ -0.039276123046875,
+ -1.16796875,
+ 0.5234375,
+ -0.1715087890625,
+ 0.77197265625,
+ 0.82373046875,
+ 2.1640625,
+ 1.3369140625,
+ ],
+ [
+ -0.369140625,
+ -0.2393798828125,
+ 1.099609375,
+ 0.6552734375,
+ 0.64013671875,
+ -1.6171875,
+ -0.024322509765625,
+ -0.73779296875,
+ ],
+ [
+ 0.280029296875,
+ -0.09814453125,
+ 0.91015625,
+ 0.317138671875,
+ 0.7861328125,
+ -0.46630859375,
+ -0.9443359375,
+ -0.41015625,
+ ],
+ [
+ -0.0170135498046875,
+ 0.379150390625,
+ 2.259765625,
+ -0.042266845703125,
+ -0.9560546875,
+ -0.345947265625,
+ -0.463623046875,
+ 0.4814453125,
+ ],
+ [
+ -1.541015625,
+ 0.063232421875,
+ 0.156494140625,
+ 0.232177734375,
+ -0.59716796875,
+ -0.2379150390625,
+ -1.423828125,
+ -0.493408203125,
+ ],
+ [
+ -0.54296875,
+ 0.416015625,
+ -1.15625,
+ 0.78125,
+ 1.494140625,
+ -2.0703125,
+ 0.42626953125,
+ 0.6767578125,
+ ],
+ [
+ -0.63720703125,
+ -0.397216796875,
+ -0.1329345703125,
+ -0.2978515625,
+ -0.30908203125,
+ -1.67578125,
+ 1.15234375,
+ 1.080078125,
+ ],
+ [
+ -0.8134765625,
+ -1.466796875,
+ 0.52099609375,
+ -0.57568359375,
+ 0.1419677734375,
+ -0.3193359375,
+ 0.69140625,
+ 0.69482421875,
+ ],
+ ]
+ ]
+ ]
+ ]
+ ).astype("float16")
+
+ ref_value_cache = np.array(
+ [
+ [
+ [
+ [
+ 0.1773681640625,
+ 1.1396484375,
+ -1.1650390625,
+ -1.0703125,
+ 0.010498046875,
+ -1.1728515625,
+ -0.861328125,
+ 0.37646484375,
+ -1.9365234375,
+ 0.188720703125,
+ 0.52392578125,
+ 0.08843994140625,
+ -0.310791015625,
+ 0.097412109375,
+ 0.39892578125,
+ -2.7734375,
+ ],
+ [
+ -0.40185546875,
+ -1.234375,
+ 0.90087890625,
+ 1.0546875,
+ 1.7861328125,
+ 1.943359375,
+ 1.91015625,
+ -1.099609375,
+ -0.11053466796875,
+ 1.0205078125,
+ -0.69189453125,
+ 1.5361328125,
+ 0.286376953125,
+ 0.60888671875,
+ -1.044921875,
+ 1.2109375,
+ ],
+ [
+ -1.6298828125,
+ 0.40234375,
+ 0.465576171875,
+ -0.403076171875,
+ 0.126953125,
+ -0.41357421875,
+ -0.26806640625,
+ 0.29833984375,
+ 0.09771728515625,
+ 0.5830078125,
+ -0.3994140625,
+ 0.3701171875,
+ -1.306640625,
+ 1.658203125,
+ -0.1181640625,
+ -0.68017578125,
+ ],
+ [
+ 0.462890625,
+ -0.6845703125,
+ -1.5361328125,
+ 1.22265625,
+ 0.402099609375,
+ -0.74755859375,
+ 0.80224609375,
+ 1.326171875,
+ -1.126953125,
+ -0.73046875,
+ -0.384765625,
+ 0.0943603515625,
+ -0.04217529296875,
+ -0.286865234375,
+ -0.061614990234375,
+ -0.1072998046875,
+ ],
+ [
+ -0.9072265625,
+ -0.87060546875,
+ 1.48828125,
+ 0.208251953125,
+ 1.8828125,
+ 1.9228515625,
+ 0.947265625,
+ -0.6943359375,
+ -0.70458984375,
+ 0.943359375,
+ 0.7470703125,
+ -1.1884765625,
+ 0.7734375,
+ -1.18359375,
+ -2.658203125,
+ 0.6064453125,
+ ],
+ [
+ 0.05194091796875,
+ -0.57861328125,
+ 1.8955078125,
+ 0.9765625,
+ -1.34765625,
+ 1.48046875,
+ -0.155029296875,
+ -0.149658203125,
+ -0.44091796875,
+ -0.2802734375,
+ -0.36474609375,
+ 0.15673828125,
+ 0.57861328125,
+ 0.349609375,
+ -0.76416015625,
+ -1.4375,
+ ],
+ [
+ 0.72900390625,
+ -0.3115234375,
+ 1.1787109375,
+ 0.3564453125,
+ -1.2705078125,
+ 1.8671875,
+ 0.6142578125,
+ -0.43505859375,
+ 0.6982421875,
+ 0.0037708282470703125,
+ 0.931640625,
+ 0.33984375,
+ -0.01568603515625,
+ 0.160888671875,
+ -0.190673828125,
+ -0.394775390625,
+ ],
+ [
+ 0.1290283203125,
+ 0.05615234375,
+ -0.179931640625,
+ 0.70654296875,
+ 0.96923828125,
+ 0.90625,
+ 0.92236328125,
+ 1.849609375,
+ 0.6435546875,
+ -1.5703125,
+ -0.2069091796875,
+ 0.88037109375,
+ -1.6982421875,
+ 0.38720703125,
+ -2.255859375,
+ -1.0224609375,
+ ],
+ ]
+ ]
+ ]
+ ).astype("float16")
+
+ # from vllm import cache_ops
+ # import torch
+
+ # def to_torch(arr):
+ # return torch.from_numpy(arr).to("cuda")
+
+ # ref_key_cache = to_torch(key_cache_before.copy())
+ # ref_value_cache = to_torch(value_cache_before.copy())
+
+ # cache_ops.reshape_and_cache(
+ # to_torch(key),
+ # to_torch(value),
+ # ref_key_cache,
+ # ref_value_cache,
+ # to_torch(slot_mapping),
+ # )
+
+ # ref_key_cache = ref_key_cache.cpu().numpy()
+ # ref_value_cache = ref_value_cache.cpu().numpy()
+
+ assert np.max(np.abs(out_key_cache - ref_key_cache)) == 0
+ assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0
+
+
+if __name__ == "__main__":
+ tvm.testing.main()