DickJC123 opened a new pull request #20443:
URL: https://github.com/apache/incubator-mxnet/pull/20443
## Description ##
This PR makes RTC (as invoked by our Python unittests and other model
scripts) work with CUDA enhanced compatibility.
As such, it is an extension of PR
https://github.com/apache/incubator-mxnet/pull/19364, which brought that
functionality to the C++ backend. This PR keeps
test_operator_gpu.py::test_cuda_rtc from failing on systems that rely on CUDA
enhanced compatibility, though those systems may not be part of upstream CI at
present.
The changes of this PR are:
- break off the calculation of the max supported arch into a separate
function GetMaxSupportedArch(), and enhance it to use nvrtcGetSupportedArchs()
if CUDA_VERSION >= 11.2
- wrap GetMaxSupportedArch() as MXGetMaxSupportedArch() and add it to the C
api
- use MXGetMaxSupportedArch() in a newly created Python utility function
get_rtc_compile_opts(ctx)
- enhance test_cuda_rtc to use this new function
Our current approach to RTC in Python code, which might fail under CUDA
enhanced compatibility:
```
module = mx.rtc.CudaModule(source)
```
With this PR, the new approach that succeeds under CUDA enhanced
compatibility:
```
ctx = < some GPU context, e.g. mx.gpu(0) >
module = mx.rtc.CudaModule(source, options=get_rtc_compile_opts(ctx))
```
get_rtc_compile_opts() will return a list of options that is most
appropriate for the system and the gpu context. Currently this is a single
option of the form `--gpu-architecture=compute_NN` or
`--gpu-architecture=sm_NN` as needed.
## Background ##
Starting with CUDA 11.1, a user can accept minor release upgrades of the
CUDA toolkit (potentially picking up support for a newer GPU arch) without
upgrading the driver (per
https://docs.nvidia.com/deploy/cuda-compatibility/index.html). In such cases,
the toolkit nvrtc compile toolchain should not only compile CUDA code to PTX,
but also further translate the PTX to SASS, since the driver would be unable to
JIT-compile to SASS for the newer GPU arch. This is controlled by the nvrtc
compiler option used: for example, to compile to SASS for the Ampere A100 the
option is `--gpu-architecture=sm_80`. To compile only to PTX, and so rely on
the driver's ability to JIT-compile to SASS, the option is
`--gpu-architecture=compute_80`.
## Checklist ##
### Essentials ###
- [X] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL],
[FEATURE], [DOC], etc)
- [X] Changes are complete (i.e. I finished coding on this PR)
- [~] All changes have test coverage [Verified privately, but ideally
upstream's CI would have systems that stress this PR]
- [X] Code is well-documented
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]