MoebiusMeow opened a new pull request, #12275:
URL: https://github.com/apache/tvm/pull/12275
This PR proposes a fix to #12274 by adding a single range check statement in
sub-warp reduction.
### Before
```cuda
extern "C" __global__ void __launch_bounds__(11)
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
float red_buf0[1];
uint mask[1];
float t0[1];
red_buf0[0] = A[((((int)blockIdx.x) * 11) + ((int)threadIdx.x))];
mask[0] = __activemask();
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], 0, 32);
B[((int)blockIdx.x)] = red_buf0[0];
}
```
### After
```cuda
extern "C" __global__ void __launch_bounds__(11)
default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
float red_buf0[1];
uint mask[1];
float t0[1];
red_buf0[0] = A[((((int)blockIdx.x) * 11) + ((int)threadIdx.x))];
mask[0] = __activemask();
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
if (((int)threadIdx.x) < 3) {
red_buf0[0] = max(red_buf0[0], t0[0]);
}
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
red_buf0[0] = max(red_buf0[0], t0[0]);
red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], 0, 32);
B[((int)blockIdx.x)] = red_buf0[0];
}
```
### Analysis
We cannot put `__shfl_down_sync` in the if-then-else case, but we can decide
whether the returned value is used.
Only the first shuffle call needs range check considering the procedure.
(the "dirty" values will not affect thread 0). Therefore, this fix only brings
little performance overhead.
In addiction, there is no need to check when performing the full warp reduce.
An additional unit test for `max` reduction is also added to
`tests/python/unittest/test_subwarp_reduction_cuda.py`.
cc @vinx13 @junrushao1994 @yzh119
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]