This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cd9445d63b [Unity][lm_support] window kvcache sink (#16240)
cd9445d63b is described below

commit cd9445d63b286fc3cb1aedce9837717d1c36f575
Author: David Pissarra <[email protected]>
AuthorDate: Thu Dec 14 20:08:05 2023 -0500

    [Unity][lm_support] window kvcache sink (#16240)
    
    * attention sinks with correctness test
    
    * fix override sink
---
 src/runtime/relax_vm/lm_support.cc         | 44 ++++++++++++++++++++----------
 tests/python/relax/test_runtime_builtin.py | 41 ++++++++++++++++++++++++++++
 2 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index 6301245dac..706e2c3d5f 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -116,10 +116,12 @@ class AttentionKVCacheObj : public Object {
   /*!
    * \brief Append value to the cache, overrides if full.
    * \param value The value to override previous elements.
+   * \param max_cache_size max size of the cache.
+   * \param num_attention_sinks number of sinks to store 
(https://arxiv.org/abs/2309.17453).
    */
-  void WindowOverride(NDArray value, int64_t max_cache_size) {
+  void WindowOverride(NDArray value, int64_t max_cache_size, int64_t 
num_attention_sinks = 0) {
     CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
-    CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large";
+    CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 
of value too large";
     // reallocate cache
     if (fill_count + value->shape[0] <= max_cache_size) {
       int64_t reserved_slots = data->shape[0];
@@ -148,20 +150,22 @@ class AttentionKVCacheObj : public Object {
       shape.push_back(data->shape[i]);
     }
     int64_t num_filled_elements = window_attention_current_pos * 
num_elements_p_entry;
-
-    DLTensor copy_dst = *(data.operator->());
-    copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * 
data->dtype.lanes + 7) / 8);
-    copy_dst.shape = &shape[0];
-
-    DLTensor copy_src = *(value.operator->());
-    copy_src.byte_offset = 0;
-    copy_src.shape = &shape[0];
-
-    NDArray::CopyFromTo(&copy_src, &copy_dst);
     this->fill_count = std::min(this->fill_count + value->shape[0], 
max_cache_size);
     this->window_attention_current_pos =
         std::min(this->window_attention_current_pos + value->shape[0], 
max_cache_size);
 
+    if (num_elements_to_copy > 0) {
+      DLTensor copy_dst = *(data.operator->());
+      copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * 
data->dtype.lanes + 7) / 8);
+      copy_dst.shape = &shape[0];
+
+      DLTensor copy_src = *(value.operator->());
+      copy_src.byte_offset = 0;
+      copy_src.shape = &shape[0];
+
+      NDArray::CopyFromTo(&copy_src, &copy_dst);
+    }
+
     // copy the remainder to the beginning of the cache
     if (num_elements_to_copy < value->shape[0]) {
       ICHECK_EQ(this->fill_count, max_cache_size);
@@ -171,7 +175,8 @@ class AttentionKVCacheObj : public Object {
       num_filled_elements = num_elements_to_copy * num_elements_p_entry;
 
       DLTensor copy_dst = *(data.operator->());
-      copy_dst.byte_offset = 0;
+      copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) *
+                             ((data->dtype.bits * data->dtype.lanes + 7) / 8);
       copy_dst.shape = &shape[0];
 
       DLTensor copy_src = *(value.operator->());
@@ -180,7 +185,8 @@ class AttentionKVCacheObj : public Object {
       copy_src.shape = &shape[0];
 
       NDArray::CopyFromTo(&copy_src, &copy_dst);
-      this->window_attention_current_pos = value->shape[0] - 
num_elements_to_copy;
+      this->window_attention_current_pos =
+          value->shape[0] - num_elements_to_copy + num_attention_sinks;
     }
   }
 
@@ -277,6 +283,16 @@ AttentionKVCache 
AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
     .set_body_typed(AttentionKVCacheWindowOverride);
 
+AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache 
cache, NDArray value,
+                                                         int64_t 
max_cache_size,
+                                                         int64_t 
num_attention_sinks) {
+  cache->WindowOverride(value, max_cache_size, num_attention_sinks);
+  return cache;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
+    .set_body_typed(AttentionKVCacheWindowOverrideWithSinks);
+
 NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
   return cache->View(shape);
 }
diff --git a/tests/python/relax/test_runtime_builtin.py 
b/tests/python/relax/test_runtime_builtin.py
index 0417f99233..614d32ce0c 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -217,5 +217,46 @@ def test_attention_kv_cache_window_override():
     ).all()
 
 
+def test_attention_kv_cache_window_override_with_sinks():
+    fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
+    foverride = 
tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks")
+    fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
+
+    num_attention_sinks = 2
+    has_sink = False
+    current_pos = 0
+
+    cache = fcreate(
+        tvm.nd.array(np.full((16, 2), -1).astype("int32")),
+        tvm.runtime.ShapeTuple([16, 2]),
+        current_pos,
+    )
+    np_all_arrays = np.zeros((0, 2)).astype("int32")
+
+    num_steps = 40
+    for i in range(num_steps):
+        np_array = i * np.ones((1, 2)).astype("int32")
+        np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0)
+        cache = foverride(cache, tvm.nd.array(np_array), 16, 
num_attention_sinks)
+
+        if has_sink:
+            current_pos = max((current_pos + 1) % 16, num_attention_sinks)
+        else:
+            current_pos += 1
+            has_sink = current_pos >= num_attention_sinks
+
+    res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()
+
+    # unrotate cache and assert cache matches last 16 elements
+    assert (
+        np.concatenate(
+            (np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 + 
num_attention_sinks :, :])
+        )
+        == np.concatenate(
+            (res[:num_attention_sinks], res[current_pos:], 
res[num_attention_sinks:current_pos])
+        )
+    ).all()
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to