This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6333d86105 [KVCache] Support mode "None" for Rotary Embebdding (#16580)
6333d86105 is described below
commit 6333d86105e28435070976058f7830a15a751fa3
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 16 07:54:18 2024 -0500
[KVCache] Support mode "None" for Rotary Embebdding (#16580)
This PR supports a "None" Rotary Embedding mode in
PagedKVCache. When the mode is None, the rotary embedding
will not be applied to when computing attention.
---
src/runtime/relax_vm/paged_kv_cache.cc | 6 ++-
..._builtin_paged_attention_kv_cache_flashinfer.py | 45 +++++++++++++++-------
...runtime_builtin_paged_attention_kv_cache_tir.py | 43 ++++++++++++++++-----
3 files changed, 69 insertions(+), 25 deletions(-)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 7417d90e02..d5ddef7527 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -153,12 +153,14 @@ struct Sequence {
/*!
* \brief The rotary embedding mode adopted by the paged KV cache
* when computing attention.
+ * "None" means RoPE is never applied to q and k.
* "Normal" means RoPE is computed in a standalone kernel.
* "Inline" means RoPE is computed on-the-fly in attention kernels.
*/
enum class RoPEMode : int {
- kNormal = 0,
- kInline = 1,
+ kNone = 0,
+ kNormal = 1,
+ kInline = 2,
};
/*!
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 8a40f43ab7..fccef312c1 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import enum
from typing import Dict, List, Tuple, Union
import numpy as np
@@ -322,7 +323,19 @@ def create_kv_cache(rope_mode):
return cache
[email protected](params=[0, 1])
+class RopeMode(enum.IntEnum):
+ """The RoPE mode of the Paged KV cache.
+ If it is none, the KV cache will not apply RoPE to q and k.
+ If it is normal, RoPE will be applied to k before adding k to cache.
+ Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+ """
+
+ NONE = 0
+ NORMAL = 1
+ INLINE = 2
+
+
[email protected](params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE])
def kv_cache_and_rope_mode(request):
set_global_func()
return create_kv_cache(request.param), request.param
@@ -361,7 +374,7 @@ def f_apply_rotary(x, offset, scale, theta):
def apply_attention(
kv_cache,
- rope_mode: int,
+ rope_mode: RopeMode,
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
@@ -406,10 +419,12 @@ def apply_attention(
cached_k[seq_id],
np.stack(
[
- new_k[l]
- if rope_mode == 1
- else f_apply_rotary(
- new_k[l], cached_k[seq_id].shape[1], rope_scale,
rope_theta
+ (
+ new_k[l]
+ if rope_mode != RopeMode.NORMAL
+ else f_apply_rotary(
+ new_k[l], cached_k[seq_id].shape[1],
rope_scale, rope_theta
+ )
)
for l in range(num_layers)
],
@@ -445,15 +460,19 @@ def apply_attention(
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >=
append_length
rope_offset = cached_k[seq_id].shape[1] - append_length
- q_seq = f_apply_rotary(
- q_array[i][layer_id],
- rope_offset,
- rope_scale,
- rope_theta,
+ q_seq = (
+ q_array[i][layer_id]
+ if rope_mode == RopeMode.NONE
+ else f_apply_rotary(
+ q_array[i][layer_id],
+ rope_offset,
+ rope_scale,
+ rope_theta,
+ )
).transpose(1, 0, 2)
k_seq = (
cached_k[seq_id][layer_id]
- if rope_mode == 0
+ if rope_mode != RopeMode.INLINE
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale,
rope_theta)
).transpose(1, 2, 0)
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -586,7 +605,7 @@ def
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
if __name__ == "__main__":
set_global_func()
- for rope_mode in [0, 1]:
+ for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
cache = create_kv_cache(rope_mode)
for fuse_qkv in [False, True]:
test_paged_attention_kv_cache_prefill_and_decode((cache,
rope_mode), fuse_qkv)
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index dc4d4082f1..8bd9da3bbb 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import enum
import itertools
import math
from typing import Dict, List, Tuple, Union
@@ -140,7 +141,25 @@ def create_kv_cache(head_dim, dtype, rope_mode):
return cache
[email protected](params=itertools.product([64, 128], ["float16", "float32"],
[0, 1]))
+class RopeMode(enum.IntEnum):
+ """The RoPE mode of the Paged KV cache.
+ If it is none, the KV cache will not apply RoPE to q and k.
+ If it is normal, RoPE will be applied to k before adding k to cache.
+ Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+ """
+
+ NONE = 0
+ NORMAL = 1
+ INLINE = 2
+
+
[email protected](
+ params=itertools.product(
+ [64, 128],
+ ["float16", "float32"],
+ [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE],
+ )
+)
def kv_cache_and_rope_mode(request):
global head_dim, dtype
head_dim, dtype, rope_mode = request.param
@@ -181,7 +200,7 @@ def f_apply_rotary(x, offset, scale, theta):
def apply_attention(
kv_cache,
- rope_mode: int,
+ rope_mode: RopeMode,
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
@@ -228,7 +247,7 @@ def apply_attention(
[
(
new_k[l]
- if rope_mode == 1
+ if rope_mode != RopeMode.NORMAL
else f_apply_rotary(
new_k[l], cached_k[seq_id].shape[1],
rope_scale, rope_theta
)
@@ -267,15 +286,19 @@ def apply_attention(
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >=
append_length
rope_offset = cached_k[seq_id].shape[1] - append_length
- q_seq = f_apply_rotary(
- q_array[i][layer_id],
- rope_offset,
- rope_scale,
- rope_theta,
+ q_seq = (
+ q_array[i][layer_id]
+ if rope_mode == RopeMode.NONE
+ else f_apply_rotary(
+ q_array[i][layer_id],
+ rope_offset,
+ rope_scale,
+ rope_theta,
+ )
).transpose(1, 0, 2)
k_seq = (
cached_k[seq_id][layer_id]
- if rope_mode == 0
+ if rope_mode != RopeMode.INLINE
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale,
rope_theta)
).transpose(1, 2, 0)
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -1639,7 +1662,7 @@ def _merge_state_inplace(num_heads, head_dim, v_dtype):
if __name__ == "__main__":
for head_dim in [64, 128]:
for dtype in ["float16", "float32"]:
- for rope_mode in [0, 1]:
+ for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
set_global_func(head_dim, dtype)
cache = create_kv_cache(head_dim, dtype, rope_mode)
for fuse_qkv in [False, True]: