This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 81f7da8a7f [Relax][PyTorch] Support prod, std and var ops for
ExportedProgram importer (#17785)
81f7da8a7f is described below
commit 81f7da8a7f55ab54898cac1bef401bc182364125
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Mar 28 10:16:07 2025 +0800
[Relax][PyTorch] Support prod, std and var ops for ExportedProgram importer
(#17785)
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
---
.../frontend/torch/exported_program_translator.py | 3 ++
.../relax/test_frontend_from_exported_program.py | 63 ++++++++++++++++++++++
2 files changed, 66 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4319fbebe7..0f1dc11787 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -353,7 +353,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"upsample_nearest2d.vec": self._upsample_nearest2d,
# statistical
"mean.dim": self._mean,
+ "prod.default": self._prod,
+ "std.correction": self._std,
"sum.dim_IntList": self._sum,
+ "var.correction": self._var,
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index b56ffc7d2e..739fe87dc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3704,5 +3704,68 @@ def test_take():
verify_model(Take(), example_args, {}, Expected)
+def test_std():
+ class Std(Module):
+ def forward(self, x):
+ return torch.std(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(5, 3, dtype=torch.float32),)
+ verify_model(Std(), example_args, {}, Expected)
+
+
+def test_var():
+ class Var(Module):
+ def forward(self, x):
+ return torch.var(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.variance(inp_0,
axis=None, keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(5, 3, dtype=torch.float32),)
+ verify_model(Var(), example_args, {}, Expected)
+
+
+def test_prod():
+ class Prod(Module):
+ def forward(self, x):
+ return torch.prod(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(5, 3, dtype=torch.float32),)
+ verify_model(Prod(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()