ArmageddonKnight opened a new pull request #8678: URL: https://github.com/apache/tvm/pull/8678
# Short Summary Sometimes, when executing CUDA kernels, we might encounter the error `CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES` (e.g., [here](https://discuss.tvm.apache.org/t/cuda-got-error-cuda-error-launch-out-of-resources/4173)). This happens because **the nvcc compiler allocates too many registers per thread**. In the case when we launch the CUDA kernel using too many threads, the GPU will notice that the CUDA kernel requests more registers than what are available on the chip and therefore refuse to launch the kernel. This hence implies that we need a way of telling nvcc what to expect in terms of the number of threads per block. Luckily, the `__launch_bounds__` directive can help us achieve what we want. In this patch, we add `__launch_bounds__` as part of the CUDA code generation procedure. `__launch_bounds__` will be automatically printed if it is detected that the number of threads per block is a constant integer value. Passing this information to nvcc allows it to spill registers if needed, which might hurt performance, but is still better than having a CUDA kernel that is not functional. # Q & A **Q: Would this affect the AutoTVM and the auto-scheduler submodule?** A: No. Although in those cases the number of threads keeps changing at each trial, the number will be set to a constant when it comes to the code generation phase. Furthermore, in the case when the number of threads per block is not a constant, `__launch_bounds__` will simply not be printed. Any feedback on this patch is appreciated. @comaniac @icemelon @yzhliu @yidawang -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
