This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new adce5a4745 [Relax][PyTorch] Add support for prod, std and var ops 
(#17772)
adce5a4745 is described below

commit adce5a4745d55cca05dc26abfa03b444ef2a4810
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Mar 22 16:02:35 2025 +0800

    [Relax][PyTorch] Add support for prod, std and var ops (#17772)
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update base_fx_graph_translator.py
---
 .../frontend/torch/base_fx_graph_translator.py     | 21 ++++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |  3 ++
 tests/python/relax/test_frontend_from_fx.py        | 60 ++++++++++++++++++++++
 3 files changed, 84 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6bbc9d5de6..4bfdb8c1bc 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -908,6 +908,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
 
+    def _prod(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.prod(x, dim, keepdims=keepdim))
+
+    def _std(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.std(x, dim, keepdims=keepdim))
+
     def _sum(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
@@ -915,6 +929,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             return self.block_builder.emit(relax.op.sum(args[0], 
keepdims=keepdim))
         return self.block_builder.emit(relax.op.sum(args[0], args[1]))
 
+    def _var(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
+
     ########## Search ##########
 
     def _argmax_argmin(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 98de2e114b..022a7bffea 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -777,7 +777,10 @@ class TorchFXImporter(BaseFXGraphImporter):
             "lerp": self._lerp,
             # statistical
             "mean": self._mean,
+            "prod": self._prod,
+            "std": self._std,
             "sum": self._sum,
+            "var": self._var,
             # search
             "argmax": self._argmax_argmin(relax.op.argmax),
             "argmin": self._argmax_argmin(relax.op.argmin),
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 962a6accff..726ff6f8e8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4255,5 +4255,65 @@ def test_lerp():
     )
 
 
+def test_std():
+    class Std(Module):
+        def forward(self, x):
+            return torch.std(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Std(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_var():
+    class Var(Module):
+        def forward(self, x):
+            return torch.var(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.variance(inp_0, 
axis=None, keepdims=False)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Var(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_prod():
+    class Prod(Module):
+        def forward(self, x):
+            return torch.prod(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to