This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7e6165e896 Fix BufferError when converting PyTorch models with sparse
tensors (#18492)
7e6165e896 is described below
commit 7e6165e8962cb2d8bf9f5d0709e56db635378e6c
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Mon Nov 24 03:28:53 2025 +0800
Fix BufferError when converting PyTorch models with sparse tensors (#18492)
This commit fixes issue #18474 by adding support for sparse tensor
conversion in the PyTorch ExportedProgram importer. The fix
automatically detects sparse tensors (non-strided layout) and converts
them to dense tensors before DLPack conversion.
Changes:
- Add _convert_pytorch_tensor_to_tvm() static method to handle sparse
tensor detection and conversion
- Automatically convert sparse tensors to dense using .to_dense() before
DLPack conversion
- Update parameter/buffer/constant binding to use the new conversion
method
- Update parameter handling to use the new conversion method
The fix ensures that PyTorch models containing sparse tensors can be
successfully converted to TVM Relax modules without raising BufferError.
Fixes #18474
---
.../frontend/torch/exported_program_translator.py | 38 ++++++++++++++++++----
1 file changed, 32 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index d7975a8dde..883be88837 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -34,6 +34,36 @@ class ExportedProgramImporter(BaseFXGraphImporter):
from torch import fx
+ @staticmethod
+ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) ->
tvm.runtime.Tensor:
+ """Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
+
+ Parameters
+ ----------
+ tensor_value : torch.Tensor
+ The PyTorch tensor to convert.
+
+ Returns
+ -------
+ tvm.runtime.Tensor
+ The converted TVM tensor.
+ """
+ # PyTorch sparse tensors (layout != torch.strided) must be converted
to dense.
+ if tensor_value.layout != torch.strided:
+ tensor_to_convert = tensor_value.to_dense()
+ else:
+ tensor_to_convert = tensor_value
+ tensor_detached = tensor_to_convert.detach()
+
+ # Try DLPack conversion first (faster)
+ try:
+ return tvm.runtime.from_dlpack(tensor_detached)
+ except (RuntimeError, BufferError):
+ # Fallback: convert to numpy and then to TVM tensor
+ # This handles cases where DLPack conversion fails
+ tensor_cpu = tensor_detached.cpu().contiguous()
+ return tvm.runtime.tensor(tensor_cpu.numpy())
+
########## Unary Ops ##########
def _hardtanh(self, node: fx.Node) -> relax.Expr:
@@ -1502,18 +1532,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if tensor_name == spec.target:
bind_name = spec.arg.name
break
- try:
- binding[bind_name] =
tvm.runtime.from_dlpack(tensor_value.detach())
- except RuntimeError:
- tensor_cpu = tensor_value.detach().cpu().contiguous()
- binding[bind_name] = tvm.runtime.tensor(tensor_cpu.numpy())
+ binding[bind_name] =
self._convert_pytorch_tensor_to_tvm(tensor_value)
mod = self.block_builder.get()
mod = relax.transform.BindParams("main", binding)(mod)
if keep_params_as_input:
parameters = dict(exported_program.named_parameters())
- params = [tvm.runtime.from_dlpack(p.detach()) for p in
parameters.values()]
+ params = [self._convert_pytorch_tensor_to_tvm(p) for p in
parameters.values()]
mod["main"] = mod["main"].with_attr("params", params)
return mod