LeiWang1999 commented on PR #15967:
URL: https://github.com/apache/tvm/pull/15967#issuecomment-1826236971
hi @malixian , I tried the rocwmma meta schedule on my MI210 AMD GPU, here
is the codegen result:
```cpp
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
using float16_t = _Float16;
using float16x2
= __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4
= __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
#include <rocwmma/rocwmma.hpp>
using int32x4
= __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4
= __attribute__((__vector_size__(4 * sizeof(float)))) float;
extern "C" __global__ void main_kernel(half* __restrict__ A, half*
__restrict__ B, float* __restrict__ C);
__launch_bounds__(64)extern "C" __global__ void __launch_bounds__(64)
main_kernel(half* __restrict__ A, half* __restrict__ B, float* __restrict__ C) {
rocwmma::fragment<rocwmma::accumulator, 32, 32, 8, float>
C_reindex_shared_wmma_accumulator[16];
__shared__ half A_reindex_shared[5120];
__shared__ half B_reindex_shared[4352];
rocwmma::fragment<rocwmma::matrix_a, 32, 32, 8, half, rocwmma::row_major>
A_reindex_shared_wmma_matrix_a[8];
rocwmma::fragment<rocwmma::matrix_b, 32, 32, 8, half, rocwmma::row_major>
B_reindex_shared_wmma_matrix_b[8];
__shared__ float C_reindex_shared[4096];
for (int ax0_0_3_init = 0; ax0_0_3_init < 2; ++ax0_0_3_init) {
for (int ax0_0_4_init = 0; ax0_0_4_init < 2; ++ax0_0_4_init) {
for (int ax1_0_4_init = 0; ax1_0_4_init < 4; ++ax1_0_4_init) {
rocwmma::fill_fragment(C_reindex_shared_wmma_accumulator[(((ax0_0_3_init * 8) +
(ax0_0_4_init * 4)) + ax1_0_4_init)], 0.000000e+00f);
}
}
}
for (int ax2_0_0 = 0; ax2_0_0 < 512; ++ax2_0_0) {
__syncthreads();
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 64; ++ax0_ax1_fused_0) {
A_reindex_shared[(((ax0_ax1_fused_0 * 80) + ((((int)threadIdx.x) >> 5)
* 40)) + (((int)threadIdx.x) & 31))] = A[(((((((((int)blockIdx.y) >> 5) *
8388608) + ((((int)blockIdx.x) >> 2) * 2097152)) + (ax0_ax1_fused_0 * 32768)) +
((((int)threadIdx.x) >> 5) * 16384)) + (ax2_0_0 * 32)) + (((int)threadIdx.x) &
31))];
}
for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 64;
++ax0_ax1_fused_0_1) {
B_reindex_shared[((((ax0_ax1_fused_0_1 >> 1) * 136) +
((ax0_ax1_fused_0_1 & 1) * 64)) + ((int)threadIdx.x))] =
B[(((((((((int)blockIdx.y) & 31) * 8388608) + ((((int)blockIdx.x) & 3) *
2097152)) + ((ax0_ax1_fused_0_1 & 1) * 1048576)) + (((int)threadIdx.x) *
16384)) + (ax2_0_0 * 32)) + (ax0_ax1_fused_0_1 >> 1))];
}
__syncthreads();
for (int ax2_0_1 = 0; ax2_0_1 < 2; ++ax2_0_1) {
for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) {
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) {
rocwmma::load_matrix_sync(A_reindex_shared_wmma_matrix_a[((ax0_0 *
2) + ax1_0)], (&(A_reindex_shared[(((ax0_0 * 1280) + (ax2_0_1 * 16)) + (ax1_0 *
8))])), 40);
}
}
for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
rocwmma::load_matrix_sync(B_reindex_shared_wmma_matrix_b[((ax0_0_1
* 4) + ax1_0_1)], (&(B_reindex_shared[(((ax2_0_1 * 2176) + (ax0_0_1 * 1088)) +
(ax1_0_1 * 32))])), 136);
}
}
for (int ax0_0_3 = 0; ax0_0_3 < 2; ++ax0_0_3) {
for (int ax2_0_2 = 0; ax2_0_2 < 2; ++ax2_0_2) {
for (int ax0_0_4 = 0; ax0_0_4 < 2; ++ax0_0_4) {
for (int ax1_0_4 = 0; ax1_0_4 < 4; ++ax1_0_4) {
rocwmma::mma_sync(C_reindex_shared_wmma_accumulator[(((ax0_0_3
* 8) + (ax0_0_4 * 4)) + ax1_0_4)], A_reindex_shared_wmma_matrix_a[(((ax0_0_3 *
4) + (ax0_0_4 * 2)) + ax2_0_2)], B_reindex_shared_wmma_matrix_b[((ax2_0_2 * 4)
+ ax1_0_4)], C_reindex_shared_wmma_accumulator[(((ax0_0_3 * 8) + (ax0_0_4 * 4))
+ ax1_0_4)]);
}
}
}
}
}
}
for (int ax2 = 0; ax2 < 4; ++ax2) {
__syncthreads();
for (int ax3 = 0; ax3 < 4; ++ax3) {
rocwmma::store_matrix_sync((&(C_reindex_shared[(ax3 * 1024)])),
C_reindex_shared_wmma_accumulator[((ax2 * 4) + ax3)], 32,
rocwmma::mem_row_major);
}
__syncthreads();
for (int ax0_ax1_ax3_ax4_ax5_fused_0 = 0; ax0_ax1_ax3_ax4_ax5_fused_0 <
16; ++ax0_ax1_ax3_ax4_ax5_fused_0) {
*(float4*)(C + ((((((((((((int)blockIdx.y) >> 5) * 8388608) +
((((int)blockIdx.x) >> 2) * 2097152)) + (ax2 * 524288)) +
((ax0_ax1_ax3_ax4_ax5_fused_0 & 3) * 131072)) + ((((int)threadIdx.x) >> 3) *
16384)) + ((((int)blockIdx.y) & 31) * 512)) + ((((int)blockIdx.x) & 3) * 128))
+ ((ax0_ax1_ax3_ax4_ax5_fused_0 >> 2) * 32)) + ((((int)threadIdx.x) & 7) * 4)))
= *(float4*)(C_reindex_shared + ((ax0_ax1_ax3_ax4_ax5_fused_0 * 256) +
(((int)threadIdx.x) * 4)));
}
}
}
```
The hip code is a bit awkward; we have duplicate function name `extern "C"
__global__ void main_kernel(half* __restrict__ A, half* __restrict__ B, float*
__restrict__ C);
__launch_bounds__(64)extern "C" __global__ void __launch_bounds__(64)
main_kernel(half* __restrict__ A, half* __restrict__ B, float* __restrict__ C)`,
But surprised that hipcc can still compile this function successfully.
The `CodegenHIP::PrintExtraAttrs` should be removed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]