tinywisdom opened a new issue, #19549:
URL: https://github.com/apache/tvm/issues/19549
### Summary
A minimal `torch.topk` model compiled through TVM Relax for a CUDA target
fails at runtime with `CUDA_ERROR_INVALID_VALUE`.
The model is very small:
```python
values, _ = torch.topk(x, k=10, dim=-1)
return values
```
For input shape `(60, 1000, 1000)`, `tvm.compile(...)` succeeds, but the
Relax VM fails when launching the generated CUDA kernel topk_kernel_2.
The failing launch configuration is:
```
grid=(1,240000,1), block=(256,1,1)
// func_name=topk_kernel_2
```
Here `gridDim.y = 240000`, which exceeds the usual CUDA grid y-dimension
limit. It looks like the CUDA TopK lowering/scheduling path generates an
invalid launch configuration for this large input.
### Minimal reproduction
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import platform
import traceback
import numpy as np
import torch
import tvm
from tvm import relax
class TopKModel(torch.nn.Module):
def forward(self, x):
values, _ = torch.topk(x, k=10, dim=-1)
return values
def make_tvm_array(arr, dev):
if not isinstance(arr, np.ndarray):
arr = np.array(arr)
if not arr.flags["C_CONTIGUOUS"]:
arr = np.ascontiguousarray(arr)
# Some TVM builds expose tvm.nd.array, while newer/custom FFI builds may
not.
if hasattr(tvm, "nd") and hasattr(tvm.nd, "array"):
try:
return tvm.nd.array(arr, device=dev)
except TypeError:
return tvm.nd.array(arr, dev)
try:
from tvm.runtime import ndarray as _nd
if hasattr(_nd, "array"):
try:
return _nd.array(arr, dev)
except TypeError:
return _nd.array(arr, device=dev)
except Exception:
pass
if hasattr(tvm, "runtime") and hasattr(tvm.runtime, "tensor"):
try:
return tvm.runtime.tensor(arr, device=dev)
except TypeError:
return tvm.runtime.tensor(arr, dev)
raise RuntimeError("Cannot construct TVM NDArray/Tensor in this TVM
build")
def main():
print("=" * 80)
print("Environment")
print("=" * 80)
print("python:", sys.version.replace("\n", " "))
print("platform:", platform.platform())
print("torch:", torch.__version__)
print("tvm:", getattr(tvm, "__version__", "<unknown>"))
print("tvm path:", getattr(tvm, "__file__", "<unknown>"))
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES",
""))
print("torch.cuda.is_available:", torch.cuda.is_available())
print("tvm.cuda(0).exist:", tvm.cuda(0).exist)
target = tvm.target.Target(
"cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024
-thread_warp_size=32"
)
dev = tvm.cuda(0)
print("target:", target)
model = TopKModel().eval()
# Export on CPU. The compiled TVM module will run on CUDA.
x = torch.rand(60, 1000, 1000, dtype=torch.float32)
with torch.no_grad():
eager_out = model(x)
print("input shape:", tuple(x.shape))
print("eager output shape:", tuple(eager_out.shape))
ep = torch.export.export(model, (x,))
from tvm.relax.frontend.torch import from_exported_program
ir_mod = from_exported_program(ep)
print("=" * 80)
print("tvm.compile")
print("=" * 80)
ex = tvm.compile(
ir_mod,
target=target,
relax_pipeline="default",
tir_pipeline="default",
)
print("compile: OK")
print("=" * 80)
print("Relax VM run")
print("=" * 80)
vm = relax.VirtualMachine(ex, dev)
fn = vm["main"]
x_tvm = make_tvm_array(x.numpy(), dev)
try:
out = fn(x_tvm)
# Force async CUDA launch errors to surface.
try:
dev.sync()
except Exception:
pass
print("run: OK")
if hasattr(out, "numpy"):
out_np = out.numpy()
print("output shape:", out_np.shape)
print("output dtype:", out_np.dtype)
else:
print("output type:", type(out))
except Exception:
print("run: FAILED")
traceback.print_exc()
if __name__ == "__main__":
main()
```
### Actual behavior
The module compiles successfully:
```
tvm.compile
compile: OK
```
But the Relax VM fails during CUDA kernel launch:
```
RuntimeError: CUDALaunch CUDA_ERROR_INVALID_VALUE
grid=(1,240000,1), block=(256,1,1)
// func_name=topk_kernel_2
```
The generated CUDA source is also printed in the exception. The relevant
part is that the failing kernel is topk_kernel_2, and it is launched with:
```
grid=(1,240000,1), block=(256,1,1)
```
### Environment
TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6
### Triage
Please refer to the list of label tags
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the
relevant tags and add them below in a bullet format (example below).
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]