tinywisdom opened a new issue, #19550:
URL: https://github.com/apache/tvm/issues/19550
### Summary
A minimal PyTorch model using integer tensor `pow` can be exported and
converted to Relax, but fails during `tvm.compile` in the default Relax
pipeline.
The PyTorch eager behavior is valid:
```python
x = torch.tensor([-1, 1], dtype=torch.int64)
x.pow(4)
# tensor([1, 1])
```
However, after `torch.export` and `from_exported_program`, TVM lowers this
operation to:
```
R.power(x, R.const(4, "int64"))
```
During `relax.transform.LegalizeOps`, this reaches TOPI/TIR power, which
checks that the input dtype is floating-point and raises:
```
InternalError: Check failed: (x.dtype().is_float()) is false: power only
applies to float
```
This looks like either:
1. the PyTorch frontend should not lower integer aten.pow directly to
R.power if integer power is unsupported by Relax/TOPI legalization; or
2. Relax/TOPI should support integer tensor power for integer constant
exponents; or
3. the frontend/compiler should report a clearer unsupported-dtype
diagnostic instead of failing with an internal assertion during legalization.
### Actual behavior
`tvm.compile` fails during the default Relax pipeline, specifically in
`LegalizeOps`:
```
InternalError: Check failed: (x.dtype().is_float()) is false: power only
applies to float
```
The relevant part of the stack trace is:
```
tvm.relax.build
-> relax_pipeline(mod)
-> relax.transform.LegalizeOps
-> binary.py: binary_call_te
-> bb.call_te(te_func, arg0, arg1)
-> topi.broadcast.power
-> tvm::pow
-> ICHECK(x.dtype().is_float()) << "power only applies to float"
```
### Environment
TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6
### Steps to reproduce
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import platform
import torch
import tvm
class MyModel(torch.nn.Module):
def forward(self, x):
return x.pow(4)
def main():
print("=" * 80)
print("Environment")
print("=" * 80)
print("python:", sys.version.replace("\n", " "))
print("platform:", platform.platform())
print("torch:", torch.__version__)
print("tvm:", getattr(tvm, "__version__", "<unknown>"))
print("tvm path:", getattr(tvm, "__file__", "<unknown>"))
model = MyModel().eval()
x = torch.tensor([-1, 1], dtype=torch.int64)
with torch.no_grad():
eager_out = model(x)
print("=" * 80)
print("PyTorch eager")
print("=" * 80)
print("input:", x)
print("eager output:", eager_out)
ep = torch.export.export(model, (x,))
from tvm.relax.frontend.torch import from_exported_program
ir_mod = from_exported_program(ep)
print("=" * 80)
print("Exported Relax IR")
print("=" * 80)
print(ir_mod.script(show_meta=True))
print("=" * 80)
print("tvm.compile")
print("=" * 80)
target = tvm.target.Target("llvm")
ex = tvm.compile(
ir_mod,
target=target,
relax_pipeline="default",
tir_pipeline="default",
)
print("compile: OK")
print(ex)
if __name__ == "__main__":
main()
```
### Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]