tinywisdom opened a new issue, #19551:
URL: https://github.com/apache/tvm/issues/19551
### Summary
A minimal PyTorch model using `torch.prod(x, dtype=torch.bool)` can be
exported to Relax, but fails during `tvm.compile(..., target="llvm")`.
PyTorch eager handles this case successfully:
```python
x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)
torch.prod(x, dtype=torch.bool)
# tensor(False)
```
After `torch.export and `from_exported_program`, TVM produces a Relax
program containing:
```
R.prod(x, axis=None, keepdims=False)
```
where `x` has dtype `bool`. However, `tvm.compile` fails during LLVM code
generation with:
```
InternalError: Check failed: (t.is_float()) is false:
```
The stack trace shows the failure reaches CodeGenLLVM::CreateMul, suggesting
that the bool reduction is lowered as a multiplication-based reduction over
bool values. For bool prod, this should either be lowered to a valid
logical-AND-style reduction, cast to a supported integer representation, or
rejected earlier with a clear unsupported-dtype diagnostic.
### Expected behavior
The exported Relax IR contains R.prod over a bool tensor:
```
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 1, 16, 16), dtype="bool")) ->
R.Tuple(R.Tensor((), dtype="bool")):
with R.dataflow():
lv: R.Tensor((), dtype="bool") = R.prod(x, axis=None,
keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="bool")) = (lv,)
R.output(gv)
return gv
```
### Actual behavior
`tvm.compile` fails for the LLVM target.
The relevant part of the stack trace is:
```
tvm.tir.build
-> codegen_build
-> CodeGenLLVM::AddFunctionInternal
-> CodeGenLLVM::VisitStmt_(BufferStoreNode)
-> CodeGenLLVM::MakeValue
-> CodeGenLLVM::VisitExpr_(CastNode)
-> CodeGenLLVM::CreateMul
-> InternalError: Check failed: (t.is_float()) is false:
```
Full observed failure:
```
tvm.error.InternalError: Check failed: (t.is_float()) is false:
```
### Environment
TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6
### Steps to reproduce
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import platform
import traceback
import torch
import tvm
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.prod(x, dtype=torch.bool)
def main():
print("=" * 80)
print("Environment")
print("=" * 80)
print("python:", sys.version.replace("\n", " "))
print("platform:", platform.platform())
print("torch:", torch.__version__)
print("tvm:", getattr(tvm, "__version__", "<unknown>"))
print("tvm path:", getattr(tvm, "__file__", "<unknown>"))
model = MyModel().eval()
x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)
with torch.no_grad():
eager = model(x)
print("=" * 80)
print("PyTorch eager")
print("=" * 80)
print("input shape:", tuple(x.shape), "dtype:", x.dtype)
print("eager:", eager, eager.dtype, eager.shape)
ep = torch.export.export(model, (x,))
from tvm.relax.frontend.torch import from_exported_program
ir_mod = from_exported_program(ep)
print("=" * 80)
print("Exported Relax IR")
print("=" * 80)
print(ir_mod.script(show_meta=True))
print("=" * 80)
print("tvm.compile with LLVM")
print("=" * 80)
ex = tvm.compile(
ir_mod,
target=tvm.target.Target("llvm"),
relax_pipeline="default",
tir_pipeline="default",
)
print("compile: OK")
print(ex)
if __name__ == "__main__":
try:
main()
except Exception:
print("compile: FAILED")
traceback.print_exc()
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]