This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e7bcf17e4f [Relax][PyTorch] Support MatrixMultiply op for 
ExportedProgram importer (#18343)
e7bcf17e4f is described below

commit e7bcf17e4f18e82ee4fd65d8caee10b72a1386bd
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Sep 25 15:05:21 2025 -0400

    [Relax][PyTorch] Support MatrixMultiply op for ExportedProgram importer 
(#18343)
    
    This pr supports `mm.default` for ExportedProgram importer.
    
    Resolves the issue #18339.
---
 .../frontend/torch/exported_program_translator.py  |  3 +++
 .../relax/test_frontend_from_exported_program.py   | 26 ++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7c20d1b1a4..3cf07effec 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -434,6 +434,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "matmul.default": self._binary_op(
                 partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
             ),
+            "mm.default": self._binary_op(
+                partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
+            ),
             "max.other": self._binary_op(relax.op.maximum, max),
             "min.other": self._binary_op(relax.op.minimum, min),
             "max.default": self._unary_op(relax.op.max),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 2871e3f4cd..ead341de28 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5914,6 +5914,32 @@ def test_dtypes(torch_dtype, relax_dtype):
     verify_model(Model(), example_args, {}, Expected)
 
 
+def test_mm():
+    class MatrixMultiply(Module):
+        def forward(self, a, b):
+            return torch.mm(a, b)
+
+    example_args = (
+        torch.randn(2, 3, dtype=torch.float32),
+        torch.randn(3, 4, dtype=torch.float32),
+    )
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            a: R.Tensor((2, 3), dtype="float32"),
+            b: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, 
out_dtype="float32")
+                gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(MatrixMultiply(), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
 1

Reply via email to