This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 81f7da8a7f [Relax][PyTorch] Support prod, std and var ops for 
ExportedProgram importer (#17785)
81f7da8a7f is described below

commit 81f7da8a7f55ab54898cac1bef401bc182364125
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Mar 28 10:16:07 2025 +0800

    [Relax][PyTorch] Support prod, std and var ops for ExportedProgram importer 
(#17785)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/exported_program_translator.py  |  3 ++
 .../relax/test_frontend_from_exported_program.py   | 63 ++++++++++++++++++++++
 2 files changed, 66 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4319fbebe7..0f1dc11787 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -353,7 +353,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "upsample_nearest2d.vec": self._upsample_nearest2d,
             # statistical
             "mean.dim": self._mean,
+            "prod.default": self._prod,
+            "std.correction": self._std,
             "sum.dim_IntList": self._sum,
+            "var.correction": self._var,
             # search
             "argmax.default": self._argmax_argmin(relax.op.argmax),
             "argmin.default": self._argmax_argmin(relax.op.argmin),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index b56ffc7d2e..739fe87dc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3704,5 +3704,68 @@ def test_take():
     verify_model(Take(), example_args, {}, Expected)
 
 
+def test_std():
+    class Std(Module):
+        def forward(self, x):
+            return torch.std(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(5, 3, dtype=torch.float32),)
+    verify_model(Std(), example_args, {}, Expected)
+
+
+def test_var():
+    class Var(Module):
+        def forward(self, x):
+            return torch.var(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.variance(inp_0, 
axis=None, keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(5, 3, dtype=torch.float32),)
+    verify_model(Var(), example_args, {}, Expected)
+
+
+def test_prod():
+    class Prod(Module):
+        def forward(self, x):
+            return torch.prod(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(5, 3, dtype=torch.float32),)
+    verify_model(Prod(), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to