tinywisdom opened a new issue, #18315:
URL: https://github.com/apache/tvm/issues/18315

   ### Summary
   
   Passing a 0-dimensional `torch.bool` scalar into a TVM Relax VM function via 
DLPack results in a dtype mismatch at the VM boundary. The Relax IR expects 
`R.Tensor((), dtype="bool")` for the parameter, but at runtime the argument is 
treated as a different dtype (likely `uint8`), causing:
   ```
   ValueError: ... expect Tensor with dtype bool but get ...
   Aborted (core dumped)
   ```
   
   This occurs when compiling a minimal PyTorch program (exported with 
`torch.export`) to Relax and then invoking the VM with inputs created through 
`tvm.nd.from_dlpack(to_dlpack(...))`.
   
   ### Actual behavior
   
   The VM entry check (`CheckTensorInfo`) sees a dtype different from `bool` 
(blank message often corresponds to mis-typed scalar), and throws an 
InternalError/ValueError. This suggests the DLPack → `tvm.nd.NDArray` path is 
interpreting the 0-D `bool` as a non-bool dtype (commonly `uint8`), or dropping 
dtype information for 0-D.
   
   ```
   About to call TVM VM with (x, flag) where flag is 0-d torch.bool via DLPack 
...
   terminate called after throwing an instance of 'tvm::runtime::InternalError'
     what():  ...
   ValueError: Check failed: (DataType(ptr->dtype) == dtype) is false:
     ErrorContext(fn=main, loc=param[1], param=flag, annotation=R.Tensor((), 
dtype="bool"))
     expect Tensor with dtype bool but get
   Aborted (core dumped)
   ```
   
   ### Environment
   
   + OS: (Ubuntu 22.04.4 LTS (x86_64))
   + TVM version: (release v0.21.0)
   + Python: (3.10.16)
   + LLVM: (17.0.6)
   
   ### Steps to reproduce
   
   ```python
   # repro_tvm_bool_flag_crash.py
   import torch
   import torch.nn as nn
   import numpy as np
   
   from torch.export import export as torch_export
   from torch.utils.dlpack import to_dlpack, from_dlpack
   
   import tvm
   from tvm import relax
   from tvm.relax.frontend.torch import from_exported_program
   
   class TinyCond(nn.Module):
       # Minimal conditional: `flag` is a 0-D boolean
       def forward(self, x, flag):
           t = (x + 1).sum()
           f = (x - 1).sum()
           return torch.where(flag, t, f)
   
   def main():
       torch.manual_seed(0)
       m = TinyCond().eval()
   
       # Inputs: x is regular tensor; flag is 0-D torch.bool scalar
       x = torch.randn(2, 3)
       flag = torch.randint(0, 2, (), dtype=torch.bool)
   
       # Sanity check on PyTorch side
       with torch.inference_mode():
           _ = m(x, flag)
   
       # Export → Relax
       ep = torch_export(m, (x, flag))
       mod = from_exported_program(ep)
   
       # Build for CPU to minimize deps
       target = "llvm"
       dev = tvm.cpu(0)
       exec_mod = relax.build(mod, target=target)
       vm = relax.VirtualMachine(exec_mod, dev)
   
       # Critical step: feed 0-D torch.bool via DLPack
       tvm_x = tvm.nd.from_dlpack(to_dlpack(x))
       tvm_flag = tvm.nd.from_dlpack(to_dlpack(flag))
   
       print("About to call TVM VM with (x, flag) where flag is 0-d torch.bool 
via DLPack ...")
       # Expect crash: InternalError/ValueError from CheckTensorInfo on dtype 
mismatch
       vm["main"](tvm_x, tvm_flag)
   
   if __name__ == "__main__":
       main()
   
   ```
   
   ### Triage
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to