This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7bcf17e4f [Relax][PyTorch] Support MatrixMultiply op for
ExportedProgram importer (#18343)
e7bcf17e4f is described below
commit e7bcf17e4f18e82ee4fd65d8caee10b72a1386bd
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Sep 25 15:05:21 2025 -0400
[Relax][PyTorch] Support MatrixMultiply op for ExportedProgram importer
(#18343)
This pr supports `mm.default` for ExportedProgram importer.
Resolves the issue #18339.
---
.../frontend/torch/exported_program_translator.py | 3 +++
.../relax/test_frontend_from_exported_program.py | 26 ++++++++++++++++++++++
2 files changed, 29 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7c20d1b1a4..3cf07effec 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -434,6 +434,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"matmul.default": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
),
+ "mm.default": self._binary_op(
+ partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
+ ),
"max.other": self._binary_op(relax.op.maximum, max),
"min.other": self._binary_op(relax.op.minimum, min),
"max.default": self._unary_op(relax.op.max),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 2871e3f4cd..ead341de28 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5914,6 +5914,32 @@ def test_dtypes(torch_dtype, relax_dtype):
verify_model(Model(), example_args, {}, Expected)
+def test_mm():
+ class MatrixMultiply(Module):
+ def forward(self, a, b):
+ return torch.mm(a, b)
+
+ example_args = (
+ torch.randn(2, 3, dtype=torch.float32),
+ torch.randn(3, 4, dtype=torch.float32),
+ )
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ a: R.Tensor((2, 3), dtype="float32"),
+ b: R.Tensor((3, 4), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b,
out_dtype="float32")
+ gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(MatrixMultiply(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
1