yzh119 opened a new pull request #10207:
URL: https://github.com/apache/tvm/pull/10207
Previously the `LowerThreadAllReduce` pass will only emit code that uses
`shfl_down` when reduce extent equals warp size, when reduce extent is less
than warp size, the codegen fall back to emit code that uses shared memory,
which is not efficient. Consider CUDA supports sub warp reduction by specifying
the mask, we can still use shuffle-down approach for reduction by changing the
mask.
Example code:
```python
import tvm
import numpy as np
from tvm.script import tir as T
@T.prim_func
def reduce(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [1024, 11])
B = T.match_buffer(b, [1024])
for i, j in T.grid(1024, 11):
with T.block("reduce"):
vi, vj = T.axis.remap("SR", [i, j])
with T.init():
B[vi] = 0.
B[vi] = B[vi] + A[vi, vj]
sch = tvm.tir.Schedule(reduce)
blk = sch.get_block("reduce")
i, j = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")
print(f.imported_modules[0].get_source())
```
Emitted code before this PR:
```cuda
extern "C" __global__ void __launch_bounds__(11)
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
__shared__ float red_buf0[11];
__syncthreads();
((volatile float*)red_buf0)[(((int)threadIdx.x))] = A[(((((int)blockIdx.x)
* 11) + ((int)threadIdx.x)))];
__syncthreads();
if (((int)threadIdx.x) < 3) {
((volatile float*)red_buf0)[(((int)threadIdx.x))] = (((volatile
float*)red_buf0)[(((int)threadIdx.x))] + ((volatile
float*)red_buf0)[((((int)threadIdx.x) + 8))]);
}
__syncthreads();
if (((int)threadIdx.x) < 4) {
float w_4_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] +
((volatile float*)red_buf0)[((((int)threadIdx.x) + 4))]);
((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_4_0;
float w_2_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] +
((volatile float*)red_buf0)[((((int)threadIdx.x) + 2))]);
((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_2_0;
float w_1_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] +
((volatile float*)red_buf0)[((((int)threadIdx.x) + 1))]);
((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_1_0;
}
__syncthreads();
B[(((int)blockIdx.x))] = ((volatile float*)red_buf0)[(0)];
}
```
Emitted code after this PR:
```cuda
extern "C" __global__ void __launch_bounds__(11)
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
float red_buf0[1];
uint mask[1];
float t0[1];
red_buf0[(0)] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
mask[(0)] = (__activemask() & (uint)2047);
t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32);
red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32);
B[(((int)blockIdx.x))] = red_buf0[(0)];
}
```
# Future work
CUDA 11 supports [cooperative group
reduction](https://developer.nvidia.com/blog/cuda-11-features-revealed/) which
we can directly use.
cc @vinx13 @junrushao1994
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]