This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 42b1e9731e [KVCache] Fix attention prefill kernel for Metal and
Android (#17539)
42b1e9731e is described below
commit 42b1e9731e32edb80afbf9fc6ce27c62039f3452
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Nov 21 06:38:48 2024 -0500
[KVCache] Fix attention prefill kernel for Metal and Android (#17539)
This PR fixes two bugs of the attention prefill ragged kernel.
* The first bug is the unroll of loop `ki`, which causes the TIR build
failure in the PointerValueTypeRewrite pass due to vector size.
* The second is the tile sizes of `tile_z` and `tile_y` may violate
the assertion check in `get_tile_size`.
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 618345d0a5..18f3e19909 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -1579,6 +1579,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
+ original_tile_y = tile_y
+ original_tile_z = tile_z
+ while (tile_x * tile_z) % (bdx * num_warps) != 0:
+ tile_z += original_tile_z
+ while (tile_x * tile_y) % (bdx * num_warps) != 0:
+ tile_y += original_tile_y
# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
@@ -1907,7 +1913,6 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype,
rope_scaling: Dict[str, Any],
sch.unroll(yio)
sch.vectorize(yiv)
sch.unroll(xi)
- sch.unroll(ki)
sch.decompose_reduction(block, ty)
def apply_to_md(sch, block):