krishnaraj36 opened a new pull request, #17482:
URL: https://github.com/apache/tvm/pull/17482

   Improvements -
   Added Tranpose to K for better Vectorization during Matmul. Improved Load 
Schedule.
   Improved a bit more than 2x is most cases.
   Llama-2 7B observation
   -----------kernel----------------baseline----------optimized
   - ---batch_prefill_ragged_kv------15 ms-------------7.1 ms
   
   
   This PR fixes the issue addressed in the PR 
[#17446](https://github.com/apache/tvm/pull/17466). The correctness issue is 
caused by incorrect code-generation during the unroll phase. Thus, we removed 
the explicit unroll and noticed little to no performance degradation.
   
   We generated OpenCL kernels extracting the generated modules by setting 
num_qo_heads=28 in
   
   
[apache-tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py](https://github.qualcomm.com/gpgpu/apache-tvm/blob/85e15d494d5a42360859941cbc972c4f175c3b94/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py#L36)
   
   Line 36 in 
[85e15d4](https://github.qualcomm.com/gpgpu/apache-tvm/commit/85e15d494d5a42360859941cbc972c4f175c3b94)
   
    num_qo_heads = 32 
   
   Original PR Codegen
   int cur_L_3 = ((((((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) / 7) + 
(((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 
4) * 4)) + LH_start) + 1) % 7) >> 31)) + q_indptr[(b_idx_1 + 
q_indptr_elem_offset)]);
   if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
       vstore4((convert_half4((O_local[3] / 
((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + 
(((((cur_L_3 * 3584) + ((convert_int(get_group_id(1))) * 896)) + 
((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) 
>> 4) * 4)) + LH_start) + 1) % 7) + (7 & (((((((convert_int(get_local_id(1))) * 
8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 
31))) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
   }
   int cur_L_4 = ((((((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) / 7) - 
-306783377) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]);
   if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
       vstore4((convert_half4((O_local[4] / 
((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + 
((((cur_L_4 * 3584) + ((convert_int(get_group_id(1))) * 896)) + 
(((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 
4) * 4)) + LH_start) - 2147483637) % 7) * 128)) + 
(((convert_int(get_local_id(0))) & 15) * 8)));
   }
   In the O_store block we notice large and incorrect pointer offsets were 
being generated during subsequent stages of unroll. This can be indirectly 
noted zero elements contained in the output and compute instability.
   
   Fusing the unroll loops to unroll together doesn't seem to resolve this.
   
   Oddly enough, the initial test case doesn't seem to trigger the issue and 
works as intended.
   
   int cur_L_3 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 1) >> 
2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + 
(convert_int(get_local_id(1))));
   if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
       vstore4((convert_half4((O_local[3] / 
((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + 
(((((cur_L_3 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + 
(((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 1) & 7) * 
128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
   }
   int cur_L_4 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 2) >> 
2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + 
(convert_int(get_local_id(1))));
    if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
       vstore4((convert_half4((O_local[4] / 
((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], 
d_smem[((((convert_int(get_local_id(1))) * 8) + 
(((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + 
((((cur_L_4 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + 
(((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 2) & 7) * 
128)) + (((convert_int(get_local_id(0))) & 15) * 8)));
   }


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to