krishnaraj36 opened a new pull request, #17482: URL: https://github.com/apache/tvm/pull/17482
Improvements - Added Tranpose to K for better Vectorization during Matmul. Improved Load Schedule. Improved a bit more than 2x is most cases. Llama-2 7B observation -----------kernel----------------baseline----------optimized - ---batch_prefill_ragged_kv------15 ms-------------7.1 ms This PR fixes the issue addressed in the PR [#17446](https://github.com/apache/tvm/pull/17466). The correctness issue is caused by incorrect code-generation during the unroll phase. Thus, we removed the explicit unroll and noticed little to no performance degradation. We generated OpenCL kernels extracting the generated modules by setting num_qo_heads=28 in [apache-tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py](https://github.qualcomm.com/gpgpu/apache-tvm/blob/85e15d494d5a42360859941cbc972c4f175c3b94/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py#L36) Line 36 in [85e15d4](https://github.qualcomm.com/gpgpu/apache-tvm/commit/85e15d494d5a42360859941cbc972c4f175c3b94) num_qo_heads = 32 Original PR Codegen int cur_L_3 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) / 7) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31)) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]); if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) { vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 3584) + ((convert_int(get_group_id(1))) * 896)) + ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) + (7 & (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31))) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4)); } int cur_L_4 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) / 7) - -306783377) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]); if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) { vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 3584) + ((convert_int(get_group_id(1))) * 896)) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) % 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8))); } In the O_store block we notice large and incorrect pointer offsets were being generated during subsequent stages of unroll. This can be indirectly noted zero elements contained in the output and compute instability. Fusing the unroll loops to unroll together doesn't seem to resolve this. Oddly enough, the initial test case doesn't seem to trigger the issue and works as intended. int cur_L_3 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 1) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1)))); if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) { vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 1) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4)); } int cur_L_4 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 2) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1)))); if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) { vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 2) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8))); } -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
