This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 91dafd9 Add version 11.1 in finding CUDA libdevice (#7033)
91dafd9 is described below
commit 91dafd9ae7c9af9c496b3f7a4e357ec98dc44a66
Author: Thomas Viehmann <[email protected]>
AuthorDate: Fri Dec 4 18:37:30 2020 +0100
Add version 11.1 in finding CUDA libdevice (#7033)
* Add CUDA 11.1 libdevice
Maybe we should have a >= check instead.
I also added a fallback to detect the version if version.txt is
missing. Calling nvcc for this has been inspired by what PyTorch
does when compiling extension modules.
---
python/tvm/contrib/nvcc.py | 19 ++++++++++++++++---
1 file changed, 16 insertions(+), 3 deletions(-)
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 53a507f..89548b7 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -148,8 +148,20 @@ def get_cuda_version(cuda_path):
with open(version_file_path) as f:
version_str = f.readline().replace("\n", "").replace("\r", "")
return float(version_str.split(" ")[2][:2])
- except:
- raise RuntimeError("Cannot read cuda version file")
+ except FileNotFoundError:
+ pass
+
+ cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"]
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
+ (out, _) = proc.communicate()
+ out = py_str(out)
+ if proc.returncode == 0:
+ release_line = [l for l in out.split("\n") if "release" in l][0]
+ release_fields = [s.strip() for s in release_line.split(",")]
+ release_version = [f[1:] for f in release_fields if
f.startswith("V")][0]
+ major_minor = ".".join(release_version.split(".")[:2])
+ return float(major_minor)
+ raise RuntimeError("Cannot read cuda version file")
@tvm._ffi.register_func("tvm_callback_libdevice_path")
@@ -174,7 +186,7 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
- if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0):
+ if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
@@ -219,6 +231,7 @@ def parse_compute_version(compute_version):
minor = int(split_ver[1])
return major, minor
except (IndexError, ValueError) as err:
+ # pylint: disable=raise-missing-from
raise RuntimeError("Compute version parsing error: " + str(err))