This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 605a9ad8ce8b81710e626ad7ebf8a8167fbbfaff
Author: Masahiro Masuda <[email protected]>
AuthorDate: Mon Oct 2 05:00:38 2023 +0900

    Add vllm kernels
---
 CMakeLists.txt                                |   1 +
 cmake/config.cmake                            |   4 +
 cmake/modules/CUDA.cmake                      |   6 +
 cmake/modules/contrib/vllm.cmake              |  30 ++
 licenses/LICENSE.vllm.txt                     | 201 ++++++++
 src/runtime/contrib/vllm/attention_kernels.cu | 509 +++++++++++++++++++
 src/runtime/contrib/vllm/cache_alloc.cc       |  55 ++
 src/runtime/contrib/vllm/cache_kernels.cu     | 109 ++++
 src/runtime/contrib/vllm/dtype_float16.h      | 688 +++++++++++++++++++++++++
 tests/python/relax/test_contrib_vllm.py       | 695 ++++++++++++++++++++++++++
 10 files changed, 2298 insertions(+)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 29aa865a59..e5b3305d32 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -572,6 +572,7 @@ include(cmake/modules/contrib/VitisAI.cmake)
 include(cmake/modules/contrib/Verilator.cmake)
 include(cmake/modules/contrib/UMA.cmake)
 include(cmake/modules/contrib/MSC.cmake)
+include(cmake/modules/contrib/vllm.cmake)
 include(cmake/modules/Git.cmake)
 include(cmake/modules/LibInfo.cmake)
 include(cmake/modules/RustExt.cmake)
diff --git a/cmake/config.cmake b/cmake/config.cmake
index bf0a49b1aa..0ef8952ea4 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -442,3 +442,7 @@ set(USE_UMA OFF)
 
 # Set custom Alloc Alignment for device allocated memory ndarray points to
 set(USE_KALLOC_ALIGNMENT 64)
+
+# List of architectures to generate CUDA device code for, only used for
+# compiling external kernels from Thrust and vLLM.
+set(CMAKE_CUDA_ARCHITECTURES "80;75")
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index ce561c66a6..f2370648e1 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -64,6 +64,12 @@ if(USE_CUDA)
     message(STATUS "Build with Thrust support")
     cmake_minimum_required(VERSION 3.13) # to compile CUDA code
     enable_language(CUDA)
