This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3fad0109bf [Unity][Contrib] Add vLLM paged attention kernel (#16350)
3fad0109bf is described below

commit 3fad0109bf5ef2cc0ca37f0c47d207fd2ce99fa8
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Jan 5 12:44:08 2024 -0700

    [Unity][Contrib] Add vLLM paged attention kernel (#16350)
    
    * Add vllm kernels
    
    * Add vllm paged attention v2
    
    * [vllm] remove head_mapping
    
    * revert cmake changes
    
    * add kernel for copying cache blocks
    
    * add kernel for reconstructing past kv tensors from cache
    
    * lint
    
    ---------
    
    Co-authored-by: Masahiro Masuda <[email protected]>
---
 CMakeLists.txt                                |   1 +
 cmake/modules/contrib/vllm.cmake              |  25 +
 licenses/LICENSE.vllm.txt                     | 201 +++++++
 src/runtime/contrib/vllm/attention_kernels.cu | 774 ++++++++++++++++++++++++++
 src/runtime/contrib/vllm/attention_utils.cuh  |  55 ++
 src/runtime/contrib/vllm/cache_alloc.cc       |  55 ++
 src/runtime/contrib/vllm/cache_kernels.cu     | 234 ++++++++
 src/runtime/contrib/vllm/dtype_float16.h      | 697 +++++++++++++++++++++++
 tests/lint/check_file_type.py                 |   1 +
 tests/python/relax/test_contrib_vllm.py       | 746 +++++++++++++++++++++++++
 10 files changed, 2789 insertions(+)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index fcd670429f..410ed439f1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -572,6 +572,7 @@ include(cmake/modules/contrib/VitisAI.cmake)
 include(cmake/modules/contrib/Verilator.cmake)
 include(cmake/modules/contrib/UMA.cmake)
 include(cmake/modules/contrib/MSC.cmake)
+include(cmake/modules/contrib/vllm.cmake)
 include(cmake/modules/Git.cmake)
 include(cmake/modules/LibInfo.cmake)
 include(cmake/modules/RustExt.cmake)
diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake
new file mode 100644
index 0000000000..4a09edd02e
--- /dev/null
+++ b/cmake/modules/contrib/vllm.cmake
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(USE_VLLM)
+  message(STATUS "Build with vllm paged attention kernel.")
+  include_directories(src/runtime/contrib/vllm)
+  enable_language(CUDA)
+
+  tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu 
src/runtime/contrib/vllm/*.cc)
+  list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC})
+endif(USE_VLLM)
diff --git a/licenses/LICENSE.vllm.txt b/licenses/LICENSE.vllm.txt
new file mode 100644
index 0000000000..261eeb9e9f
--- /dev/null
+++ b/licenses/LICENSE.vllm.txt
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/src/runtime/contrib/vllm/attention_kernels.cu 
b/src/runtime/contrib/vllm/attention_kernels.cu
new file mode 100644
index 0000000000..fe6e974dad
--- /dev/null
+++ b/src/runtime/contrib/vllm/attention_kernels.cu
@@ -0,0 +1,774 @@
+/*
+ * Adapted from
+ * 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <float.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <type_traits>
+
+#include "attention_utils.cuh"
+#include "dtype_float16.h"
+
+#define WARP_SIZE 32
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))
+
+namespace vllm {
+
+// Utility function for attention softmax.
+template <int NUM_WARPS>
+inline __device__ float block_sum(float* red_smem, float sum) {
+  // Decompose the thread index into warp / lane.
+  int warp = threadIdx.x / WARP_SIZE;
+  int lane = threadIdx.x % WARP_SIZE;
+
+  // Compute the sum per warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+  }
+
+  // Warp leaders store the data to shared memory.
+  if (lane == 0) {
+    red_smem[warp] = sum;
+  }
+
+  // Make sure the data is in shared memory.
+  __syncthreads();
+
+  // The warps compute the final sums.
+  if (lane < NUM_WARPS) {
+    sum = red_smem[lane];
+  }
+
+  // Parallel reduction inside the warp.
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+  }
+
+  // Broadcast to other threads.
+  return __shfl_sync(uint32_t(-1), sum, 0);
+}
+
+// TODO(woosuk): Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE = 0>  // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+    float* __restrict__ exp_sums,          // [num_seqs, num_heads, 
max_num_partitions]
+    float* __restrict__ max_logits,        // [num_seqs, num_heads, 
max_num_partitions]
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, 
max_num_partitions, head_size]
+    const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads, 
head_size/x, block_size, x]
+    const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads, 
head_size, block_size]
+    const int num_kv_heads, const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, 
max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+  const int seq_idx = blockIdx.y;
+  const int partition_idx = blockIdx.z;
+  const int max_num_partitions = gridDim.z;
+  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+  const int context_len = context_lens[seq_idx];
+  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+    // No work to do. Terminate the thread block.
+    return;
+  }
+
+  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+  const int num_blocks_per_partition =
+      USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+  // [start_block_idx, end_block_idx) is the range of blocks to process.
+  const int start_block_idx = USE_PARTITIONING ? partition_idx * 
num_blocks_per_partition : 0;
+  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, 
num_context_blocks);
+  const int num_blocks = end_block_idx - start_block_idx;
+
+  // [start_token_idx, end_token_idx) is the range of tokens to process.
+  const int start_token_idx = start_block_idx * BLOCK_SIZE;
+  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, 
context_len);
+  const int num_tokens = end_token_idx - start_token_idx;
+
+  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  constexpr int NUM_THREAD_GROUPS =
+      NUM_THREADS / THREAD_GROUP_SIZE;  // Note: This assumes 
THREAD_GROUP_SIZE divides NUM_THREADS
+  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, 
WARP_SIZE);
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int thread_idx = threadIdx.x;
+  const int warp_idx = thread_idx / WARP_SIZE;
+  const int lane = thread_idx % WARP_SIZE;
+
+  const int head_idx = blockIdx.x;
+  const int num_heads = gridDim.x;
+  const int num_queries_per_kv = num_heads / num_kv_heads;
+  const int kv_head_idx = head_idx / num_queries_per_kv;
+  const float alibi_slope = alibi_slopes == nullptr ? 0.f : 
alibi_slopes[head_idx];
+
+  // A vector type to store a part of a key or a query.
+  // The vector size is configured in such a way that the threads in a thread 
group
+  // fetch or compute 16 bytes at a time.
+  // For example, if the size of a thread group is 4 and the data type is half,
+  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+
+  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+  // Load the query to registers.
+  // Each thread in a thread group has a different part of the query.
+  // For example, if the the thread group size is 4, then the first thread in 
the group
+  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 
9, ...
+  // th vectors of the query, and so on.
+  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be 
contiguous.
+  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += 
NUM_THREAD_GROUPS) {
+    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + 
vec_idx * VEC_SIZE);
+  }
+  __syncthreads();  // TODO(naed90): possible speedup if this is replaced with 
a memory wall right
+                    // before we use q_vecs
+
+  // Memory planning.
+  extern __shared__ char shared_mem[];
+  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+  float* logits = reinterpret_cast<float*>(shared_mem);
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // x == THREAD_GROUP_SIZE * VEC_SIZE
+  // Each thread group fetches x elements from the key at a time.
+  constexpr int x = 16 / sizeof(scalar_t);
+  float qk_max = -FLT_MAX;
+
+  // Iterate over the key blocks.
+  // Each warp fetches a block of keys for each iteration.
+  // Each thread group in a warp fetches a key from the block, and computes
+  // dot product with the query.
+  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
+    // NOTE(woosuk): The block number is stored in int32. However, we cast it 
to int64
+    // because int32 can lead to overflow when this variable is multiplied by 
large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = 
static_cast<int64_t>(block_table[block_idx]);
+
+    // Load a key to registers.
+    // Each thread in a thread group has a different part of the key.
+    // For example, if the the thread group size is 4, then the first thread 
in the group
+    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 
9, ... th
+    // vectors of the key, and so on.
+    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % 
BLOCK_SIZE;
+      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+      K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+        const scalar_t* k_ptr = k_cache + physical_block_number * 
kv_block_stride +
+                                kv_head_idx * kv_head_stride + 
physical_block_offset * x;
+        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+        const int offset1 = (vec_idx * VEC_SIZE) / x;
+        const int offset2 = (vec_idx * VEC_SIZE) % x;
+        k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * 
BLOCK_SIZE * x + offset2);
+      }
+
+      // Compute dot product.
+      // This includes a reduction across the threads in the same thread group.
+      float qk =
+          scale * Qk_dot<scalar_t, 
THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+      // Add the ALiBi bias if slopes are given.
+      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 
0;
+
+      if (thread_group_offset == 0) {
+        // Store the partial reductions to shared memory.
+        // NOTE(woosuk): It is required to zero out the masked logits.
+        const bool mask = token_idx >= context_len;
+        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+        // Update the max value.
+        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+      }
+    }
+  }
+
+  // Perform reduction across the threads in the same warp to get the
+  // max qk value for each "warp" (not across the thread block yet).
+  // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = qk_max;
+  }
+  __syncthreads();
+
+  // TODO(woosuk): Refactor this part.
+  // Get the max qk value for the sequence.
+  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+  }
+  // Broadcast the max qk value to all threads.
+  qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+
+  // Get the sum of the exp values.
+  float exp_sum = 0.f;
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    float val = __expf(logits[i] - qk_max);
+    logits[i] = val;
+    exp_sum += val;
+  }
+  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+
+  // Compute softmax.
+  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    logits[i] *= inv_sum;
+  }
+  __syncthreads();
+
+  // If partitioning is enabled, store the max logit and exp_sum.
+  if (USE_PARTITIONING && thread_idx == 0) {
+    float* max_logits_ptr = max_logits + seq_idx * num_heads * 
max_num_partitions +
+                            head_idx * max_num_partitions + partition_idx;
+    *max_logits_ptr = qk_max;
+    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+                          head_idx * max_num_partitions + partition_idx;
+    *exp_sums_ptr = exp_sum;
+  }
+
+  // Each thread will fetch 16 bytes from the value cache at a time.
+  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using Float_L_vec = typename FloatVec<L_vec>::Type;
+
+  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, 
NUM_ROWS_PER_ITER);
+
+  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+  float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    accs[i] = 0.f;
+  }
+
+  scalar_t zero_value;
+  zero(zero_value);
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
+    // NOTE(woosuk): The block number is stored in int32. However, we cast it 
to int64
+    // because int32 can lead to overflow when this variable is multiplied by 
large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = 
static_cast<int64_t>(block_table[block_idx]);
+    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+    L_vec logits_vec;
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx 
- start_token_idx));
+
+    const scalar_t* v_ptr =
+        v_cache + physical_block_number * kv_block_stride + kv_head_idx * 
kv_head_stride;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE) {
+        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+        if (block_idx == num_context_blocks - 1) {
+          // NOTE(woosuk): When v_vec contains the tokens that are out of the 
context,
+          // we should explicitly zero out the values since they may contain 
NaNs.
+          // See 
https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+#pragma unroll
+          for (int j = 0; j < V_VEC_SIZE; j++) {
+            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : 
zero_value;
+          }
+        }
+        accs[i] += dot(logits_vec, v_vec);
+      }
+    }
+  }
+
+  // Perform reduction within each warp.
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    float acc = accs[i];
+#pragma unroll
+    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+      acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
+    }
+    accs[i] = acc;
+  }
+
+  // NOTE(woosuk): A barrier is required because the shared memory space for 
logits
+  // is reused for the output.
+  __syncthreads();
+
+  // Perform reduction across warps.
+  float* out_smem = reinterpret_cast<float*>(shared_mem);
+#pragma unroll
+  for (int i = NUM_WARPS; i > 1; i /= 2) {
+    int mid = i / 2;
+    // Upper warps write to shared memory.
+    if (warp_idx >= mid && warp_idx < i) {
+      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          dst[row_idx] = accs[i];
+        }
+      }
+    }
+    __syncthreads();
+
+    // Lower warps update the output.
+    if (warp_idx < mid) {
+      const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          accs[i] += src[row_idx];
+        }
+      }
+    }
+    __syncthreads();
+  }
+
+  // Write the final output.
+  if (warp_idx == 0) {
+    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * 
HEAD_SIZE +
+                        head_idx * max_num_partitions * HEAD_SIZE + 
partition_idx * HEAD_SIZE;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+        from_float(*(out_ptr + row_idx), accs[i]);
+      }
+    }
+  }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS>
+__global__ void paged_attention_v1_kernel(
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads, 
head_size/x, block_size, x]
+    const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads, 
head_size, block_size]
+    const int num_kv_heads, const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, 
max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
+      /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, 
v_cache, num_kv_heads,
+      scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, 
q_stride,
+      kv_block_stride, kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+    float* __restrict__ exp_sums,          // [num_seqs, num_heads, 
max_num_partitions]
+    float* __restrict__ max_logits,        // [num_seqs, num_heads, 
max_num_partitions]
+    scalar_t* __restrict__ tmp_out,        // [num_seqs, num_heads, 
max_num_partitions, head_size]
+    const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads, 
head_size/x, block_size, x]
+    const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads, 
head_size, block_size]
+    const int num_kv_heads, const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, 
max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride) {
+  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, 
PARTITION_SIZE>(
+      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, 
block_tables,
+      context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, 
kv_block_stride,
+      kv_head_stride);
+}
+
+// Grid: (num_heads, num_seqs).
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const float* __restrict__ exp_sums,    // [num_seqs, num_heads, 
max_num_partitions]
+    const float* __restrict__ max_logits,  // [num_seqs, num_heads, 
max_num_partitions]
+    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads, 
max_num_partitions, head_size]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_partitions) {
+  const int num_heads = gridDim.x;
+  const int head_idx = blockIdx.x;
+  const int seq_idx = blockIdx.y;
+  const int context_len = context_lens[seq_idx];
+  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+  if (num_partitions == 1) {
+    // No need to reduce. Only copy tmp_out to out.
+    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * 
HEAD_SIZE;
+    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * 
max_num_partitions * HEAD_SIZE +
+                                  head_idx * max_num_partitions * HEAD_SIZE;
+    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+      out_ptr[i] = tmp_out_ptr[i];
+    }
+    // Terminate the thread block.
+    return;
+  }
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int warp_idx = threadIdx.x / WARP_SIZE;
+  const int lane = threadIdx.x % WARP_SIZE;
+
+  // Size: 2 * num_partitions.
+  extern __shared__ char shared_mem[];
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // Load max logits to shared memory.
+  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+  const float* max_logits_ptr =
+      max_logits + seq_idx * num_heads * max_num_partitions + head_idx * 
max_num_partitions;
+  float max_logit = -FLT_MAX;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    const float l = max_logits_ptr[i];
+    shared_max_logits[i] = l;
+    max_logit = fmaxf(max_logit, l);
+  }
+  __syncthreads();
+
+  // Get the global max logit.
+  // Reduce within the warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, 
mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = max_logit;
+  }
+  __syncthreads();
+  // Reduce across warps.
+  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, 
mask));
+  }
+  // Broadcast the max value to all threads.
+  max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
+
+  // Load rescaled exp sums to shared memory.
+  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) 
* num_partitions);
+  const float* exp_sums_ptr =
+      exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * 
max_num_partitions;
+  float global_exp_sum = 0.0f;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    float l = shared_max_logits[i];
+    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+    global_exp_sum += rescaled_exp_sum;
+    shared_exp_sums[i] = rescaled_exp_sum;
+  }
+  __syncthreads();
+  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+  // Aggregate tmp_out to out.
+  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * 
max_num_partitions * HEAD_SIZE +
+                                head_idx * max_num_partitions * HEAD_SIZE;
+  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * 
HEAD_SIZE;
+#pragma unroll
+  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+    float acc = 0.0f;
+    for (int j = 0; j < num_partitions; ++j) {
+      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * 
inv_global_exp_sum;
+    }
+    from_float(out_ptr[i], acc);
+  }
+}
+
+}  // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                   
                   \
+  cudaFuncSetAttribute(vllm::paged_attention_v1_kernel<T, HEAD_SIZE, 
BLOCK_SIZE, NUM_THREADS>,    \
+                       cudaFuncAttributeMaxDynamicSharedMemorySize, 
shared_mem_size);             \
+  vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>       
                   \
+      <<<grid, block, shared_mem_size, stream>>>(                              
                   \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, 
scale,                \
+          block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, 
alibi_slopes_ptr, q_stride, \
+          kv_block_stride, kv_head_stride);
+
+// TODO(woosuk): Tune NUM_THREADS.
+template <typename T, int BLOCK_SIZE, int NUM_THREADS = 128>
+void paged_attention_v1_launcher(DLTensor* out, const DLTensor* query, const 
DLTensor* key_cache,
+                                 const DLTensor* value_cache, float scale,
+                                 const DLTensor* block_tables, const DLTensor* 
context_lens,
+                                 int max_context_len) {
+  int num_seqs = query->shape[0];
+  int num_heads = query->shape[1];
+  int head_size = query->shape[2];
+  int max_num_blocks_per_seq = block_tables->shape[1];
+  int q_stride = query->shape[1] * query->shape[2];
+
+  int num_kv_heads = key_cache->shape[1];
+  int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] * 
key_cache->shape[4];
+  int kv_block_stride = kv_head_stride * key_cache->shape[1];
+
+  // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  // assert(head_size % thread_group_size == 0);
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = nullptr;
+
+  T* out_ptr = static_cast<T*>(out->data);
+  T* query_ptr = static_cast<T*>(query->data);
+  T* key_cache_ptr = static_cast<T*>(key_cache->data);
+  T* value_cache_ptr = static_cast<T*>(value_cache->data);
+  int* block_tables_ptr = static_cast<int*>(block_tables->data);
+  int* context_lens_ptr = static_cast<int*>(context_lens->data);
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * 
BLOCK_SIZE;
+  int logits_size = padded_max_context_len * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
+  // Keep that in sync with the logic here!
+  int shared_mem_size = std::max(logits_size, outputs_size);
+
+  dim3 grid(num_heads, num_seqs, 1);
+  dim3 block(NUM_THREADS);
+  const cudaStream_t stream = nullptr;
+  switch (head_size) {
+    // NOTE(woosuk): To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V1(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V1(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V1(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V1(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V1(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V1(256);
+      break;
+    default:
+      // TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V1_LAUNCHER(T, BLOCK_SIZE)                                        
         \
+  paged_attention_v1_launcher<T, BLOCK_SIZE>(out, query, key_cache, 
value_cache, scale, \
+                                             block_tables, context_lens, 
max_context_len);
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   
                 \
+  vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, 
PARTITION_SIZE>        \
+      <<<grid, block, shared_mem_size, stream>>>(                              
                 \
+          exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, 
value_cache_ptr, \
+          num_kv_heads, scale, block_tables_ptr, context_lens_ptr, 
max_num_blocks_per_seq,      \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride);        
                 \
+  vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, 
PARTITION_SIZE>             \
+      <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                
                 \
+          out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, 
context_lens_ptr,                 \
+          max_num_partitions);
+
+template <typename T, int BLOCK_SIZE, int NUM_THREADS = 128, int 
PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+
+    DLTensor* out, DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out,
+    const DLTensor* query, const DLTensor* key_cache, const DLTensor* 
value_cache, float scale,
+    const DLTensor* block_tables, const DLTensor* context_lens, int 
max_context_len) {
+  int num_seqs = query->shape[0];
+  int num_heads = query->shape[1];
+  int head_size = query->shape[2];
+  int max_num_blocks_per_seq = block_tables->shape[1];
+  int q_stride = query->shape[1] * query->shape[2];
+
+  int num_kv_heads = key_cache->shape[1];
+  int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] * 
key_cache->shape[4];
+  int kv_block_stride = kv_head_stride * key_cache->shape[1];
+
+  // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  // assert(head_size % thread_group_size == 0);
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = nullptr;
+
+  T* out_ptr = static_cast<T*>(out->data);
+  float* exp_sums_ptr = static_cast<float*>(exp_sums->data);
+  float* max_logits_ptr = static_cast<float*>(max_logits->data);
+  T* tmp_out_ptr = static_cast<T*>(tmp_out->data);
+  T* query_ptr = static_cast<T*>(query->data);
+  T* key_cache_ptr = static_cast<T*>(key_cache->data);
+  T* value_cache_ptr = static_cast<T*>(value_cache->data);
+  int* block_tables_ptr = static_cast<int*>(block_tables->data);
+  int* context_lens_ptr = static_cast<int*>(context_lens->data);
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+  int logits_size = PARTITION_SIZE * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+  // For paged attention v2 kernel.
+  dim3 grid(num_heads, num_seqs, max_num_partitions);
+  int shared_mem_size = std::max(logits_size, outputs_size);
+  // For paged attention v2 reduce kernel.
+  dim3 reduce_grid(num_heads, num_seqs);
+  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+  dim3 block(NUM_THREADS);
+  const cudaStream_t stream = nullptr;
+  switch (head_size) {
+    // NOTE(woosuk): To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V2(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V2(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V2(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V2(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V2(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V2(256);
+      break;
+    default:
+      // TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V2_LAUNCHER(T, BLOCK_SIZE)                                        
                    \
+  paged_attention_v2_launcher<T, BLOCK_SIZE>(out, exp_sums, max_logits, 
tmp_out, query, key_cache, \
+                                             value_cache, scale, block_tables, 
context_lens,       \
+                                             max_context_len);
+
+void single_query_cached_kv_attention_v1(
+    const DLTensor* query, const DLTensor* key_cache, const DLTensor* 
value_cache,
+    const DLTensor* block_tables, const DLTensor* context_lens, int block_size,
+    const DLTensor* max_context_len_tensor,  // TODO(masahi): pass integer
+    DLTensor* out) {
+  float scale = 1.0 / sqrt(query->shape[2]);
+  int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+
+  using T = uint16_t;  // for half precision
+  if (block_size == 8) {
+    CALL_V1_LAUNCHER(T, 8);
+  } else if (block_size == 16) {
+    CALL_V1_LAUNCHER(T, 16);
+  } else if (block_size == 32) {
+    CALL_V1_LAUNCHER(T, 32);
+  } else {
+    LOG(FATAL) << "Unsupported block size: " << block_size;
+  }
+}
+
+void single_query_cached_kv_attention_v2(
+    const DLTensor* query, const DLTensor* key_cache, const DLTensor* 
value_cache,
+    const DLTensor* block_tables, const DLTensor* context_lens, int block_size,
+    const DLTensor* max_context_len_tensor,  // TODO(masahi): pass integer
+    DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out, DLTensor* 
out) {
+  float scale = 1.0 / sqrt(query->shape[2]);
+  int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+
+  using T = uint16_t;  // for half precision
+  if (block_size == 8) {
+    CALL_V2_LAUNCHER(T, 8);
+  } else if (block_size == 16) {
+    CALL_V2_LAUNCHER(T, 16);
+  } else if (block_size == 32) {
+    CALL_V2_LAUNCHER(T, 32);
+  } else {
+    LOG(FATAL) << "Unsupported block size: " << block_size;
+  }
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention")
+    .set_body_typed([](const DLTensor* query, const DLTensor* key_cache,
+                       const DLTensor* value_cache, const DLTensor* 
block_tables,
+                       const DLTensor* context_lens, int block_size,
+                       const DLTensor* max_context_len_tensor,  // 
TODO(masahi): pass integer
+                       DLTensor* exp_sums, DLTensor* max_logits, DLTensor* 
tmp_out, DLTensor* out) {
+      int num_seqs = query->shape[0];
+      int num_heads = query->shape[1];
+      int max_context_len = static_cast<int*>(max_context_len_tensor->data)[0];
+      const int PARTITION_SIZE = 512;
+      int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, 
PARTITION_SIZE);
+      bool use_v1 =
+          max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs * 
num_heads > 512);
+      if (use_v1) {
+        single_query_cached_kv_attention_v1(query, key_cache, value_cache, 
block_tables,
+                                            context_lens, block_size, 
max_context_len_tensor, out);
+      } else {
+        single_query_cached_kv_attention_v2(query, key_cache, value_cache, 
block_tables,
+                                            context_lens, block_size, 
max_context_len_tensor,
+                                            exp_sums, max_logits, tmp_out, 
out);
+      }
+    });
+
+// Expose for testing
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1")
+    .set_body_typed(single_query_cached_kv_attention_v1);
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2")
+    .set_body_typed(single_query_cached_kv_attention_v2);
+
+}  // namespace runtime
+}  // namespace tvm
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
diff --git a/src/runtime/contrib/vllm/attention_utils.cuh 
b/src/runtime/contrib/vllm/attention_utils.cuh
new file mode 100644
index 0000000000..85301e89aa
--- /dev/null
+++ b/src/runtime/contrib/vllm/attention_utils.cuh
@@ -0,0 +1,55 @@
+/*
+ * Adapted from 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "dtype_float16.h"
+
+#include <float.h>
+#include <type_traits>
+
+namespace vllm {
+
+// Q*K^T operation.
+template<int THREAD_GROUP_SIZE, typename Vec, int N>
+inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
+  using A_vec = typename FloatVec<Vec>::Type;
+  // Compute the parallel products for Q*K^T (treat vector lanes separately).
+  A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
+#pragma unroll
+  for (int ii = 1; ii < N; ++ii) {
+    qk_vec = fma(q[ii], k[ii], qk_vec);
+  }
+
+  // Finalize the reduction across lanes.
+  float qk = sum(qk_vec);
+#pragma unroll
+  for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+    qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+  }
+  return qk;
+}
+
+template<typename T, int THREAD_GROUP_SIZE>
+struct Qk_dot {
+  template<typename Vec, int N>
+  static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
+    return qk_dot_<THREAD_GROUP_SIZE>(q, k);
+  }
+};
+
+} // namespace vllm
diff --git a/src/runtime/contrib/vllm/cache_alloc.cc 
b/src/runtime/contrib/vllm/cache_alloc.cc
new file mode 100644
index 0000000000..aea50aa47a
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_alloc.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cuda_runtime.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace runtime {
+namespace vllm {
+
+Array<NDArray> AllocateKVCache(int head_size, int num_layers, int num_heads, 
int block_size,
+                               int num_blocks) {
+  Array<NDArray> cache;
+  int element_size = 2;
+  int vec_size = 16 / element_size;
+
+  int device_id;
+  cudaGetDevice(&device_id);
+
+  DLDevice dev{DLDeviceType::kDLCUDA, device_id};
+
+  for (int i = 0; i < num_layers; ++i) {
+    NDArray key_blocks =
+        NDArray::Empty({num_blocks, num_heads, head_size / vec_size, 
block_size, vec_size},
+                       runtime::DataType::Float(16), dev);
+    NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size, 
block_size},
+                                          runtime::DataType::Float(16), dev);
+    cache.push_back(key_blocks);
+    cache.push_back(value_blocks);
+  }
+
+  return cache;
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache);
+
+}  // namespace vllm
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/vllm/cache_kernels.cu 
b/src/runtime/contrib/vllm/cache_kernels.cu
new file mode 100644
index 0000000000..537ff31fd0
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_kernels.cu
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <vector>
+
+namespace vllm {
+
+template <typename scalar_t>
+__global__ void reshape_and_cache_kernel(
+    const scalar_t* __restrict__ key,      // [num_tokens, num_heads, 
head_size]
+    const scalar_t* __restrict__ value,    // [num_tokens, num_heads, 
head_size]
+    scalar_t* __restrict__ key_cache,      // [num_blocks, num_heads, 
head_size/x, block_size, x]
+    scalar_t* __restrict__ value_cache,    // [num_blocks, num_heads, 
head_size, block_size]
+    const int* __restrict__ slot_mapping,  // [num_tokens]
+    const int key_stride, const int value_stride, const int num_heads, const 
int head_size,
+    const int block_size, const int x) {
+  const int token_idx = blockIdx.x;
+  const int slot_idx = slot_mapping[token_idx];
+  const int block_idx = slot_idx / block_size;
+  const int block_offset = slot_idx % block_size;
+
+  const int n = num_heads * head_size;
+  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+    const int src_key_idx = token_idx * key_stride + i;
+    const int src_value_idx = token_idx * value_stride + i;
+
+    const int head_idx = i / head_size;
+    const int head_offset = i % head_size;
+    const int x_idx = head_offset / x;
+    const int x_offset = head_offset % x;
+
+    const int tgt_key_idx = block_idx * num_heads * (head_size / x) * 
block_size * x +
+                            head_idx * (head_size / x) * block_size * x + 
x_idx * block_size * x +
+                            block_offset * x + x_offset;
+    const int tgt_value_idx = block_idx * num_heads * head_size * block_size +
+                              head_idx * head_size * block_size + head_offset 
* block_size +
+                              block_offset;
+    key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
+    value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
+  }
+}
+
+template <typename scalar_t>
+__global__ void reconstruct_from_cache_kernel(
+    const scalar_t* __restrict__ key_cache,  // [num_blocks, num_heads, 
head_size/x, block_size, x]
+    const scalar_t* __restrict__ value_cache,  // [num_blocks, num_heads, 
head_size, block_size]
+    const int* __restrict__ slot_mapping,      // [num_tokens]
+    scalar_t* __restrict__ key,                // [num_tokens, num_heads, 
head_size]
+    scalar_t* __restrict__ value,              // [num_tokens, num_heads, 
head_size]
+    const int key_stride, const int value_stride, const int num_heads, const 
int head_size,
+    const int block_size, const int x) {
+  const int token_idx = blockIdx.x;
+  const int slot_idx = slot_mapping[token_idx];
+  const int block_idx = slot_idx / block_size;
+  const int block_offset = slot_idx % block_size;
+
+  const int n = num_heads * head_size;
+  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+    const int tgt_key_idx = token_idx * key_stride + i;
+    const int tgt_value_idx = token_idx * value_stride + i;
+
+    const int head_idx = i / head_size;
+    const int head_offset = i % head_size;
+    const int x_idx = head_offset / x;
+    const int x_offset = head_offset % x;
+
+    const int src_key_idx = block_idx * num_heads * (head_size / x) * 
block_size * x +
+                            head_idx * (head_size / x) * block_size * x + 
x_idx * block_size * x +
+                            block_offset * x + x_offset;
+    const int src_value_idx = block_idx * num_heads * head_size * block_size +
+                              head_idx * head_size * block_size + head_offset 
* block_size +
+                              block_offset;
+
+    key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
+    value[src_value_idx] = __ldg(&value_cache[tgt_value_idx]);
+  }
+}
+
+// Grid: (num_layers, num_pairs)
+template <typename scalar_t>
+__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t* 
value_cache_ptrs,
+                                   const int64_t* __restrict__ block_mapping,
+                                   const int numel_per_block) {
+  const int layer_idx = blockIdx.x;
+  const int pair_idx = blockIdx.y;
+
+  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
+  scalar_t* value_cache = 
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
+  int64_t src_block_number = block_mapping[2 * pair_idx];
+  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+  const int64_t src_block_offset = src_block_number * numel_per_block;
+  const int64_t dst_block_offset = dst_block_number * numel_per_block;
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    key_cache[dst_offset] = key_cache[src_offset];
+  }
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    value_cache[dst_offset] = value_cache[src_offset];
+  }
+}
+
+}  // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
+    .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, NDArray 
value_cache,
+                       NDArray slot_mapping) {
+      int num_tokens = key->shape[0];
+      int num_heads = key->shape[1];
+      int head_size = key->shape[2];
+      int block_size = key_cache->shape[3];
+      int vec_size = key_cache->shape[4];
+
+      int key_stride = key->shape[1] * key->shape[2];
+      int value_stride = value->shape[1] * value->shape[2];
+
+      dim3 grid(num_tokens);
+      dim3 block(std::min(num_heads * head_size, 512));
+
+      using scalar_t = uint16_t;
+      vllm::reshape_and_cache_kernel<scalar_t><<<grid, block>>>(
+          static_cast<const scalar_t*>(key->data), static_cast<const 
scalar_t*>(value->data),
+          static_cast<scalar_t*>(key_cache->data), 
static_cast<scalar_t*>(value_cache->data),
+          static_cast<const int*>(slot_mapping->data), key_stride, 
value_stride, num_heads,
+          head_size, block_size, vec_size);
+
+      return Array{key_cache, value_cache};
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache")
+    .set_body_typed([](NDArray key_cache, NDArray value_cache, NDArray 
slot_mapping) {
+      int num_tokens = slot_mapping->shape[0];
+      int num_heads = value_cache->shape[1];
+      int head_size = value_cache->shape[2];
+      int block_size = value_cache->shape[3];
+      int vec_size = key_cache->shape[4];
+
+      DLDevice dev = key_cache->device;
+      auto key = NDArray::Empty({num_tokens, num_heads, head_size}, 
key_cache->dtype, dev);
+      auto value = NDArray::Empty({num_tokens, num_heads, head_size}, 
key_cache->dtype, dev);
+
+      int key_stride = key->shape[1] * key->shape[2];
+      int value_stride = value->shape[1] * value->shape[2];
+
+      dim3 grid(num_tokens);
+      dim3 block(std::min(num_heads * head_size, 512));
+
+      using scalar_t = uint16_t;
+      vllm::reconstruct_from_cache_kernel<scalar_t>
+          <<<grid, block>>>(static_cast<const scalar_t*>(key_cache->data),
+                            static_cast<const scalar_t*>(value_cache->data),
+                            static_cast<const int*>(slot_mapping->data),
+                            static_cast<scalar_t*>(key->data), 
static_cast<scalar_t*>(value->data),
+                            key_stride, value_stride, num_heads, head_size, 
block_size, vec_size);
+
+      return Array{key, value};
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks")
+    .set_body_typed([](Array<NDArray> key_value_caches, NDArray block_mapping) 
{
+      auto num_layers = key_value_caches.size() / 2;
+      auto num_pairs = block_mapping->shape[0] / 2;
+
+      if (num_layers == 0) {
+        return;
+      }
+
+      std::vector<int64_t> key_cache_ptrs(num_layers);
+      std::vector<int64_t> value_cache_ptrs(num_layers);
+      for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+        key_cache_ptrs[layer_idx] =
+            reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx]->data);
+        value_cache_ptrs[layer_idx] =
+            reinterpret_cast<int64_t>(key_value_caches[2 * layer_idx + 
1]->data);
+      }
+
+      NDArray key_cache = key_value_caches[1];  // [num_blocks, num_heads, 
head_size, block_size]
+      DLDevice dev = key_cache->device;
+
+      NDArray key_cache_ptrs_gpu =
+          NDArray::Empty({static_cast<int>(num_layers)}, 
runtime::DataType::Int(64), dev);
+      NDArray value_cache_ptrs_gpu =
+          NDArray::Empty({static_cast<int>(num_layers)}, 
runtime::DataType::Int(64), dev);
+      key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(),
+                                       sizeof(int64_t) * 
key_cache_ptrs.size());
+      value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(),
+                                         sizeof(int64_t) * 
value_cache_ptrs.size());
+
+      NDArray block_mapping_gpu =
+          NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64), 
dev);
+      block_mapping_gpu.CopyFromBytes(block_mapping->data,
+                                      sizeof(int64_t) * 
block_mapping->shape[0]);
+
+      const int numel_per_block = key_cache->shape[1] * key_cache->shape[2] * 
key_cache->shape[3];
+      dim3 grid(num_layers, num_pairs);
+      dim3 block(std::min(1024, numel_per_block));
+
+      using scalar_t = uint16_t;
+      vllm::copy_blocks_kernel<scalar_t>
+          <<<grid, block>>>(static_cast<int64_t*>(key_cache_ptrs_gpu->data),
+                            static_cast<int64_t*>(value_cache_ptrs_gpu->data),
+                            static_cast<int64_t*>(block_mapping_gpu->data), 
numel_per_block);
+    });
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/vllm/dtype_float16.h 
b/src/runtime/contrib/vllm/dtype_float16.h
new file mode 100644
index 0000000000..4f6f52bb50
--- /dev/null
+++ b/src/runtime/contrib/vllm/dtype_float16.h
@@ -0,0 +1,697 @@
+/*
+ * Adapted from
+ * 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and
+ * 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <stdint.h>
+
+namespace vllm {
+
+// A vector type to store Q, K, V elements.
+template <typename T, int VEC_SIZE>
+struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template <typename T>
+struct FloatVec {};
+
+// Template vector operations.
+template <typename Acc, typename A, typename B>
+inline __device__ Acc mul(A a, B b);
+
+template <typename T>
+inline __device__ float sum(T v);
+
+template <typename T>
+inline __device__ float dot(T a, T b) {
+  return sum(mul<T, T, T>(a, b));
+}
+
+template <typename A, typename T>
+inline __device__ float dot(T a, T b) {
+  return sum(mul<A, T, T>(a, b));
+}
+
+template <typename T>
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void zero(T& dst) {
+  constexpr int WORDS = sizeof(T) / 4;
+  union {
+    T raw;
+    uint32_t words[WORDS];
+  } tmp;
+
+#pragma unroll
+  for (int ii = 0; ii < WORDS; ++ii) {
+    tmp.words[ii] = 0u;
+  }
+  dst = tmp.raw;
+}
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+  float2 x;
+  float2 y;
+};
+
+struct Float8_ {
+  float2 x;
+  float2 y;
+  float2 z;
+  float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template <>
+struct Vec<float, 1> {
+  using Type = float;
+};
+template <>
+struct Vec<float, 2> {
+  using Type = float2;
+};
+template <>
+struct Vec<float, 4> {
+  using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<float> {
+  using Type = float;
+};
+template <>
+struct FloatVec<float2> {
+  using Type = float2;
+};
+template <>
+struct FloatVec<float4> {
+  using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) { return a + b; }
+
+inline __device__ float2 add(float2 a, float2 b) {
+  float2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+  float4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ float mul<float, float>(float a, float b) {
+  return a * b;
+}
+
+template <>
+inline __device__ float2 mul(float2 a, float2 b) {
+  float2 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  return c;
+}
+
+template <>
+inline __device__ float2 mul(float a, float2 b) {
+  float2 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  return c;
+}
+
+template <>
+inline __device__ float4 mul(float4 a, float4 b) {
+  float4 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  c.z = a.z * b.z;
+  c.w = a.w * b.w;
+  return c;
+}
+
+template <>
+inline __device__ float4 mul(float a, float4 b) {
+  float4 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  c.z = a * b.z;
+  c.w = a * b.w;
+  return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) { return a * b + c; }
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+  Float4_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+  Float8_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(float v) {
+  return v;
+}
+
+template <>
+inline __device__ float sum(float2 v) {
+  return v.x + v.y;
+}
+
+template <>
+inline __device__ float sum(float4 v) {
+  return v.x + v.y + v.z + v.w;
+}
+
+template <>
+inline __device__ float sum(Float4_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template <>
+inline __device__ float sum(Float8_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) { return a * b; }
+
+inline __device__ float dot(float2 a, float2 b) {
+  float2 c = mul<float2, float2, float2>(a, b);
+  return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  acc = fma(a.z, b.z, acc);
+  acc = fma(a.w, b.w, acc);
+  return acc.x + acc.y;
+}
+
+// From float to float.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float& dst, float src) { dst = src; }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
+
+// From float to float.
+inline __device__ float to_float(float u) { return u; }
+
+inline __device__ float2 to_float(float2 u) { return u; }
+
+inline __device__ float4 to_float(float4 u) { return u; }
+
+inline __device__ Float4_ to_float(Float4_ u) { return u; }
+
+inline __device__ Float8_ to_float(Float8_ u) { return u; }
+
+// FP16 vector types for Q, K, V.
+template <>
+struct Vec<uint16_t, 1> {
+  using Type = uint16_t;
+};
+template <>
+struct Vec<uint16_t, 2> {
+  using Type = uint32_t;
+};
+template <>
+struct Vec<uint16_t, 4> {
+  using Type = uint2;
+};
+template <>
+struct Vec<uint16_t, 8> {
+  using Type = uint4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<uint16_t> {
+  using Type = float;
+};
+template <>
+struct FloatVec<uint32_t> {
+  using Type = float2;
+};
+template <>
+struct FloatVec<uint2> {
+  using Type = Float4_;
+};
+template <>
+struct FloatVec<uint4> {
+  using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ uint32_t h0_h0(uint16_t a) {
+  uint32_t b;
+  asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+  return b;
+}
+
+inline __device__ float half_to_float(uint16_t h) {
+  float f;
+  asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+  return f;
+}
+
+inline __device__ float2 half2_to_float2(uint32_t v) {
+  uint16_t lo, hi;
+  asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+  return make_float2(half_to_float(lo), half_to_float(hi));
+}
+
+inline __device__ uint16_t float_to_half(float f) {
+  union {
+    uint32_t u32;
+    uint16_t u16[2];
+  } tmp;
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+  return tmp.u16[0];
+}
+
+inline __device__ uint32_t float2_to_half2(float2 f) {
+  union {
+    uint32_t u32;
+    uint16_t u16[2];
+  } tmp;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), 
"f"(f.x));
+#else
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+#endif
+  return tmp.u32;
+}
+
+// Vector addition.
+inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+  uint16_t c;
+  asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+  return c;
+}
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+  uint32_t c;
+  asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+  return c;
+}
+
+inline __device__ uint2 add(uint2 a, uint2 b) {
+  uint2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ uint4 add(uint4 a, uint4 b) {
+  uint4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+inline __device__ float2 add(uint32_t a, float2 fb) {
+  float2 fa = half2_to_float2(a);
+  return add(fa, fb);
+}
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb) {
+  Float4_ fc;
+  fc.x = add(a.x, fb.x);
+  fc.y = add(a.y, fb.y);
+  return fc;
+}
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb) {
+  Float8_ fc;
+  fc.x = add(a.x, fb.x);
+  fc.y = add(a.y, fb.y);
+  fc.z = add(a.z, fb.z);
+  fc.w = add(a.w, fb.w);
+  return fc;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+  uint16_t c;
+  asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+  return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+  uint32_t c;
+  asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+  return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
+  return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ uint2 mul(uint2 a, uint2 b) {
+  uint2 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+  return c;
+}
+
+template <>
+inline __device__ uint2 mul(uint16_t a, uint2 b) {
+  uint32_t s = h0_h0(a);
+  uint2 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+  return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint4 a, uint4 b) {
+  uint4 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+  c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
+  c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
+  return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint16_t a, uint4 b) {
+  uint32_t s = h0_h0(a);
+  uint4 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+  c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
+  c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
+  return c;
+}
+
+template <>
+inline __device__ float mul(uint16_t a, uint16_t b) {
+  float fa = half_to_float(a);
+  float fb = half_to_float(b);
+  return fa * fb;
+}
+
+template <>
+inline __device__ float2 mul(uint32_t a, uint32_t b) {
+  float2 fa = half2_to_float2(a);
+  float2 fb = half2_to_float2(b);
+  return mul<float2, float2, float2>(fa, fb);
+}
+
+template <>
+inline __device__ float2 mul(uint16_t a, uint32_t b) {
+  return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ Float4_ mul(uint2 a, uint2 b) {
+  Float4_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+  return fc;
+}
+
+template <>
+inline __device__ Float4_ mul(uint16_t a, uint2 b) {
+  uint32_t s = h0_h0(a);
+  Float4_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+  return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint4 a, uint4 b) {
+  Float8_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+  fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
+  fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
+  return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint16_t a, uint4 b) {
+  uint32_t s = h0_h0(a);
+  Float8_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+  fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
+  fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
+  return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+  uint32_t d;
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), 
"r"(c));
+  return d;
+}
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return 
fma(h0_h0(a), b, c); }
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
+  uint2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
+  uint32_t s = h0_h0(a);
+  uint2 d;
+  d.x = fma(s, b.x, c.x);
+  d.y = fma(s, b.y, c.y);
+  return d;
+}
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
+  uint4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
+  uint32_t s = h0_h0(a);
+  uint4 d;
+  d.x = fma(s, b.x, c.x);
+  d.y = fma(s, b.y, c.y);
+  d.z = fma(s, b.z, c.z);
+  d.w = fma(s, b.w, c.w);
+  return d;
+}
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
+  float fa = half_to_float(a);
+  float fb = half_to_float(b);
+  return fa * fb + fc;
+}
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
+  float2 fa = half2_to_float2(a);
+  float2 fb = half2_to_float2(b);
+  return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return 
fma(h0_h0(a), b, fc); }
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
+  Float4_ fd;
+  fd.x = fma(a.x, b.x, fc.x);
+  fd.y = fma(a.y, b.y, fc.y);
+  return fd;
+}
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
+  uint32_t s = h0_h0(a);
+  Float4_ fd;
+  fd.x = fma(s, b.x, fc.x);
+  fd.y = fma(s, b.y, fc.y);
+  return fd;
+}
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
+  Float8_ fd;
+  fd.x = fma(a.x, b.x, fc.x);
+  fd.y = fma(a.y, b.y, fc.y);
+  fd.z = fma(a.z, b.z, fc.z);
+  fd.w = fma(a.w, b.w, fc.w);
+  return fd;
+}
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
+  uint32_t s = h0_h0(a);
+  Float8_ fd;
+  fd.x = fma(s, b.x, fc.x);
+  fd.y = fma(s, b.y, fc.y);
+  fd.z = fma(s, b.z, fc.z);
+  fd.w = fma(s, b.w, fc.w);
+  return fd;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(uint16_t v) {
+  return half_to_float(v);
+}
+
+template <>
+inline __device__ float sum(uint32_t v) {
+  float2 tmp = half2_to_float2(v);
+  return tmp.x + tmp.y;
+}
+
+template <>
+inline __device__ float sum(uint2 v) {
+  uint32_t c = add(v.x, v.y);
+  return sum(c);
+}
+
+template <>
+inline __device__ float sum(uint4 v) {
+  uint32_t c = add(v.x, v.y);
+  c = add(c, v.z);
+  c = add(c, v.w);
+  return sum(c);
+}
+
+// From float32 to float16.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint16_t& dst, float src) { dst = 
float_to_half(src); }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint32_t& dst, float2 src) { dst = 
float2_to_half2(src); }
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint2& dst, Float4_ src) {
+  dst.x = float2_to_half2(src.x);
+  dst.y = float2_to_half2(src.y);
+}
+
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void from_float(uint4& dst, Float8_ src) {
+  dst.x = float2_to_half2(src.x);
+  dst.y = float2_to_half2(src.y);
+  dst.z = float2_to_half2(src.z);
+  dst.w = float2_to_half2(src.w);
+}
+
+// From float16 to float32.
+inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
+
+inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
+
+inline __device__ Float4_ to_float(uint2 u) {
+  Float4_ tmp;
+  tmp.x = half2_to_float2(u.x);
+  tmp.y = half2_to_float2(u.y);
+  return tmp;
+}
+
+inline __device__ Float8_ to_float(uint4 u) {
+  Float8_ tmp;
+  tmp.x = half2_to_float2(u.x);
+  tmp.y = half2_to_float2(u.y);
+  tmp.z = half2_to_float2(u.z);
+  tmp.w = half2_to_float2(u.w);
+  return tmp;
+}
+
+// Zero-out a variable.
+// NOLINTNEXTLINE(runtime/references)
+inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
+
+}  // namespace vllm
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index f3d9b9df32..757c00e0e3 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -46,6 +46,7 @@ ALLOW_EXTENSION = {
     "pyd",
     "pyx",
     "cu",
+    "cuh",
     "bat",
     # relay text format
     "rly",
diff --git a/tests/python/relax/test_contrib_vllm.py 
b/tests/python/relax/test_contrib_vllm.py
new file mode 100644
index 0000000000..b674e1c9fb
--- /dev/null
+++ b/tests/python/relax/test_contrib_vllm.py
@@ -0,0 +1,746 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm.testing
+import tvm
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+has_vllm = 
tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True)
+
+vllm_enabled = pytest.mark.skipif(
+    not has_vllm,
+    reason="VLLM not enabled.",
+)
+
+pytestmark = [vllm_enabled]
+
+
+def build_and_run(mod, inputs_np, target, legalize=True):
+    if legalize:
+        mod = relax.transform.LegalizeOps()(mod)
+
+        with tvm.target.Target("cuda"):
+            mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+    with tvm.transform.PassContext():
+        ex = relax.build(mod, target)
+
+    dev = tvm.device(target, 0)
+    vm = relax.VirtualMachine(ex, dev)
+    f = vm["main"]
+    inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+
+    out = f(*inputs)
+
+    if isinstance(out, tvm.ir.container.Array):
+        return [arr.numpy() for arr in out]
+
+    return out.numpy()
+
+
+def test_attention():
+    @I.ir_module
+    class ModulePagedAttentionV1:
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            query: R.Tensor(("num_seqs", 1, 64), dtype="float16"),
+            key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+            value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"),
+            block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"), 
dtype="int32"),
+            context_lens: R.Tensor(("num_seqs",), dtype="int32"),
+        ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"):
+            with R.dataflow():
+                max_len = R.to_vdevice(R.max(context_lens), "llvm:0")
+                out = R.call_dps_packed(
+                    "tvm.contrib.vllm.single_query_cached_kv_attention_v1",
+                    [
+                        query,
+                        key_cache,
+                        value_cache,
+                        block_tables,
+                        context_lens,
+                        16,
+                        max_len,
+                    ],
+                    out_sinfo=query.struct_info,
+                )
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class ModulePagedAttentionV2:
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            query: R.Tensor(("num_seqs", 1, 64), dtype="float16"),
+            key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+            value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"),
+            block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"), 
dtype="int32"),
+            context_lens: R.Tensor(("num_seqs",), dtype="int32"),
+        ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"):
+            with R.dataflow():
+                num_seqs = T.int64()
+                max_len = R.to_vdevice(R.max(context_lens), "llvm:0")
+                # alloc workspace
+                exp_sums = R.zeros((num_seqs, 1, 1), "float32")
+                max_logits = R.zeros((num_seqs, 1, 1), "float32")
+                tmp_out = R.zeros((num_seqs, 1, 1, 64), "float16")
+
+                out = R.call_dps_packed(
+                    "tvm.contrib.vllm.single_query_cached_kv_attention_v2",
+                    [
+                        query,
+                        key_cache,
+                        value_cache,
+                        block_tables,
+                        context_lens,
+                        16,
+                        max_len,
+                        exp_sums,
+                        max_logits,
+                        tmp_out,
+                    ],
+                    out_sinfo=query.struct_info,
+                )
+                R.output(out)
+            return out
+
+    np.random.seed(0)
+    num_heads = 1
+    head_dim = 64
+    vec_size = 8
+    block_size = 16
+    num_seqs = 2
+    num_blocks = 1
+    query = np.random.randn(num_seqs, num_heads, head_dim).astype("float16")
+    key_cache = np.random.randn(
+        num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+    ).astype("float16")
+    value_cache = np.random.randn(num_blocks, num_heads, head_dim, 
block_size).astype("float16")
+    block_tables = np.array([[0], [0]]).astype("int32")
+    context_lens = np.array([3, 5]).astype("int32")
+
+    out_v1 = build_and_run(
+        ModulePagedAttentionV1,
+        [query, key_cache, value_cache, block_tables, context_lens],
+        "cuda",
+        legalize=True,
+    )
+
+    out_v2 = build_and_run(
+        ModulePagedAttentionV2,
+        [query, key_cache, value_cache, block_tables, context_lens],
+        "cuda",
+        legalize=True,
+    )
+
+    ref = np.array(
+        [
+            [
+                [
+                    0.28271484375,
+                    0.197021484375,
+                    -0.278564453125,
+                    0.444580078125,
+                    -0.47802734375,
+                    -0.7548828125,
+                    -0.84228515625,
+                    -0.80322265625,
+                    0.478759765625,
+                    0.195068359375,
+                    -0.59521484375,
+                    0.779296875,
+                    0.35888671875,
+                    -0.158935546875,
+                    -0.6103515625,
+                    0.188720703125,
+                    0.410400390625,
+                    0.28662109375,
+                    0.40283203125,
+                    -1.23046875,
+                    -0.01043701171875,
+                    -0.0794677734375,
+                    -0.0350341796875,
+                    0.12005615234375,
+                    0.63671875,
+                    0.368896484375,
+                    -0.58642578125,
+                    -0.34228515625,
+                    -0.552734375,
+                    0.947265625,
+                    -0.079833984375,
+                    0.85302734375,
+                    0.1947021484375,
+                    0.16748046875,
+                    -0.083984375,
+                    -0.75244140625,
+                    -0.568359375,
+                    -1.45703125,
+                    -1.021484375,
+                    -0.2235107421875,
+                    -0.98828125,
+                    -0.87109375,
+                    -0.43359375,
+                    -0.3271484375,
+                    0.0557861328125,
+                    -0.269287109375,
+                    -1.009765625,
+                    0.1387939453125,
+                    -0.0831298828125,
+                    0.27978515625,
+                    -0.9736328125,
+                    0.7802734375,
+                    -0.1329345703125,
+                    -0.5927734375,
+                    -1.6923828125,
+                    1.1904296875,
+                    -1.3759765625,
+                    -1.080078125,
+                    -0.53173828125,
+                    0.28466796875,
+                    -2.02734375,
+                    -0.377685546875,
+                    -0.81201171875,
+                    -0.7412109375,
+                ]
+            ],
+            [
+                [
+                    0.482177734375,
+                    0.114501953125,
+                    -0.265869140625,
+                    -1.154296875,
+                    0.28857421875,
+                    0.71240234375,
+                    -1.1767578125,
+                    0.187744140625,
+                    -0.23486328125,
+                    0.07135009765625,
+                    -0.34521484375,
+                    0.444091796875,
+                    -0.09130859375,
+                    0.900390625,
+                    -0.043701171875,
+                    0.61279296875,
+                    0.1201171875,
+                    0.443603515625,
+                    -0.4150390625,
+                    -0.9560546875,
+                    -0.1917724609375,
+                    0.0863037109375,
+                    0.267578125,
+                    0.04931640625,
+                    -0.32666015625,
+                    0.5859375,
+                    -0.57421875,
+                    0.29541015625,
+                    -0.26220703125,
+                    1.177734375,
+                    0.11309814453125,
+                    0.81201171875,
+                    0.346435546875,
+                    0.53271484375,
+                    -0.0009765625,
+                    -0.35205078125,
+                    -0.1298828125,
+                    -1.2431640625,
+                    -0.2196044921875,
+                    0.31640625,
+                    -0.40869140625,
+                    0.25244140625,
+                    -0.9853515625,
+                    0.284912109375,
+                    0.399169921875,
+                    -1.1435546875,
+                    0.305419921875,
+                    0.300048828125,
+                    -0.84521484375,
+                    -0.5166015625,
+                    -0.787109375,
+                    0.1011962890625,
+                    -1.0302734375,
+                    -1.35546875,
+                    -0.0556640625,
+                    1.0791015625,
+                    -0.047607421875,
+                    -0.498046875,
+                    -0.055999755859375,
+                    -0.35009765625,
+                    -1.4296875,
+                    0.350341796875,
+                    -1.16796875,
+                    -0.576171875,
+                ]
+            ],
+        ]
+    ).astype("float16")
+
+    # from vllm import attention_ops
+    # import torch
+    #
+    # def to_torch(arr):
+    #     return torch.from_numpy(arr).to("cuda")
+    #
+    # ref = to_torch(np.zeros_like(query))
+    # attention_ops.single_query_cached_kv_attention(
+    #     ref,
+    #     to_torch(query),
+    #     to_torch(key_cache),
+    #     to_torch(value_cache),
+    #     num_kv_heads,
+    #     query.shape[-1] ** -0.5,  # scale
+    #     to_torch(block_tables),
+    #     to_torch(context_lens),
+    #     value_cache.shape[-1],  # block_size,
+    #     np.max(context_lens),
+    #     None,
+    #     )
+    # ref = ref.cpu().numpy()
+
+    # print(ref.tolist())
+
+    for out in [out_v1, out_v2]:
+        assert np.max(np.abs(ref - out)) == 0.0
+
+
+def test_cache():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            key: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+            value: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+            key_cache: R.Tensor(("num_blocks", 1, 1, 16, 8), dtype="float16"),
+            value_cache: R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+            slot_mapping: R.Tensor(("num_tokens",), dtype="int32"),
+        ) -> R.Tuple(
+            [
+                R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+                R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+            ]
+        ):
+            with R.dataflow():
+                kv = R.call_pure_packed(
+                    "tvm.contrib.vllm.reshape_and_cache",
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    slot_mapping,
+                    sinfo_args=[key_cache.struct_info, 
value_cache.struct_info],
+                )
+                out = (kv[0], kv[1])
+                R.output(out)
+            return out
+
+    np.random.seed(0)
+    num_heads = 1
+    head_dim = 8
+    vec_size = 8
+    block_size = 16
+    num_tokens = 8
+    num_blocks = 1
+    key = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+    value = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+    key_cache_before = np.random.randn(
+        num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+    ).astype("float16")
+    value_cache_before = np.random.randn(num_blocks, num_heads, head_dim, 
block_size).astype(
+        "float16"
+    )
+    slot_mapping = np.arange(num_tokens).astype("int32")
+
+    key_cache = key_cache_before.copy()
+    value_cache = value_cache_before.copy()
+
+    out_key_cache, out_value_cache = build_and_run(
+        Module,
+        [key, value, key_cache, value_cache, slot_mapping],
+        "cuda",
+    )
+
+    ref_key_cache = np.array(
+        [
+            [
+                [
+                    [
+                        [
+                            1.763671875,
+                            0.400146484375,
+                            0.978515625,
+                            2.240234375,
+                            1.8671875,
+                            -0.97705078125,
+                            0.9501953125,
+                            -0.1513671875,
+                        ],
+                        [
+                            -0.10321044921875,
+                            0.41064453125,
+                            0.14404296875,
+                            1.4541015625,
+                            0.76123046875,
+                            0.1217041015625,
+                            0.44384765625,
+                            0.333740234375,
+                        ],
+                        [
+                            1.494140625,
+                            -0.2052001953125,
+                            0.31298828125,
+                            -0.85400390625,
+                            -2.552734375,
+                            0.65380859375,
+                            0.8642578125,
+                            -0.7421875,
+                        ],
+                        [
+                            2.26953125,
+                            -1.4541015625,
+                            0.045745849609375,
+                            -0.1871337890625,
+                            1.533203125,
+                            1.4697265625,
+                            0.1549072265625,
+                            0.378173828125,
+                        ],
+                        [
+                            -0.8876953125,
+                            -1.98046875,
+                            -0.347900390625,
+                            0.1563720703125,
+                            1.23046875,
+                            1.2021484375,
+                            -0.38720703125,
+                            -0.30224609375,
+                        ],
+                        [
+                            -1.048828125,
+                            -1.419921875,
+                            -1.7060546875,
+                            1.951171875,
+                            -0.509765625,
+                            -0.43798828125,
+                            -1.2529296875,
+                            0.77734375,
+                        ],
+                        [
+                            -1.6142578125,
+                            -0.2127685546875,
+                            -0.8955078125,
+                            0.386962890625,
+                            -0.5107421875,
+                            -1.1806640625,
+                            -0.0281829833984375,
+                            0.42822265625,
+                        ],
+                        [
+                            0.0665283203125,
+                            0.302490234375,
+                            -0.63427734375,
+                            -0.36279296875,
+                            -0.67236328125,
+                            -0.359619140625,
+                            -0.81298828125,
+                            -1.7265625,
+                        ],
+                        [
+                            -0.039276123046875,
+                            -1.16796875,
+                            0.5234375,
+                            -0.1715087890625,
+                            0.77197265625,
+                            0.82373046875,
+                            2.1640625,
+                            1.3369140625,
+                        ],
+                        [
+                            -0.369140625,
+                            -0.2393798828125,
+                            1.099609375,
+                            0.6552734375,
+                            0.64013671875,
+                            -1.6171875,
+                            -0.024322509765625,
+                            -0.73779296875,
+                        ],
+                        [
+                            0.280029296875,
+                            -0.09814453125,
+                            0.91015625,
+                            0.317138671875,
+                            0.7861328125,
+                            -0.46630859375,
+                            -0.9443359375,
+                            -0.41015625,
+                        ],
+                        [
+                            -0.0170135498046875,
+                            0.379150390625,
+                            2.259765625,
+                            -0.042266845703125,
+                            -0.9560546875,
+                            -0.345947265625,
+                            -0.463623046875,
+                            0.4814453125,
+                        ],
+                        [
+                            -1.541015625,
+                            0.063232421875,
+                            0.156494140625,
+                            0.232177734375,
+                            -0.59716796875,
+                            -0.2379150390625,
+                            -1.423828125,
+                            -0.493408203125,
+                        ],
+                        [
+                            -0.54296875,
+                            0.416015625,
+                            -1.15625,
+                            0.78125,
+                            1.494140625,
+                            -2.0703125,
+                            0.42626953125,
+                            0.6767578125,
+                        ],
+                        [
+                            -0.63720703125,
+                            -0.397216796875,
+                            -0.1329345703125,
+                            -0.2978515625,
+                            -0.30908203125,
+                            -1.67578125,
+                            1.15234375,
+                            1.080078125,
+                        ],
+                        [
+                            -0.8134765625,
+                            -1.466796875,
+                            0.52099609375,
+                            -0.57568359375,
+                            0.1419677734375,
+                            -0.3193359375,
+                            0.69140625,
+                            0.69482421875,
+                        ],
+                    ]
+                ]
+            ]
+        ]
+    ).astype("float16")
+
+    ref_value_cache = np.array(
+        [
+            [
+                [
+                    [
+                        0.1773681640625,
+                        1.1396484375,
+                        -1.1650390625,
+                        -1.0703125,
+                        0.010498046875,
+                        -1.1728515625,
+                        -0.861328125,
+                        0.37646484375,
+                        -1.9365234375,
+                        0.188720703125,
+                        0.52392578125,
+                        0.08843994140625,
+                        -0.310791015625,
+                        0.097412109375,
+                        0.39892578125,
+                        -2.7734375,
+                    ],
+                    [
+                        -0.40185546875,
+                        -1.234375,
+                        0.90087890625,
+                        1.0546875,
+                        1.7861328125,
+                        1.943359375,
+                        1.91015625,
+                        -1.099609375,
+                        -0.11053466796875,
+                        1.0205078125,
+                        -0.69189453125,
+                        1.5361328125,
+                        0.286376953125,
+                        0.60888671875,
+                        -1.044921875,
+                        1.2109375,
+                    ],
+                    [
+                        -1.6298828125,
+                        0.40234375,
+                        0.465576171875,
+                        -0.403076171875,
+                        0.126953125,
+                        -0.41357421875,
+                        -0.26806640625,
+                        0.29833984375,
+                        0.09771728515625,
+                        0.5830078125,
+                        -0.3994140625,
+                        0.3701171875,
+                        -1.306640625,
+                        1.658203125,
+                        -0.1181640625,
+                        -0.68017578125,
+                    ],
+                    [
+                        0.462890625,
+                        -0.6845703125,
+                        -1.5361328125,
+                        1.22265625,
+                        0.402099609375,
+                        -0.74755859375,
+                        0.80224609375,
+                        1.326171875,
+                        -1.126953125,
+                        -0.73046875,
+                        -0.384765625,
+                        0.0943603515625,
+                        -0.04217529296875,
+                        -0.286865234375,
+                        -0.061614990234375,
+                        -0.1072998046875,
+                    ],
+                    [
+                        -0.9072265625,
+                        -0.87060546875,
+                        1.48828125,
+                        0.208251953125,
+                        1.8828125,
+                        1.9228515625,
+                        0.947265625,
+                        -0.6943359375,
+                        -0.70458984375,
+                        0.943359375,
+                        0.7470703125,
+                        -1.1884765625,
+                        0.7734375,
+                        -1.18359375,
+                        -2.658203125,
+                        0.6064453125,
+                    ],
+                    [
+                        0.05194091796875,
+                        -0.57861328125,
+                        1.8955078125,
+                        0.9765625,
+                        -1.34765625,
+                        1.48046875,
+                        -0.155029296875,
+                        -0.149658203125,
+                        -0.44091796875,
+                        -0.2802734375,
+                        -0.36474609375,
+                        0.15673828125,
+                        0.57861328125,
+                        0.349609375,
+                        -0.76416015625,
+                        -1.4375,
+                    ],
+                    [
+                        0.72900390625,
+                        -0.3115234375,
+                        1.1787109375,
+                        0.3564453125,
+                        -1.2705078125,
+                        1.8671875,
+                        0.6142578125,
+                        -0.43505859375,
+                        0.6982421875,
+                        0.0037708282470703125,
+                        0.931640625,
+                        0.33984375,
+                        -0.01568603515625,
+                        0.160888671875,
+                        -0.190673828125,
+                        -0.394775390625,
+                    ],
+                    [
+                        0.1290283203125,
+                        0.05615234375,
+                        -0.179931640625,
+                        0.70654296875,
+                        0.96923828125,
+                        0.90625,
+                        0.92236328125,
+                        1.849609375,
+                        0.6435546875,
+                        -1.5703125,
+                        -0.2069091796875,
+                        0.88037109375,
+                        -1.6982421875,
+                        0.38720703125,
+                        -2.255859375,
+                        -1.0224609375,
+                    ],
+                ]
+            ]
+        ]
+    ).astype("float16")
+
+    # from vllm import cache_ops
+    # import torch
+
+    # def to_torch(arr):
+    #     return torch.from_numpy(arr).to("cuda")
+
+    # ref_key_cache = to_torch(key_cache_before.copy())
+    # ref_value_cache = to_torch(value_cache_before.copy())
+
+    # cache_ops.reshape_and_cache(
+    #     to_torch(key),
+    #     to_torch(value),
+    #     ref_key_cache,
+    #     ref_value_cache,
+    #     to_torch(slot_mapping),
+    # )
+
+    # ref_key_cache = ref_key_cache.cpu().numpy()
+    # ref_value_cache = ref_value_cache.cpu().numpy()
+
+    assert np.max(np.abs(out_key_cache - ref_key_cache)) == 0
+    assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to