This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ad1da4ee57 [Runtime][Builtin] Using float32 accumulation in attention
kernel (#16667)
ad1da4ee57 is described below
commit ad1da4ee5712264886c3ea385ffedd25a8998d85
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 3 21:31:27 2024 -0500
[Runtime][Builtin] Using float32 accumulation in attention kernel (#16667)
Prior to this PR, the TIR attention kernels does not cast matmul
operands to fp32 before multiplying.
For models like Phi-2 which may have large Q/K/V data (at the level
of a few hundreds), the fp16 multiplication exceeds the range of
fp16, and lead to attention result being NAN sometimes.
This PR fixes this issue.
---
.../test_runtime_builtin_paged_attention_kv_cache_tir.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 2a4f7e87bd..365420dd12 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -902,7 +902,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
- S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
+ S_local[i, j] +=
T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") *
attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):
@@ -960,7 +960,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
O_local[i, j] *=
T.exp2(m_prev_smem[i] - m_smem[i])
- O_local[i, j] += S_smem[i,
k] * V_smem[k, j]
+ O_local[i, j] += S_smem[i,
k] * T.cast(V_smem[k, j], "float32")
# Store O from smem to gmem
for li, lj in T.grid(tile_x, tile_y):
@@ -1196,7 +1196,7 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
# compute S = Q * K * sm_scale
S_reduce_local[0] = 0
for vec in T.serial(VEC_SIZE):
- S_reduce_local[0] += Q_local[vec]
* K_local[vec] * attn_score_scaling_factor * sm_scale
+ S_reduce_local[0] +=
T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") *
attn_score_scaling_factor * sm_scale
with T.block("block_cross_thread"):
T.reads(S_reduce_local[0])
@@ -1230,7 +1230,7 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
for vec in T.vectorized(VEC_SIZE):
V_local[vec] = V_smem[tz * bdy *
tile_size_per_bdx + j, tx * VEC_SIZE + vec]
for vec in T.vectorized(VEC_SIZE):
- O_local[vec] += V_local[vec] *
S_local[j]
+ O_local[vec] +=
T.cast(V_local[vec], "float32") * S_local[j]
if bdz > 1:
# allreduce over bdz
@@ -1445,7 +1445,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
- S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
+ S_local[i, j] +=
T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") *
attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):
@@ -1503,7 +1503,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
O_local[i, j] *=
T.exp2(m_prev_smem[i] - m_smem[i])
- O_local[i, j] += S_smem[i,
k] * V_smem[k, j]
+ O_local[i, j] += S_smem[i,
k] * T.cast(V_smem[k, j], "float32")
# Store O from smem to gmem
for li, lj in T.grid(tile_x, tile_y):