+
+    if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+      message(WARNING "CMAKE_CUDA_ARCHITECTURES not set, compiling Thrust for 
sm80 and sm75.")
+      set(CMAKE_CUDA_ARCHITECTURES "80;75")
+    endif()
+
     set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
     tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
     list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake
new file mode 100644
index 0000000000..ae9474cb8b
--- /dev/null
+++ b/cmake/modules/contrib/vllm.cmake
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(USE_VLLM)
+  message(STATUS "Build with vllm paged attention kernel.")
+  include_directories(src/runtime/contrib/vllm)
+  enable_language(CUDA)
+
+  if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+    message(WARNING "CMAKE_CUDA_ARCHITECTURES not set, compiling vLLM kernels 
for sm80 and sm75.")
+    set(CMAKE_CUDA_ARCHITECTURES "80;75")
+  endif()
+
+  tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu 
src/runtime/contrib/vllm/*.cc)
+  list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC})
+endif(USE_VLLM)
diff --git a/licenses/LICENSE.vllm.txt b/licenses/LICENSE.vllm.txt
new file mode 100644
index 0000000000..261eeb9e9f
--- /dev/null
+++ b/licenses/LICENSE.vllm.txt
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/src/runtime/contrib/vllm/attention_kernels.cu 
b/src/runtime/contrib/vllm/attention_kernels.cu
new file mode 100644
index 0000000000..8176aea7c1
--- /dev/null
+++ b/src/runtime/contrib/vllm/attention_kernels.cu
@@ -0,0 +1,509 @@
+/*
+ * Adapted from 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <float.h>
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <type_traits>
+
+#include "dtype_float16.h"
+
+#define WARP_SIZE 32
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+namespace vllm {
+
+// Q*K^T operation.
+template<int THREAD_GROUP_SIZE, typename Vec, int N>
+inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
+  using A_vec = typename FloatVec<Vec>::Type;
+  // Compute the parallel products for Q*K^T (treat vector lanes separately).
+  A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
+#pragma unroll
+  for (int ii = 1; ii < N; ++ii) {
+    qk_vec = fma(q[ii], k[ii], qk_vec);
+  }
+
+  // Finalize the reduction across lanes.
+  float qk = sum(qk_vec);
+#pragma unroll
+  for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+    qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+  }
+  return qk;
+}
+
+template<typename T, int THREAD_GROUP_SIZE>
+struct Qk_dot {
+  template<typename Vec, int N>
+  static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
+    return qk_dot_<THREAD_GROUP_SIZE>(q, k);
+  }
+};
+
+// Utility function for attention softmax.
+template<int NUM_WARPS>
+inline __device__ float block_sum(float* red_smem, float sum) {
+  // Decompose the thread index into warp / lane.
+  int warp = threadIdx.x / WARP_SIZE;
+  int lane = threadIdx.x % WARP_SIZE;
+
+  // Compute the sum per warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+  }
+
+  // Warp leaders store the data to shared memory.
+  if (lane == 0) {
+    red_smem[warp] = sum;
+  }
+
+  // Make sure the data is in shared memory.
+  __syncthreads();
+
+  // The warps compute the final sums.
+  if (lane < NUM_WARPS) {
+    sum = red_smem[lane];
+  }
+
+  // Parallel reduction inside the warp.
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+  }
+
+  // Broadcast to other threads.
+  return __shfl_sync(uint32_t(-1), sum, 0);
+}
+
+// Grid: (num_heads, num_seqs).
+template<
+  typename scalar_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS>
+__global__ void single_query_cached_kv_attention_kernel(
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, 
head_size/x, block_size, x]
+  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, 
head_size, block_size]
+  const int* __restrict__ head_mapping,   // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes,  // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride) {
+  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE;
+  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / 
WARP_SIZE;
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int thread_idx = threadIdx.x;
+  const int warp_idx = thread_idx / WARP_SIZE;
+  const int lane = thread_idx % WARP_SIZE;
+
+  const int head_idx = blockIdx.x;
+  const int num_heads = gridDim.x;
+  const int kv_head_idx = head_mapping[head_idx];
+  const int seq_idx = blockIdx.y;
+  const float alibi_slope = alibi_slopes == nullptr ? 0.f : 
alibi_slopes[head_idx];
+
+  // A vector type to store a part of a key or a query.
+  // The vector size is configured in such a way that the threads in a thread 
group
+  // fetch or compute 16 bytes at a time.
+  // For example, if the size of a thread group is 4 and the data type is half,
+  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+
+  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+  // Load the query to registers.
+  // Each thread in a thread group has a different part of the query.
+  // For example, if the the thread group size is 4, then the first thread in 
the group
+  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 
9, ...
+  // th vectors of the query, and so on.
+  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be 
contiguous.
+  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += 
NUM_THREAD_GROUPS) {
+    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + 
vec_idx * VEC_SIZE);
+  }
+  __syncthreads();
+
+  // Memory planning.
+  extern __shared__ char shared_mem[];
+  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+  float* logits = reinterpret_cast<float*>(shared_mem);
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // x == THREAD_GROUP_SIZE * VEC_SIZE
+  // Each thread group fetches x elements from the key at a time.
+  constexpr int x = 16 / sizeof(scalar_t);
+  float qk_max = -FLT_MAX;
+
+  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+  const int context_len = context_lens[seq_idx];
+  const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
+
+  // Iterate over the key blocks.
+  // Each warp fetches a block of keys for each iteration.
+  // Each thread group in a warp fetches a key from the block, and computes
+  // dot product with the query.
+  for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += 
NUM_WARPS) {
+    const int physical_block_number = block_table[block_idx];
+
+    // Load a key to registers.
+    // Each thread in a thread group has a different part of the key.
+    // For example, if the the thread group size is 4, then the first thread 
in the group
+    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 
9, ... th
+    // vectors of the key, and so on.
+    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % 
BLOCK_SIZE;
+      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+      K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+        const scalar_t* k_ptr = k_cache + physical_block_number * 
kv_block_stride
+                                        + kv_head_idx * kv_head_stride
+                                        + physical_block_offset * x;
+        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+        const int offset1 = (vec_idx * VEC_SIZE) / x;
+        const int offset2 = (vec_idx * VEC_SIZE) % x;
+        k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * 
BLOCK_SIZE * x + offset2);
+      }
+
+      // Compute dot product.
+      // This includes a reduction across the threads in the same thread group.
+      float qk = scale * Qk_dot<scalar_t, 
THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset],
+                                                                  k_vecs);
+      // Add the ALiBi bias if slopes are given.
+      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 
0;
+
+      if (thread_group_offset == 0) {
+        // Store the partial reductions to shared memory.
+        // NOTE(woosuk): It is required to zero out the masked logits.
+        const bool mask = token_idx >= context_len;
+        logits[token_idx] = mask ? 0.f : qk;
+        // Update the max value.
+        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+      }
+    }
+  }
+
+  // Perform reduction across the threads in the same warp to get the
+  // max qk value for each "warp" (not across the thread block yet).
+  // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = qk_max;
+  }
+  __syncthreads();
+
+  // TODO(woosuk): Refactor this part.
+  // Get the max qk value for the sequence.
+  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+  }
+  // Broadcast the max qk value to all threads.
+  qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+
+  // Get the sum of the exp values.
+  float exp_sum = 0.f;
+  for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
+    float val = __expf(logits[i] - qk_max);
+    logits[i] = val;
+    exp_sum += val;
+  }
+  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+
+  // Compute softmax.
+  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+  for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
+    logits[i] *= inv_sum;
+  }
+  __syncthreads();
+
+  // Each thread will fetch 16 bytes from the value cache at a time.
+  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using Float_L_vec = typename FloatVec<L_vec>::Type;
+
+  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+  constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / 
NUM_ROWS_PER_ITER;
+
+  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+  float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    accs[i] = 0.f;
+  }
+
+  scalar_t zero_value;
+  zero(zero_value);
+  for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += 
NUM_WARPS) {
+    const int physical_block_number = block_table[block_idx];
+    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+    L_vec logits_vec;
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + 
token_idx));
+
+    const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                    + kv_head_idx * kv_head_stride;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE) {
+        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+        if (block_idx == num_blocks - 1) {
+          // NOTE(woosuk): When v_vec contains the tokens that are out of the 
context,
+          // we should explicitly zero out the values since they may contain 
NaNs.
+          // See 
https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+#pragma unroll
+          for (int j = 0; j <= V_VEC_SIZE; j++) {
+            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : 
zero_value;
+          }
+        }
+        accs[i] += dot(logits_vec, v_vec);
+      }
+    }
+  }
+
+  // Perform reduction within each warp.
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    float acc = accs[i];
+#pragma unroll
+    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+      acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
+    }
+    accs[i] = acc;
+  }
+
+  // NOTE(woosuk): A barrier is required because the shared memory space for 
logits
+  // is reused for the output.
+  __syncthreads();
+
+  // Perform reduction across warps.
+  float* out_smem = reinterpret_cast<float*>(shared_mem);
+#pragma unroll
+  for (int i = NUM_WARPS; i > 1; i /= 2) {
+    int mid = i / 2;
+    // Upper warps write to shared memory.
+    if (warp_idx >= mid && warp_idx < i) {
+      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          dst[row_idx] = accs[i];
+        }
+      }
+    }
+    __syncthreads();
+
+    // Lower warps update the output.
+    if (warp_idx < mid) {
+      const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          accs[i] += src[row_idx];
+        }
+      }
+    }
+    __syncthreads();
+  }
+
+  // Write the final output.
+  if (warp_idx == 0) {
+    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * 
HEAD_SIZE;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+        from_float(*(out_ptr + row_idx), accs[i]);
+      }
+    }
+  }
+}
+
+}  // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)         
               \
+  cudaFuncSetAttribute(                                                        
               \
+      vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, 
NUM_THREADS>,   \
+      cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);           
               \
+  vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, 
NUM_THREADS>        \
+  <<<grid, block, shared_mem_size>>>(                                          
       \
+    out_ptr,                                                                   
               \
+    query_ptr,                                                                 
               \
+    key_cache_ptr,                                                             
               \
+    value_cache_ptr,                                                           
               \
+    head_mapping_ptr,                                                          
               \
+    scale,                                                                     
               \
+    block_tables_ptr,                                                          
               \
+    context_lens_ptr,                                                          
               \
+    max_num_blocks_per_seq,                                                    
               \
+    alibi_slopes_ptr,                                                          
               \
+    q_stride,                                                                  
               \
+    kv_block_stride,                                                           
               \
+    kv_head_stride);
+
+
+template<
+  typename T,
+  int BLOCK_SIZE,
+  int NUM_THREADS = 128>
+void single_query_cached_kv_attention_launcher(
+                       DLTensor* out,
+                       const DLTensor* query,
+                       const DLTensor* key_cache,
+                       const DLTensor* value_cache,
+                       const DLTensor* head_mapping,
+                       float scale,
+                       const DLTensor* block_tables,
+                       const DLTensor* context_lens,
+                       int max_context_len) {
+  int num_seqs = query->shape[0];
+  int num_heads = query->shape[1];
+  int head_size = query->shape[2];
+  int max_num_blocks_per_seq = block_tables->shape[1];
+  int q_stride = query->shape[1] * query->shape[2];
+
+  int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] * 
key_cache->shape[4];
+  int kv_block_stride = kv_head_stride * key_cache->shape[1];
+
+  // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  // assert(head_size % thread_group_size == 0);
+  const float* alibi_slopes_ptr = nullptr;
+
+  T* out_ptr = static_cast<T*>(out->data);
+  T* query_ptr = static_cast<T*>(query->data);
+  T* key_cache_ptr = static_cast<T*>(key_cache->data);
+  T* value_cache_ptr = static_cast<T*>(value_cache->data);
+  int* head_mapping_ptr = static_cast<int*>(head_mapping->data);
+  int* block_tables_ptr = static_cast<int*>(block_tables->data);
+  int* context_lens_ptr = static_cast<int*>(context_lens->data);
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / 
BLOCK_SIZE) * BLOCK_SIZE;
+  int logits_size = padded_max_context_len * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  int shared_mem_size = std::max(logits_size, outputs_size);
+
+  dim3 grid(num_heads, num_seqs);
+  dim3 block(NUM_THREADS);
+  switch (head_size) {
+    case 64:
+      LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
+      break;
+    case 80:
+      LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
+      break;
+    case 96:
+      LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
+      break;
+    case 112:
+      LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
+      break;
+    case 128:
+      LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
+      break;
+    case 256:
+      LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
+      break;
+    default:
+      // TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_KERNEL_LAUNCHER(BLOCK_SIZE)                           \
+  single_query_cached_kv_attention_launcher<uint16_t, BLOCK_SIZE>( \
+    out,                                                            \
+    query,                                                          \
+    key_cache,                                                      \
+    value_cache,                                                    \
+    head_mapping,                                                   \
+    scale,                                                          \
+    block_tables,                                                   \
+    context_lens,                                                   \
+    max_context_len);
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention")
+    .set_body_typed([](const DLTensor* query,
+                       const DLTensor* key_cache,
+                       const DLTensor* value_cache,
+                       const DLTensor* head_mapping,
+                       const DLTensor* block_tables,
+                       const DLTensor* context_lens,
+                       int block_size,
+                       const DLTensor* max_context_len_tensor,  // 
TODO(masahi): pass integer
+                       DLTensor* out) {
+        float scale = 1.0 / sqrt(query->shape[2]);
+        int max_context_len = 
static_cast<int*>(max_context_len_tensor->data)[0];
+
+        if (block_size == 8) {
+          CALL_KERNEL_LAUNCHER(8);
+        } else if (block_size == 16) {
+          CALL_KERNEL_LAUNCHER(16);
+        } else if (block_size == 32) {
+          CALL_KERNEL_LAUNCHER(32);
+        } else {
+          LOG(FATAL) << "Unsupported block size: " << block_size;
+        }
+    });
+}  // namespace runtime
+}  // namespace tvm
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
diff --git a/src/runtime/contrib/vllm/cache_alloc.cc 
b/src/runtime/contrib/vllm/cache_alloc.cc
new file mode 100644
index 0000000000..aea50aa47a
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_alloc.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cuda_runtime.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace runtime {
+namespace vllm {
+
+Array<NDArray> AllocateKVCache(int head_size, int num_layers, int num_heads, 
int block_size,
+                               int num_blocks) {
+  Array<NDArray> cache;
+  int element_size = 2;
+  int vec_size = 16 / element_size;
+
+  int device_id;
+  cudaGetDevice(&device_id);
+
+  DLDevice dev{DLDeviceType::kDLCUDA, device_id};
+
+  for (int i = 0; i < num_layers; ++i) {
+    NDArray key_blocks =
+        NDArray::Empty({num_blocks, num_heads, head_size / vec_size, 
block_size, vec_size},
+                       runtime::DataType::Float(16), dev);
+    NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size, 
block_size},
+                                          runtime::DataType::Float(16), dev);
+    cache.push_back(key_blocks);
+    cache.push_back(value_blocks);
+  }
+
+  return cache;
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache);
+
+}  // namespace vllm
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/vllm/cache_kernels.cu 
b/src/runtime/contrib/vllm/cache_kernels.cu
new file mode 100644
index 0000000000..29ab9bfa2e
--- /dev/null
+++ b/src/runtime/contrib/vllm/cache_kernels.cu
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <vector>
+
+namespace vllm {
+
+template<typename scalar_t>
+__global__ void reshape_and_cache_kernel(
+  const scalar_t* __restrict__ key,      // [num_tokens, num_heads, head_size]
+  const scalar_t* __restrict__ value,    // [num_tokens, num_heads, head_size]
+  scalar_t* __restrict__ key_cache,      // [num_blocks, num_heads, 
head_size/x, block_size, x]
+  scalar_t* __restrict__ value_cache,    // [num_blocks, num_heads, head_size, 
block_size]
+  const int* __restrict__ slot_mapping,  // [num_tokens]
+  const int key_stride,
+  const int value_stride,
+  const int num_heads,
+  const int head_size,
+  const int block_size,
+  const int x) {
+  const int token_idx = blockIdx.x;
+  const int slot_idx = slot_mapping[token_idx];
+  const int block_idx = slot_idx / block_size;
+  const int block_offset = slot_idx % block_size;
+
+  const int n = num_heads * head_size;
+  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+    const int src_key_idx = token_idx * key_stride + i;
+    const int src_value_idx = token_idx * value_stride + i;
+
+    const int head_idx = i / head_size;
+    const int head_offset = i % head_size;
+    const int x_idx = head_offset / x;
+    const int x_offset = head_offset % x;
+
+    const int tgt_key_idx = block_idx * num_heads * (head_size / x) * 
block_size * x
+                            + head_idx * (head_size / x) * block_size * x
+                            + x_idx * block_size * x
+                            + block_offset * x
+                            + x_offset;
+    const int tgt_value_idx = block_idx * num_heads * head_size * block_size
+                              + head_idx * head_size * block_size
+                              + head_offset * block_size
+                              + block_offset;
+    key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
+    value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
+  }
+}
+
+}  // namespace vllm
+
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
+    .set_body_typed([](NDArray key, NDArray value, NDArray key_cache,
+                       NDArray value_cache, NDArray slot_mapping) {
+      int num_tokens = key->shape[0];
+      int num_heads = key->shape[1];
+      int head_size = key->shape[2];
+      int block_size = key_cache->shape[3];
+      int vec_size = key_cache->shape[4];
+
+      int key_stride = key->shape[1] * key->shape[2];
+      int value_stride = value->shape[1] * value->shape[2];
+
+      dim3 grid(num_tokens);
+      dim3 block(std::min(num_heads * head_size, 512));
+
+      using scalar_t = uint16_t;
+      vllm::reshape_and_cache_kernel<scalar_t><<<grid, block>>>(
+        static_cast<const scalar_t*>(key->data),
+        static_cast<const scalar_t*>(value->data),
+        static_cast<scalar_t*>(key_cache->data),
+        static_cast<scalar_t*>(value_cache->data),
+        static_cast<const int*>(slot_mapping->data),
+        key_stride,
+        value_stride,
+        num_heads,
+        head_size,
+        block_size,
+        vec_size);
+
+      return Array{key_cache, value_cache};
+    });
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/vllm/dtype_float16.h 
b/src/runtime/contrib/vllm/dtype_float16.h
new file mode 100644
index 0000000000..e16c10468b
--- /dev/null
+++ b/src/runtime/contrib/vllm/dtype_float16.h
@@ -0,0 +1,688 @@
+/*
+ * Adapted from
+ * 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and
+ * 
https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <stdint.h>
+
+namespace vllm {
+
+// A vector type to store Q, K, V elements.
+template <typename T, int VEC_SIZE>
+struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template <typename T>
+struct FloatVec {};
+
+// Template vector operations.
+template <typename Acc, typename A, typename B>
+inline __device__ Acc mul(A a, B b);
+
+template <typename T>
+inline __device__ float sum(T v);
+
+template <typename T>
+inline __device__ float dot(T a, T b) {
+  return sum(mul<T, T, T>(a, b));
+}
+
+template <typename A, typename T>
+inline __device__ float dot(T a, T b) {
+  return sum(mul<A, T, T>(a, b));
+}
+
+template <typename T>
+inline __device__ void zero(T& dst) {
+  constexpr int WORDS = sizeof(T) / 4;
+  union {
+    T raw;
+    uint32_t words[WORDS];
+  } tmp;
+
+#pragma unroll
+  for (int ii = 0; ii < WORDS; ++ii) {
+    tmp.words[ii] = 0u;
+  }
+  dst = tmp.raw;
+}
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+  float2 x;
+  float2 y;
+};
+
+struct Float8_ {
+  float2 x;
+  float2 y;
+  float2 z;
+  float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template <>
+struct Vec<float, 1> {
+  using Type = float;
+};
+template <>
+struct Vec<float, 2> {
+  using Type = float2;
+};
+template <>
+struct Vec<float, 4> {
+  using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<float> {
+  using Type = float;
+};
+template <>
+struct FloatVec<float2> {
+  using Type = float2;
+};
+template <>
+struct FloatVec<float4> {
+  using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) { return a + b; }
+
+inline __device__ float2 add(float2 a, float2 b) {
+  float2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+  float4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ float mul<float, float>(float a, float b) {
+  return a * b;
+}
+
+template <>
+inline __device__ float2 mul(float2 a, float2 b) {
+  float2 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  return c;
+}
+
+template <>
+inline __device__ float2 mul(float a, float2 b) {
+  float2 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  return c;
+}
+
+template <>
+inline __device__ float4 mul(float4 a, float4 b) {
+  float4 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  c.z = a.z * b.z;
+  c.w = a.w * b.w;
+  return c;
+}
+
+template <>
+inline __device__ float4 mul(float a, float4 b) {
+  float4 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  c.z = a * b.z;
+  c.w = a * b.w;
+  return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) { return a * b + c; }
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+  Float4_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+  Float8_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(float v) {
+  return v;
+}
+
+template <>
+inline __device__ float sum(float2 v) {
+  return v.x + v.y;
+}
+
+template <>
+inline __device__ float sum(float4 v) {
+  return v.x + v.y + v.z + v.w;
+}
+
+template <>
+inline __device__ float sum(Float4_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template <>
+inline __device__ float sum(Float8_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) { return a * b; }
+
+inline __device__ float dot(float2 a, float2 b) {
+  float2 c = mul<float2, float2, float2>(a, b);
+  return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  acc = fma(a.z, b.z, acc);
+  acc = fma(a.w, b.w, acc);
+  return acc.x + acc.y;
+}
+
+// From float to float.
+inline __device__ void from_float(float& dst, float src) { dst = src; }
+
+inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
+
+inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
+
+// From float to float.
+inline __device__ float to_float(float u) { return u; }
+
+inline __device__ float2 to_float(float2 u) { return u; }
+
+inline __device__ float4 to_float(float4 u) { return u; }
+
+inline __device__ Float4_ to_float(Float4_ u) { return u; }
+
+inline __device__ Float8_ to_float(Float8_ u) { return u; }
+
+// FP16 vector types for Q, K, V.
+template <>
+struct Vec<uint16_t, 1> {
+  using Type = uint16_t;
+};
+template <>
+struct Vec<uint16_t, 2> {
+  using Type = uint32_t;
+};
+template <>
+struct Vec<uint16_t, 4> {
+  using Type = uint2;
+};
+template <>
+struct Vec<uint16_t, 8> {
+  using Type = uint4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template <>
+struct FloatVec<uint16_t> {
+  using Type = float;
+};
+template <>
+struct FloatVec<uint32_t> {
+  using Type = float2;
+};
+template <>
+struct FloatVec<uint2> {
+  using Type = Float4_;
+};
+template <>
+struct FloatVec<uint4> {
+  using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ uint32_t h0_h0(uint16_t a) {
+  uint32_t b;
+  asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+  return b;
+}
+
+inline __device__ float half_to_float(uint16_t h) {
+  float f;
+  asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+  return f;
+}
+
+inline __device__ float2 half2_to_float2(uint32_t v) {
+  uint16_t lo, hi;
+  asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+  return make_float2(half_to_float(lo), half_to_float(hi));
+}
+
+inline __device__ uint16_t float_to_half(float f) {
+  union {
+    uint32_t u32;
+    uint16_t u16[2];
+  } tmp;
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+  return tmp.u16[0];
+}
+
+inline __device__ uint32_t float2_to_half2(float2 f) {
+  union {
+    uint32_t u32;
+    uint16_t u16[2];
+  } tmp;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), 
"f"(f.x));
+#else
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+#endif
+  return tmp.u32;
+}
+
+// Vector addition.
+inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+  uint16_t c;
+  asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+  return c;
+}
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+  uint32_t c;
+  asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+  return c;
+}
+
+inline __device__ uint2 add(uint2 a, uint2 b) {
+  uint2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ uint4 add(uint4 a, uint4 b) {
+  uint4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+inline __device__ float2 add(uint32_t a, float2 fb) {
+  float2 fa = half2_to_float2(a);
+  return add(fa, fb);
+}
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb) {
+  Float4_ fc;
+  fc.x = add(a.x, fb.x);
+  fc.y = add(a.y, fb.y);
+  return fc;
+}
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb) {
+  Float8_ fc;
+  fc.x = add(a.x, fb.x);
+  fc.y = add(a.y, fb.y);
+  fc.z = add(a.z, fb.z);
+  fc.w = add(a.w, fb.w);
+  return fc;
+}
+
+// Vector multiplication.
+template <>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+  uint16_t c;
+  asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+  return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+  uint32_t c;
+  asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+  return c;
+}
+
+template <>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
+  return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ uint2 mul(uint2 a, uint2 b) {
+  uint2 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+  return c;
+}
+
+template <>
+inline __device__ uint2 mul(uint16_t a, uint2 b) {
+  uint32_t s = h0_h0(a);
+  uint2 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+  return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint4 a, uint4 b) {
+  uint4 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+  c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
+  c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
+  return c;
+}
+
+template <>
+inline __device__ uint4 mul(uint16_t a, uint4 b) {
+  uint32_t s = h0_h0(a);
+  uint4 c;
+  c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+  c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+  c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
+  c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
+  return c;
+}
+
+template <>
+inline __device__ float mul(uint16_t a, uint16_t b) {
+  float fa = half_to_float(a);
+  float fb = half_to_float(b);
+  return fa * fb;
+}
+
+template <>
+inline __device__ float2 mul(uint32_t a, uint32_t b) {
+  float2 fa = half2_to_float2(a);
+  float2 fb = half2_to_float2(b);
+  return mul<float2, float2, float2>(fa, fb);
+}
+
+template <>
+inline __device__ float2 mul(uint16_t a, uint32_t b) {
+  return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+template <>
+inline __device__ Float4_ mul(uint2 a, uint2 b) {
+  Float4_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+  return fc;
+}
+
+template <>
+inline __device__ Float4_ mul(uint16_t a, uint2 b) {
+  uint32_t s = h0_h0(a);
+  Float4_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+  return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint4 a, uint4 b) {
+  Float8_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+  fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
+  fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
+  return fc;
+}
+
+template <>
+inline __device__ Float8_ mul(uint16_t a, uint4 b) {
+  uint32_t s = h0_h0(a);
+  Float8_ fc;
+  fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+  fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+  fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
+  fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
+  return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+  uint32_t d;
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), 
"r"(c));
+  return d;
+}
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return 
fma(h0_h0(a), b, c); }
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
+  uint2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
+  uint32_t s = h0_h0(a);
+  uint2 d;
+  d.x = fma(s, b.x, c.x);
+  d.y = fma(s, b.y, c.y);
+  return d;
+}
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
+  uint4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
+  uint32_t s = h0_h0(a);
+  uint4 d;
+  d.x = fma(s, b.x, c.x);
+  d.y = fma(s, b.y, c.y);
+  d.z = fma(s, b.z, c.z);
+  d.w = fma(s, b.w, c.w);
+  return d;
+}
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
+  float fa = half_to_float(a);
+  float fb = half_to_float(b);
+  return fa * fb + fc;
+}
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
+  float2 fa = half2_to_float2(a);
+  float2 fb = half2_to_float2(b);
+  return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return 
fma(h0_h0(a), b, fc); }
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
+  Float4_ fd;
+  fd.x = fma(a.x, b.x, fc.x);
+  fd.y = fma(a.y, b.y, fc.y);
+  return fd;
+}
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
+  uint32_t s = h0_h0(a);
+  Float4_ fd;
+  fd.x = fma(s, b.x, fc.x);
+  fd.y = fma(s, b.y, fc.y);
+  return fd;
+}
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
+  Float8_ fd;
+  fd.x = fma(a.x, b.x, fc.x);
+  fd.y = fma(a.y, b.y, fc.y);
+  fd.z = fma(a.z, b.z, fc.z);
+  fd.w = fma(a.w, b.w, fc.w);
+  return fd;
+}
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
+  uint32_t s = h0_h0(a);
+  Float8_ fd;
+  fd.x = fma(s, b.x, fc.x);
+  fd.y = fma(s, b.y, fc.y);
+  fd.z = fma(s, b.z, fc.z);
+  fd.w = fma(s, b.w, fc.w);
+  return fd;
+}
+
+// Vector sum.
+template <>
+inline __device__ float sum(uint16_t v) {
+  return half_to_float(v);
+}
+
+template <>
+inline __device__ float sum(uint32_t v) {
+  float2 tmp = half2_to_float2(v);
+  return tmp.x + tmp.y;
+}
+
+template <>
+inline __device__ float sum(uint2 v) {
+  uint32_t c = add(v.x, v.y);
+  return sum(c);
+}
+
+template <>
+inline __device__ float sum(uint4 v) {
+  uint32_t c = add(v.x, v.y);
+  c = add(c, v.z);
+  c = add(c, v.w);
+  return sum(c);
+}
+
+// From float32 to float16.
+inline __device__ void from_float(uint16_t& dst, float src) { dst = 
float_to_half(src); }
+
+inline __device__ void from_float(uint32_t& dst, float2 src) { dst = 
float2_to_half2(src); }
+
+inline __device__ void from_float(uint2& dst, Float4_ src) {
+  dst.x = float2_to_half2(src.x);
+  dst.y = float2_to_half2(src.y);
+}
+
+inline __device__ void from_float(uint4& dst, Float8_ src) {
+  dst.x = float2_to_half2(src.x);
+  dst.y = float2_to_half2(src.y);
+  dst.z = float2_to_half2(src.z);
+  dst.w = float2_to_half2(src.w);
+}
+
+// From float16 to float32.
+inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
+
+inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
+
+inline __device__ Float4_ to_float(uint2 u) {
+  Float4_ tmp;
+  tmp.x = half2_to_float2(u.x);
+  tmp.y = half2_to_float2(u.y);
+  return tmp;
+}
+
+inline __device__ Float8_ to_float(uint4 u) {
+  Float8_ tmp;
+  tmp.x = half2_to_float2(u.x);
+  tmp.y = half2_to_float2(u.y);
+  tmp.z = half2_to_float2(u.z);
+  tmp.w = half2_to_float2(u.w);
+  return tmp;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
+
+}  // namespace vllm
diff --git a/tests/python/relax/test_contrib_vllm.py 
b/tests/python/relax/test_contrib_vllm.py
new file mode 100644
index 0000000000..7ca60bbaca
--- /dev/null
+++ b/tests/python/relax/test_contrib_vllm.py
@@ -0,0 +1,695 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm.testing
+import tvm
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+
+
+has_vllm = 
tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True)
+
+vllm_enabled = pytest.mark.skipif(
+    not has_vllm,
+    reason="VLLM not enabled.",
+)
+
+pytestmark = [vllm_enabled]
+
+
+def build_and_run(mod, inputs_np, target, legalize=True):
+    if legalize:
+        mod = relax.transform.LegalizeOps()(mod)
+
+        with tvm.target.Target("cuda"):
+            mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+    with tvm.transform.PassContext():
+        ex = relax.build(mod, target)
+
+    dev = tvm.device(target, 0)
+    vm = relax.VirtualMachine(ex, dev)
+    f = vm["main"]
+    inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+
+    out = f(*inputs)
+
+    if isinstance(out, tvm.ir.container.Array):
+        return [arr.numpy() for arr in out]
+
+    return out.numpy()
+
+
+def test_attention():
+    @I.ir_module
+    class Module:
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            query: R.Tensor(("num_seqs", 1, 64), dtype="float16"),
+            key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+            value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"),
+            head_mapping: R.Tensor((1,), dtype="int32"),
+            block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"), 
dtype="int32"),
+            context_lens: R.Tensor(("num_seqs",), dtype="int32"),
+        ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"):
+            with R.dataflow():
+                max_len = R.to_vdevice(R.max(context_lens), "llvm:0")
+                out = R.call_dps_packed(
+                    "tvm.contrib.vllm.single_query_cached_kv_attention",
+                    [
+                        query,
+                        key_cache,
+                        value_cache,
+                        head_mapping,
+                        block_tables,
+                        context_lens,
+                        16,
+                        max_len,
+                    ],
+                    out_sinfo=query.struct_info,
+                )
+                R.output(out)
+            return out
+
+    np.random.seed(0)
+    num_heads = 1
+    head_dim = 64
+    vec_size = 8
+    block_size = 16
+    num_seqs = 2
+    num_blocks = 1
+    query = np.random.randn(num_seqs, num_heads, head_dim).astype("float16")
+    key_cache = np.random.randn(
+        num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+    ).astype("float16")
+    value_cache = np.random.randn(num_blocks, num_heads, head_dim, 
block_size).astype("float16")
+    block_tables = np.array([[0], [0]]).astype("int32")
+    head_mapping = np.array([0]).astype("int32")
+    context_lens = np.array([3, 5]).astype("int32")
+
+    out = build_and_run(
+        Module,
+        [query, key_cache, value_cache, head_mapping, block_tables, 
context_lens],
+        "cuda",
+        legalize=True,
+    )
+
+    ref = np.array(
+        [
+            [
+                [
+                    0.28271484375,
+                    0.197021484375,
+                    -0.278564453125,
+                    0.444580078125,
+                    -0.47802734375,
+                    -0.7548828125,
+                    -0.84228515625,
+                    -0.80322265625,
+                    0.478759765625,
+                    0.195068359375,
+                    -0.59521484375,
+                    0.779296875,
+                    0.35888671875,
+                    -0.158935546875,
+                    -0.6103515625,
+                    0.188720703125,
+                    0.410400390625,
+                    0.28662109375,
+                    0.40283203125,
+                    -1.23046875,
+                    -0.01043701171875,
+                    -0.0794677734375,
+                    -0.0350341796875,
+                    0.12005615234375,
+                    0.63671875,
+                    0.368896484375,
+                    -0.58642578125,
+                    -0.34228515625,
+                    -0.552734375,
+                    0.947265625,
+                    -0.079833984375,
+                    0.85302734375,
+                    0.1947021484375,
+                    0.16748046875,
+                    -0.083984375,
+                    -0.75244140625,
+                    -0.568359375,
+                    -1.45703125,
+                    -1.021484375,
+                    -0.2235107421875,
+                    -0.98828125,
+                    -0.87109375,
+                    -0.43359375,
+                    -0.3271484375,
+                    0.0557861328125,
+                    -0.269287109375,
+                    -1.009765625,
+                    0.1387939453125,
+                    -0.0831298828125,
+                    0.27978515625,
+                    -0.9736328125,
+                    0.7802734375,
+                    -0.1329345703125,
+                    -0.5927734375,
+                    -1.6923828125,
+                    1.1904296875,
+                    -1.3759765625,
+                    -1.080078125,
+                    -0.53173828125,
+                    0.28466796875,
+                    -2.02734375,
+                    -0.377685546875,
+                    -0.81201171875,
+                    -0.7412109375,
+                ]
+            ],
+            [
+                [
+                    0.482177734375,
+                    0.114501953125,
+                    -0.265869140625,
+                    -1.154296875,
+                    0.28857421875,
+                    0.71240234375,
+                    -1.1767578125,
+                    0.187744140625,
+                    -0.23486328125,
+                    0.07135009765625,
+                    -0.34521484375,
+                    0.444091796875,
+                    -0.09130859375,
+                    0.900390625,
+                    -0.043701171875,
+                    0.61279296875,
+                    0.1201171875,
+                    0.443603515625,
+                    -0.4150390625,
+                    -0.9560546875,
+                    -0.1917724609375,
+                    0.0863037109375,
+                    0.267578125,
+                    0.04931640625,
+                    -0.32666015625,
+                    0.5859375,
+                    -0.57421875,
+                    0.29541015625,
+                    -0.26220703125,
+                    1.177734375,
+                    0.11309814453125,
+                    0.81201171875,
+                    0.346435546875,
+                    0.53271484375,
+                    -0.0009765625,
+                    -0.35205078125,
+                    -0.1298828125,
+                    -1.2431640625,
+                    -0.2196044921875,
+                    0.31640625,
+                    -0.40869140625,
+                    0.25244140625,
+                    -0.9853515625,
+                    0.284912109375,
+                    0.399169921875,
+                    -1.1435546875,
+                    0.305419921875,
+                    0.300048828125,
+                    -0.84521484375,
+                    -0.5166015625,
+                    -0.787109375,
+                    0.1011962890625,
+                    -1.0302734375,
+                    -1.35546875,
+                    -0.0556640625,
+                    1.0791015625,
+                    -0.047607421875,
+                    -0.498046875,
+                    -0.055999755859375,
+                    -0.35009765625,
+                    -1.4296875,
+                    0.350341796875,
+                    -1.16796875,
+                    -0.576171875,
+                ]
+            ],
+        ]
+    ).astype("float16")
+
+    # from vllm import attention_ops
+    # import torch
+    #
+    # def to_torch(arr):
+    #     return torch.from_numpy(arr).to("cuda")
+    #
+    # ref = to_torch(np.zeros_like(query))
+    # attention_ops.single_query_cached_kv_attention(
+    #     ref,
+    #     to_torch(query),
+    #     to_torch(key_cache),
+    #     to_torch(value_cache),
+    #     to_torch(head_mapping),
+    #     query.shape[-1] ** -0.5,  # scale
+    #     to_torch(block_tables),
+    #     to_torch(context_lens),
+    #     value_cache.shape[-1],  # block_size,
+    #     np.max(context_lens),
+    #     None,
+    #     )
+    # ref = ref.cpu().numpy()
+
+    # print(ref.tolist())
+
+    assert np.max(np.abs(ref - out)) == 0.0
+
+
+def test_cache():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            key: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+            value: R.Tensor(("num_tokens", 1, 8), dtype="float16"),
+            key_cache: R.Tensor(("num_blocks", 1, 1, 16, 8), dtype="float16"),
+            value_cache: R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+            slot_mapping: R.Tensor(("num_tokens",), dtype="int32"),
+        ) -> R.Tuple(
+            [
+                R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"),
+                R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"),
+            ]
+        ):
+            with R.dataflow():
+                kv = R.call_pure_packed(
+                    "tvm.contrib.vllm.reshape_and_cache",
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    slot_mapping,
+                    sinfo_args=[key_cache.struct_info, 
value_cache.struct_info],
+                )
+                out = (kv[0], kv[1])
+                R.output(out)
+            return out
+
+    np.random.seed(0)
+    num_heads = 1
+    head_dim = 8
+    vec_size = 8
+    block_size = 16
+    num_tokens = 8
+    num_blocks = 1
+    key = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+    value = np.random.randn(num_tokens, num_heads, head_dim).astype("float16")
+    key_cache_before = np.random.randn(
+        num_blocks, num_heads, head_dim // vec_size, block_size, vec_size
+    ).astype("float16")
+    value_cache_before = np.random.randn(num_blocks, num_heads, head_dim, 
block_size).astype(
+        "float16"
+    )
+    slot_mapping = np.arange(num_tokens).astype("int32")
+
+    key_cache = key_cache_before.copy()
+    value_cache = value_cache_before.copy()
+
+    out_key_cache, out_value_cache = build_and_run(
+        Module,
+        [key, value, key_cache, value_cache, slot_mapping],
+        "cuda",
+    )
+
+    ref_key_cache = np.array(
+        [
+            [
+                [
+                    [
+                        [
+                            1.763671875,
+                            0.400146484375,
+                            0.978515625,
+                            2.240234375,
+                            1.8671875,
+                            -0.97705078125,
+                            0.9501953125,
+                            -0.1513671875,
+                        ],
+                        [
+                            -0.10321044921875,
+                            0.41064453125,
+                            0.14404296875,
+                            1.4541015625,
+                            0.76123046875,
+                            0.1217041015625,
+                            0.44384765625,
+                            0.333740234375,
+                        ],
+                        [
+                            1.494140625,
+                            -0.2052001953125,
+                            0.31298828125,
+                            -0.85400390625,
+                            -2.552734375,
+                            0.65380859375,
+                            0.8642578125,
+                            -0.7421875,
+                        ],
+                        [
+                            2.26953125,
+                            -1.4541015625,
+                            0.045745849609375,
+                            -0.1871337890625,
+                            1.533203125,
+                            1.4697265625,
+                            0.1549072265625,
+                            0.378173828125,
+                        ],
+                        [
+                            -0.8876953125,
+                            -1.98046875,
+                            -0.347900390625,
+                            0.1563720703125,
+                            1.23046875,
+                            1.2021484375,
+                            -0.38720703125,
+                            -0.30224609375,
+                        ],
+                        [
+                            -1.048828125,
+                            -1.419921875,
+                            -1.7060546875,
+                            1.951171875,
+                            -0.509765625,
+                            -0.43798828125,
+                            -1.2529296875,
+                            0.77734375,
+                        ],
+                        [
+                            -1.6142578125,
+                            -0.2127685546875,
+                            -0.8955078125,
+                            0.386962890625,
+                            -0.5107421875,
+                            -1.1806640625,
+                            -0.0281829833984375,
+                            0.42822265625,
+                        ],
+                        [
+                            0.0665283203125,
+                            0.302490234375,
+                            -0.63427734375,
+                            -0.36279296875,
+                            -0.67236328125,
+                            -0.359619140625,
+                            -0.81298828125,
+                            -1.7265625,
+                        ],
+                        [
+                            -0.039276123046875,
+                            -1.16796875,
+                            0.5234375,
+                            -0.1715087890625,
+                            0.77197265625,
+                            0.82373046875,
+                            2.1640625,
+                            1.3369140625,
+                        ],
+                        [
+                            -0.369140625,
+                            -0.2393798828125,
+                            1.099609375,
+                            0.6552734375,
+                            0.64013671875,
+                            -1.6171875,
+                            -0.024322509765625,
+                            -0.73779296875,
+                        ],
+                        [
+                            0.280029296875,
+                            -0.09814453125,
+                            0.91015625,
+                            0.317138671875,
+                            0.7861328125,
+                            -0.46630859375,
+                            -0.9443359375,
+                            -0.41015625,
+                        ],
+                        [
+                            -0.0170135498046875,
+                            0.379150390625,
+                            2.259765625,
+                            -0.042266845703125,
+                            -0.9560546875,
+                            -0.345947265625,
+                            -0.463623046875,
+                            0.4814453125,
+                        ],
+                        [
+                            -1.541015625,
+                            0.063232421875,
+                            0.156494140625,
+                            0.232177734375,
+                            -0.59716796875,
+                            -0.2379150390625,
+                            -1.423828125,
+                            -0.493408203125,
+                        ],
+                        [
+                            -0.54296875,
+                            0.416015625,
+                            -1.15625,
+                            0.78125,
+                            1.494140625,
+                            -2.0703125,
+                            0.42626953125,
+                            0.6767578125,
+                        ],
+                        [
+                            -0.63720703125,
+                            -0.397216796875,
+                            -0.1329345703125,
+                            -0.2978515625,
+                            -0.30908203125,
+                            -1.67578125,
+                            1.15234375,
+                            1.080078125,
+                        ],
+                        [
+                            -0.8134765625,
+                            -1.466796875,
+                            0.52099609375,
+                            -0.57568359375,
+                            0.1419677734375,
+                            -0.3193359375,
+                            0.69140625,
+                            0.69482421875,
+                        ],
+                    ]
+                ]
+            ]
+        ]
+    ).astype("float16")
+
+    ref_value_cache = np.array(
+        [
+            [
+                [
+                    [
+                        0.1773681640625,
+                        1.1396484375,
+                        -1.1650390625,
+                        -1.0703125,
+                        0.010498046875,
+                        -1.1728515625,
+                        -0.861328125,
+                        0.37646484375,
+                        -1.9365234375,
+                        0.188720703125,
+                        0.52392578125,
+                        0.08843994140625,
+                        -0.310791015625,
+                        0.097412109375,
+                        0.39892578125,
+                        -2.7734375,
+                    ],
+                    [
+                        -0.40185546875,
+                        -1.234375,
+                        0.90087890625,
+                        1.0546875,
+                        1.7861328125,
+                        1.943359375,
+                        1.91015625,
+                        -1.099609375,
+                        -0.11053466796875,
+                        1.0205078125,
+                        -0.69189453125,
+                        1.5361328125,
+                        0.286376953125,
+                        0.60888671875,
+                        -1.044921875,
+                        1.2109375,
+                    ],
+                    [
+                        -1.6298828125,
+                        0.40234375,
+                        0.465576171875,
+                        -0.403076171875,
+                        0.126953125,
+                        -0.41357421875,
+                        -0.26806640625,
+                        0.29833984375,
+                        0.09771728515625,
+                        0.5830078125,
+                        -0.3994140625,
+                        0.3701171875,
+                        -1.306640625,
+                        1.658203125,
+                        -0.1181640625,
+                        -0.68017578125,
+                    ],
+                    [
+                        0.462890625,
+                        -0.6845703125,
+                        -1.5361328125,
+                        1.22265625,
+                        0.402099609375,
+                        -0.74755859375,
+                        0.80224609375,
+                        1.326171875,
+                        -1.126953125,
+                        -0.73046875,
+                        -0.384765625,
+                        0.0943603515625,
+                        -0.04217529296875,
+                        -0.286865234375,
+                        -0.061614990234375,
+                        -0.1072998046875,
+                    ],
+                    [
+                        -0.9072265625,
+                        -0.87060546875,
+                        1.48828125,
+                        0.208251953125,
+                        1.8828125,
+                        1.9228515625,
+                        0.947265625,
+                        -0.6943359375,
+                        -0.70458984375,
+                        0.943359375,
+                        0.7470703125,
+                        -1.1884765625,
+                        0.7734375,
+                        -1.18359375,
+                        -2.658203125,
+                        0.6064453125,
+                    ],
+                    [
+                        0.05194091796875,
+                        -0.57861328125,
+                        1.8955078125,
+                        0.9765625,
+                        -1.34765625,
+                        1.48046875,
+                        -0.155029296875,
+                        -0.149658203125,
+                        -0.44091796875,
+                        -0.2802734375,
+                        -0.36474609375,
+                        0.15673828125,
+                        0.57861328125,
+                        0.349609375,
+                        -0.76416015625,
+                        -1.4375,
+                    ],
+                    [
+                        0.72900390625,
+                        -0.3115234375,
+                        1.1787109375,
+                        0.3564453125,
+                        -1.2705078125,
+                        1.8671875,
+                        0.6142578125,
+                        -0.43505859375,
+                        0.6982421875,
+                        0.0037708282470703125,
+                        0.931640625,
+                        0.33984375,
+                        -0.01568603515625,
+                        0.160888671875,
+                        -0.190673828125,
+                        -0.394775390625,
+                    ],
+                    [
+                        0.1290283203125,
+                        0.05615234375,
+                        -0.179931640625,
+                        0.70654296875,
+                        0.96923828125,
+                        0.90625,
+                        0.92236328125,
+                        1.849609375,
+                        0.6435546875,
+                        -1.5703125,
+                        -0.2069091796875,
+                        0.88037109375,
+                        -1.6982421875,
+                        0.38720703125,
+                        -2.255859375,
+                        -1.0224609375,
+                    ],
+                ]
+            ]
+        ]
+    ).astype("float16")
+
+    # from vllm import cache_ops
+    # import torch
+
+    # def to_torch(arr):
+    #     return torch.from_numpy(arr).to("cuda")
+
+    # ref_key_cache = to_torch(key_cache_before.copy())
+    # ref_value_cache = to_torch(value_cache_before.copy())
+
+    # cache_ops.reshape_and_cache(
+    #     to_torch(key),
+    #     to_torch(value),
+    #     ref_key_cache,
+    #     ref_value_cache,
+    #     to_torch(slot_mapping),
+    # )
+
+    # ref_key_cache = ref_key_cache.cpu().numpy()
+    # ref_value_cache = ref_value_cache.cpu().numpy()
+
+    assert np.max(np.abs(out_key_cache - ref_key_cache)) == 0
+    assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to