tinywisdom opened a new issue, #18648:
URL: https://github.com/apache/tvm/issues/18648
### Summary
When using the TVM Relax Torch frontend to import a `torch.export.export`ed
program that involves a `torch.sparse_csr_tensor` registered as a buffer,
`from_exported_program` crashes with a C++ `c10::Error`:
> `layout_impl is only implemented for TensorImpl subclasses.`
`torch.export.export` itself succeeds. The crash happens during
`from_exported_program(ep)`.
---
### Environment
```text
OS: Linux x86_64
PyTorch: 2.9.0+cu128
TVM: 0.22.0
Python: 3.10.x
```
---
### Minimal Reproduction
```python
# repro_92022_export_tvm_frontend.py
import torch
import torch.nn as nn
import tvm
from tvm.relax.frontend.torch import from_exported_program
print("torch version:", torch.__version__)
print("tvm version:", getattr(tvm, "__version__", "unknown"))
class MyModel(nn.Module):
def __init__(self):
super().__init__()
crow_indices = torch.tensor([0, 1, 2], dtype=torch.int64)
col_indices = torch.tensor([0, 1], dtype=torch.int64)
values = torch.tensor([1.0, 1.0], dtype=torch.float32,
requires_grad=True)
csr_tensor = torch.sparse_csr_tensor(
crow_indices, col_indices, values, dtype=torch.float32
)
# Register sparse CSR tensor as a buffer
self.register_buffer("csr_tensor", csr_tensor)
# Explicitly enable grad as well
self.csr_tensor.requires_grad_(True)
def forward(self, x):
# Convert buffer to sparse CSR layout again
csr2 = self.csr_tensor.to_sparse(layout=torch.sparse_csr)
y = torch.matmul(csr2, x)
return y.sum()
def GetInput():
return torch.ones((2, 1), dtype=torch.float32)
def main():
model = MyModel().to("cpu").eval()
x = GetInput().to("cpu")
print("Start torch.export.export ...")
ep = torch.export.export(model, (x,))
print("torch.export.export done.")
print("Start from_exported_program ...")
ir_mod = from_exported_program(ep)
print("from_exported_program done.")
print(ir_mod)
if __name__ == "__main__":
main()
```
---
### Actual Behavior
Output on my machine:
```text
torch version: 2.9.0+cu128
tvm version: 0.22.0
.../SparseCsrTensorImpl.cpp:53: Sparse CSR tensor support is in beta state...
Start torch.export.export ...
torch.export.export done.
Start from_exported_program ...
terminate called after throwing an instance of 'c10::Error'
what(): layout_impl is only implemented for TensorImpl subclasses.
Exception raised from layout_impl at /pytorch/c10/core/TensorImpl.h:1094
(most recent call first):
frame #0: c10::Error::Error(...)
frame #1: c10::detail::torchCheckFail(...)
frame #2: ...
frame #3: torch::autograd::InputMetadata::InputMetadata(at::Tensor const&) +
...
frame #4: ...
<libtorch / libtorch_python frames omitted>
...
Aborted (core dumped)
```
`torch.export.export` completes successfully; the abort occurs only when
calling `from_exported_program(ep)`.
### Triage
Please refer to the list of label tags
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the
relevant tags and add them below in a bullet format (example below).
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]