This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 25a37e73fd [Relax][PyTorch] Add support for sparse matrix
multiplication and random number generation (#18499)
25a37e73fd is described below
commit 25a37e73fd696c76cdeebcf392953c5b1de0da04
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Fri Nov 28 14:07:14 2025 +0800
[Relax][PyTorch] Add support for sparse matrix multiplication and random
number generation (#18499)
This commit adds support for sparse matrix multiplication and random
number generation in PyTorch frontend.
Changes:
- Add _sparse_mm() method to handle sparse matrix multiplication
- Add _sparse_addmm() method to handle sparse addmm operations
- Add _randn() method to handle torch.randn random number generation
- Register these operations in the convert_map
The fix ensures that PyTorch models containing sparse matrix operations
and random number generation can be successfully converted to TVM Relax
modules.
Fixes #18476
---
.../frontend/torch/exported_program_translator.py | 45 +++++++++++
.../relax/test_frontend_from_exported_program.py | 90 ++++++++++++++++++++++
2 files changed, 135 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1f60d02a79..04e5330ce6 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -919,6 +919,49 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
return self.block_builder.emit(relax.op.zeros(size, dtype))
+ def _sparse_mm(self, node: fx.Node) -> relax.Var:
+ """Handle sparse matrix multiplication by converting sparse tensor to
dense."""
+ args = self.retrieve_args(node)
+ sparse_input = args[0]
+ dense_input = args[1]
+ # Convert sparse tensor to dense if needed
+ # Note: sparse_input should already be converted to dense in
_convert_pytorch_tensor_to_tvm
+ # Use regular matrix multiplication
+ return self.block_builder.emit(
+ relax.op.linear_algebra.matmul(sparse_input, dense_input,
out_dtype="float32")
+ )
+
+ def _sparse_addmm(self, node: fx.Node) -> relax.Var:
+ """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1,
mat2))."""
+ args = self.retrieve_args(node)
+ input_tensor = args[0] # beta * input
+ sparse_mat1 = args[1] # sparse matrix
+ dense_mat2 = args[2] # dense matrix
+ alpha = node.kwargs.get("alpha", 1.0)
+ beta = node.kwargs.get("beta", 1.0)
+
+ # Convert sparse tensor to dense if needed
+ # Note: sparse_mat1 should already be converted to dense in
_convert_pytorch_tensor_to_tvm
+ # Compute alpha * sparse_mm(mat1, mat2)
+ matmul_result = self.block_builder.emit(
+ relax.op.linear_algebra.matmul(sparse_mat1, dense_mat2,
out_dtype="float32")
+ )
+
+ if alpha != 1.0:
+ alpha_const = relax.const(alpha, matmul_result.struct_info.dtype)
+ matmul_result =
self.block_builder.emit(relax.op.multiply(matmul_result, alpha_const))
+
+ # Compute beta * input + alpha * matmul_result
+ if beta != 0.0:
+ if beta != 1.0:
+ beta_const = relax.const(beta, input_tensor.struct_info.dtype)
+ input_scaled =
self.block_builder.emit(relax.op.multiply(input_tensor, beta_const))
+ else:
+ input_scaled = input_tensor
+ return self.block_builder.emit(relax.op.add(input_scaled,
matmul_result))
+ else:
+ return matmul_result
+
def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
"""Convert torch.nn.functional.grid_sample to
relax.op.image.grid_sample."""
args = self.retrieve_args(node)
@@ -1212,6 +1255,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
"addmm.default": self._addmm,
+ "_sparse_mm.default": self._sparse_mm,
+ "_sparse_addmm.default": self._sparse_addmm,
"avg_pool1d.default": self._avg_pool1d,
"avg_pool2d.default": self._avg_pool2d,
"avg_pool3d.default": self._avg_pool3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 31743c2d12..fe3ff28aea 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2038,6 +2038,63 @@ def test_addmm():
verify_model(Addmm2(), example_args, {}, expected2)
+def test_sparse_addmm():
+ class SparseAddmm1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x1, x2, x3):
+ return torch.sparse.addmm(x1, x2, x3)
+
+ class SparseAddmm2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x1, x2, x3):
+ return torch.sparse.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1,
R.const(0.8, "float32"))
+ lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(10, 10, dtype=torch.float32),
+ )
+
+ verify_model(SparseAddmm1(), example_args, {}, expected1)
+ verify_model(SparseAddmm2(), example_args, {}, expected2)
+
+
def test_avg_pool1d():
class AvgPool1d1(Module):
def __init__(self):
@@ -7741,6 +7798,39 @@ def test_mm():
verify_model(MatrixMultiply(), example_args, {}, Expected)
+def test_sparse_mm():
+ class SparseMatrixMultiply(Module):
+ def forward(self, sparse_input, dense_input):
+ return torch.sparse.mm(sparse_input, dense_input)
+
+ indices = torch.tensor([[0, 1, 2], [2, 0, 1]])
+ values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+ sparse_input = torch.sparse_coo_tensor(indices, values, size=(3, 100))
+ dense_input = torch.randn(100, 50, dtype=torch.float32)
+
+ example_args = (sparse_input, dense_input)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ sparse_input: R.Tensor((3, 100), dtype="float32"),
+ dense_input: R.Tensor((100, 50), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((3, 50), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 50), dtype="float32") = R.full(
+ R.shape([3, 50]), R.const(0.0, "float32"), dtype="float32"
+ )
+ lv1: R.Tensor((3, 50), dtype="float32") = R.matmul(
+ sparse_input, dense_input, out_dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((3, 50), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ verify_model(SparseMatrixMultiply(), example_args, {}, Expected)
+
+
def test_lstm():
class BasicLSTM(nn.Module):
def __init__(self):