This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cd9445d63b [Unity][lm_support] window kvcache sink (#16240)
cd9445d63b is described below
commit cd9445d63b286fc3cb1aedce9837717d1c36f575
Author: David Pissarra <[email protected]>
AuthorDate: Thu Dec 14 20:08:05 2023 -0500
[Unity][lm_support] window kvcache sink (#16240)
* attention sinks with correctness test
* fix override sink
---
src/runtime/relax_vm/lm_support.cc | 44 ++++++++++++++++++++----------
tests/python/relax/test_runtime_builtin.py | 41 ++++++++++++++++++++++++++++
2 files changed, 71 insertions(+), 14 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index 6301245dac..706e2c3d5f 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -116,10 +116,12 @@ class AttentionKVCacheObj : public Object {
/*!
* \brief Append value to the cache, overrides if full.
* \param value The value to override previous elements.
+ * \param max_cache_size max size of the cache.
+ * \param num_attention_sinks number of sinks to store
(https://arxiv.org/abs/2309.17453).
*/
- void WindowOverride(NDArray value, int64_t max_cache_size) {
+ void WindowOverride(NDArray value, int64_t max_cache_size, int64_t
num_attention_sinks = 0) {
CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
- CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large";
+ CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0
of value too large";
// reallocate cache
if (fill_count + value->shape[0] <= max_cache_size) {
int64_t reserved_slots = data->shape[0];
@@ -148,20 +150,22 @@ class AttentionKVCacheObj : public Object {
shape.push_back(data->shape[i]);
}
int64_t num_filled_elements = window_attention_current_pos *
num_elements_p_entry;
-
- DLTensor copy_dst = *(data.operator->());
- copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits *
data->dtype.lanes + 7) / 8);
- copy_dst.shape = &shape[0];
-
- DLTensor copy_src = *(value.operator->());
- copy_src.byte_offset = 0;
- copy_src.shape = &shape[0];
-
- NDArray::CopyFromTo(©_src, ©_dst);
this->fill_count = std::min(this->fill_count + value->shape[0],
max_cache_size);
this->window_attention_current_pos =
std::min(this->window_attention_current_pos + value->shape[0],
max_cache_size);
+ if (num_elements_to_copy > 0) {
+ DLTensor copy_dst = *(data.operator->());
+ copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits *
data->dtype.lanes + 7) / 8);
+ copy_dst.shape = &shape[0];
+
+ DLTensor copy_src = *(value.operator->());
+ copy_src.byte_offset = 0;
+ copy_src.shape = &shape[0];
+
+ NDArray::CopyFromTo(©_src, ©_dst);
+ }
+
// copy the remainder to the beginning of the cache
if (num_elements_to_copy < value->shape[0]) {
ICHECK_EQ(this->fill_count, max_cache_size);
@@ -171,7 +175,8 @@ class AttentionKVCacheObj : public Object {
num_filled_elements = num_elements_to_copy * num_elements_p_entry;
DLTensor copy_dst = *(data.operator->());
- copy_dst.byte_offset = 0;
+ copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) *
+ ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];
DLTensor copy_src = *(value.operator->());
@@ -180,7 +185,8 @@ class AttentionKVCacheObj : public Object {
copy_src.shape = &shape[0];
NDArray::CopyFromTo(©_src, ©_dst);
- this->window_attention_current_pos = value->shape[0] -
num_elements_to_copy;
+ this->window_attention_current_pos =
+ value->shape[0] - num_elements_to_copy + num_attention_sinks;
}
}
@@ -277,6 +283,16 @@ AttentionKVCache
AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
.set_body_typed(AttentionKVCacheWindowOverride);
+AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache
cache, NDArray value,
+ int64_t
max_cache_size,
+ int64_t
num_attention_sinks) {
+ cache->WindowOverride(value, max_cache_size, num_attention_sinks);
+ return cache;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
+ .set_body_typed(AttentionKVCacheWindowOverrideWithSinks);
+
NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
return cache->View(shape);
}
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
index 0417f99233..614d32ce0c 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -217,5 +217,46 @@ def test_attention_kv_cache_window_override():
).all()
+def test_attention_kv_cache_window_override_with_sinks():
+ fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
+ foverride =
tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks")
+ fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
+
+ num_attention_sinks = 2
+ has_sink = False
+ current_pos = 0
+
+ cache = fcreate(
+ tvm.nd.array(np.full((16, 2), -1).astype("int32")),
+ tvm.runtime.ShapeTuple([16, 2]),
+ current_pos,
+ )
+ np_all_arrays = np.zeros((0, 2)).astype("int32")
+
+ num_steps = 40
+ for i in range(num_steps):
+ np_array = i * np.ones((1, 2)).astype("int32")
+ np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0)
+ cache = foverride(cache, tvm.nd.array(np_array), 16,
num_attention_sinks)
+
+ if has_sink:
+ current_pos = max((current_pos + 1) % 16, num_attention_sinks)
+ else:
+ current_pos += 1
+ has_sink = current_pos >= num_attention_sinks
+
+ res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()
+
+ # unrotate cache and assert cache matches last 16 elements
+ assert (
+ np.concatenate(
+ (np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 +
num_attention_sinks :, :])
+ )
+ == np.concatenate(
+ (res[:num_attention_sinks], res[current_pos:],
res[num_attention_sinks:current_pos])
+ )
+ ).all()
+
+
if __name__ == "__main__":
tvm.testing.main()