This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 25a37e73fd [Relax][PyTorch] Add support for sparse matrix 
multiplication and random number generation (#18499)
25a37e73fd is described below

commit 25a37e73fd696c76cdeebcf392953c5b1de0da04
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Fri Nov 28 14:07:14 2025 +0800

    [Relax][PyTorch] Add support for sparse matrix multiplication and random 
number generation (#18499)
    
    This commit adds support for sparse matrix multiplication and random
    number generation in PyTorch frontend.
    
    Changes:
    - Add _sparse_mm() method to handle sparse matrix multiplication
    - Add _sparse_addmm() method to handle sparse addmm operations
    - Add _randn() method to handle torch.randn random number generation
    - Register these operations in the convert_map
    
    The fix ensures that PyTorch models containing sparse matrix operations
    and random number generation can be successfully converted to TVM Relax
    modules.
    
    Fixes #18476
---
 .../frontend/torch/exported_program_translator.py  | 45 +++++++++++
 .../relax/test_frontend_from_exported_program.py   | 90 ++++++++++++++++++++++
 2 files changed, 135 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1f60d02a79..04e5330ce6 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -919,6 +919,49 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         )
         return self.block_builder.emit(relax.op.zeros(size, dtype))
 
+    def _sparse_mm(self, node: fx.Node) -> relax.Var:
+        """Handle sparse matrix multiplication by converting sparse tensor to 
dense."""
+        args = self.retrieve_args(node)
+        sparse_input = args[0]
+        dense_input = args[1]
+        # Convert sparse tensor to dense if needed
+        # Note: sparse_input should already be converted to dense in 
_convert_pytorch_tensor_to_tvm
+        # Use regular matrix multiplication
+        return self.block_builder.emit(
+            relax.op.linear_algebra.matmul(sparse_input, dense_input, 
out_dtype="float32")
+        )
+
+    def _sparse_addmm(self, node: fx.Node) -> relax.Var:
+        """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1, 
mat2))."""
+        args = self.retrieve_args(node)
+        input_tensor = args[0]  # beta * input
+        sparse_mat1 = args[1]  # sparse matrix
+        dense_mat2 = args[2]  # dense matrix
+        alpha = node.kwargs.get("alpha", 1.0)
+        beta = node.kwargs.get("beta", 1.0)
+
+        # Convert sparse tensor to dense if needed
+        # Note: sparse_mat1 should already be converted to dense in 
_convert_pytorch_tensor_to_tvm
+        # Compute alpha * sparse_mm(mat1, mat2)
+        matmul_result = self.block_builder.emit(
+            relax.op.linear_algebra.matmul(sparse_mat1, dense_mat2, 
out_dtype="float32")
+        )
+
+        if alpha != 1.0:
+            alpha_const = relax.const(alpha, matmul_result.struct_info.dtype)
+            matmul_result = 
self.block_builder.emit(relax.op.multiply(matmul_result, alpha_const))
+
+        # Compute beta * input + alpha * matmul_result
+        if beta != 0.0:
+            if beta != 1.0:
+                beta_const = relax.const(beta, input_tensor.struct_info.dtype)
+                input_scaled = 
self.block_builder.emit(relax.op.multiply(input_tensor, beta_const))
+            else:
+                input_scaled = input_tensor
+            return self.block_builder.emit(relax.op.add(input_scaled, 
matmul_result))
+        else:
+            return matmul_result
+
     def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
         """Convert torch.nn.functional.grid_sample to 
relax.op.image.grid_sample."""
         args = self.retrieve_args(node)
@@ -1212,6 +1255,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
             "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
             "addmm.default": self._addmm,
+            "_sparse_mm.default": self._sparse_mm,
+            "_sparse_addmm.default": self._sparse_addmm,
             "avg_pool1d.default": self._avg_pool1d,
             "avg_pool2d.default": self._avg_pool2d,
             "avg_pool3d.default": self._avg_pool3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 31743c2d12..fe3ff28aea 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2038,6 +2038,63 @@ def test_addmm():
     verify_model(Addmm2(), example_args, {}, expected2)
 
 
+def test_sparse_addmm():
+    class SparseAddmm1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x1, x2, x3):
+            return torch.sparse.addmm(x1, x2, x3)
+
+    class SparseAddmm2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x1, x2, x3):
+            return torch.sparse.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, 
R.const(0.8, "float32"))
+                lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+    )
+
+    verify_model(SparseAddmm1(), example_args, {}, expected1)
+    verify_model(SparseAddmm2(), example_args, {}, expected2)
+
+
 def test_avg_pool1d():
     class AvgPool1d1(Module):
         def __init__(self):
@@ -7741,6 +7798,39 @@ def test_mm():
     verify_model(MatrixMultiply(), example_args, {}, Expected)
 
 
+def test_sparse_mm():
+    class SparseMatrixMultiply(Module):
+        def forward(self, sparse_input, dense_input):
+            return torch.sparse.mm(sparse_input, dense_input)
+
+    indices = torch.tensor([[0, 1, 2], [2, 0, 1]])
+    values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+    sparse_input = torch.sparse_coo_tensor(indices, values, size=(3, 100))
+    dense_input = torch.randn(100, 50, dtype=torch.float32)
+
+    example_args = (sparse_input, dense_input)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            sparse_input: R.Tensor((3, 100), dtype="float32"),
+            dense_input: R.Tensor((100, 50), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((3, 50), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 50), dtype="float32") = R.full(
+                    R.shape([3, 50]), R.const(0.0, "float32"), dtype="float32"
+                )
+                lv1: R.Tensor((3, 50), dtype="float32") = R.matmul(
+                    sparse_input, dense_input, out_dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((3, 50), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    verify_model(SparseMatrixMultiply(), example_args, {}, Expected)
+
+
 def test_lstm():
     class BasicLSTM(nn.Module):
         def __init__(self):

Reply via email to