tinywisdom opened a new issue, #18315:
URL: https://github.com/apache/tvm/issues/18315
### Summary
Passing a 0-dimensional `torch.bool` scalar into a TVM Relax VM function via
DLPack results in a dtype mismatch at the VM boundary. The Relax IR expects
`R.Tensor((), dtype="bool")` for the parameter, but at runtime the argument is
treated as a different dtype (likely `uint8`), causing:
```
ValueError: ... expect Tensor with dtype bool but get ...
Aborted (core dumped)
```
This occurs when compiling a minimal PyTorch program (exported with
`torch.export`) to Relax and then invoking the VM with inputs created through
`tvm.nd.from_dlpack(to_dlpack(...))`.
### Actual behavior
The VM entry check (`CheckTensorInfo`) sees a dtype different from `bool`
(blank message often corresponds to mis-typed scalar), and throws an
InternalError/ValueError. This suggests the DLPack → `tvm.nd.NDArray` path is
interpreting the 0-D `bool` as a non-bool dtype (commonly `uint8`), or dropping
dtype information for 0-D.
```
About to call TVM VM with (x, flag) where flag is 0-d torch.bool via DLPack
...
terminate called after throwing an instance of 'tvm::runtime::InternalError'
what(): ...
ValueError: Check failed: (DataType(ptr->dtype) == dtype) is false:
ErrorContext(fn=main, loc=param[1], param=flag, annotation=R.Tensor((),
dtype="bool"))
expect Tensor with dtype bool but get
Aborted (core dumped)
```
### Environment
+ OS: (Ubuntu 22.04.4 LTS (x86_64))
+ TVM version: (release v0.21.0)
+ Python: (3.10.16)
+ LLVM: (17.0.6)
### Steps to reproduce
```python
# repro_tvm_bool_flag_crash.py
import torch
import torch.nn as nn
import numpy as np
from torch.export import export as torch_export
from torch.utils.dlpack import to_dlpack, from_dlpack
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
class TinyCond(nn.Module):
# Minimal conditional: `flag` is a 0-D boolean
def forward(self, x, flag):
t = (x + 1).sum()
f = (x - 1).sum()
return torch.where(flag, t, f)
def main():
torch.manual_seed(0)
m = TinyCond().eval()
# Inputs: x is regular tensor; flag is 0-D torch.bool scalar
x = torch.randn(2, 3)
flag = torch.randint(0, 2, (), dtype=torch.bool)
# Sanity check on PyTorch side
with torch.inference_mode():
_ = m(x, flag)
# Export → Relax
ep = torch_export(m, (x, flag))
mod = from_exported_program(ep)
# Build for CPU to minimize deps
target = "llvm"
dev = tvm.cpu(0)
exec_mod = relax.build(mod, target=target)
vm = relax.VirtualMachine(exec_mod, dev)
# Critical step: feed 0-D torch.bool via DLPack
tvm_x = tvm.nd.from_dlpack(to_dlpack(x))
tvm_flag = tvm.nd.from_dlpack(to_dlpack(flag))
print("About to call TVM VM with (x, flag) where flag is 0-d torch.bool
via DLPack ...")
# Expect crash: InternalError/ValueError from CheckTensorInfo on dtype
mismatch
vm["main"](tvm_x, tvm_flag)
if __name__ == "__main__":
main()
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]