This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ad1da4ee57 [Runtime][Builtin] Using float32 accumulation in attention 
kernel (#16667)
ad1da4ee57 is described below

commit ad1da4ee5712264886c3ea385ffedd25a8998d85
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 3 21:31:27 2024 -0500

    [Runtime][Builtin] Using float32 accumulation in attention kernel (#16667)
    
    Prior to this PR, the TIR attention kernels does not cast matmul
    operands to fp32 before multiplying.
    For models like Phi-2 which may have large Q/K/V data (at the level
    of a few hundreds), the fp16 multiplication exceeds the range of
    fp16, and lead to attention result being NAN sometimes.
    
    This PR fixes this issue.
---
 .../test_runtime_builtin_paged_attention_kv_cache_tir.py     | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 2a4f7e87bd..365420dd12 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -902,7 +902,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         S_local[i, j] = 0.0
-                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
+                                                    S_local[i, j] += 
T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * 
attn_score_scaling_factor * sm_scale
                                         T.tvm_storage_sync("shared")
                                         for li, lj in T.grid(tile_x, tile_z):
                                             with T.block("S_store"):
@@ -960,7 +960,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         O_local[i, j] *= 
T.exp2(m_prev_smem[i] - m_smem[i])
-                                                    O_local[i, j] += S_smem[i, 
k] * V_smem[k, j]
+                                                    O_local[i, j] += S_smem[i, 
k] * T.cast(V_smem[k, j], "float32")
 
                                     # Store O from smem to gmem
                                     for li, lj in T.grid(tile_x, tile_y):
@@ -1196,7 +1196,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                         # compute S = Q * K * sm_scale
                                         S_reduce_local[0] = 0
                                         for vec in T.serial(VEC_SIZE):
-                                            S_reduce_local[0] += Q_local[vec] 
* K_local[vec] * attn_score_scaling_factor * sm_scale
+                                            S_reduce_local[0] += 
T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * 
attn_score_scaling_factor * sm_scale
 
                                         with T.block("block_cross_thread"):
                                             T.reads(S_reduce_local[0])
@@ -1230,7 +1230,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                         for vec in T.vectorized(VEC_SIZE):
                                             V_local[vec] = V_smem[tz * bdy * 
tile_size_per_bdx + j, tx * VEC_SIZE + vec]
                                         for vec in T.vectorized(VEC_SIZE):
-                                            O_local[vec] += V_local[vec] * 
S_local[j]
+                                            O_local[vec] += 
T.cast(V_local[vec], "float32") * S_local[j]
 
                                 if bdz > 1:
                                     # allreduce over bdz
@@ -1445,7 +1445,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         S_local[i, j] = 0.0
-                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
+                                                    S_local[i, j] += 
T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * 
attn_score_scaling_factor * sm_scale
                                         T.tvm_storage_sync("shared")
                                         for li, lj in T.grid(tile_x, tile_z):
                                             with T.block("S_store"):
@@ -1503,7 +1503,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         O_local[i, j] *= 
T.exp2(m_prev_smem[i] - m_smem[i])
-                                                    O_local[i, j] += S_smem[i, 
k] * V_smem[k, j]
+                                                    O_local[i, j] += S_smem[i, 
k] * T.cast(V_smem[k, j], "float32")
 
                                     # Store O from smem to gmem
                                     for li, lj in T.grid(tile_x, tile_y):

Reply via email to