This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new adce5a4745 [Relax][PyTorch] Add support for prod, std and var ops
(#17772)
adce5a4745 is described below
commit adce5a4745d55cca05dc26abfa03b444ef2a4810
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Mar 22 16:02:35 2025 +0800
[Relax][PyTorch] Add support for prod, std and var ops (#17772)
* Update fx_translator.py
* Update test_frontend_from_fx.py
* Update base_fx_graph_translator.py
---
.../frontend/torch/base_fx_graph_translator.py | 21 ++++++++
python/tvm/relax/frontend/torch/fx_translator.py | 3 ++
tests/python/relax/test_frontend_from_fx.py | 60 ++++++++++++++++++++++
3 files changed, 84 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6bbc9d5de6..4bfdb8c1bc 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -908,6 +908,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
+ def _prod(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.prod(x, dim, keepdims=keepdim))
+
+ def _std(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.std(x, dim, keepdims=keepdim))
+
def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
@@ -915,6 +929,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+ def _var(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+
########## Search ##########
def _argmax_argmin(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 98de2e114b..022a7bffea 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -777,7 +777,10 @@ class TorchFXImporter(BaseFXGraphImporter):
"lerp": self._lerp,
# statistical
"mean": self._mean,
+ "prod": self._prod,
+ "std": self._std,
"sum": self._sum,
+ "var": self._var,
# search
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 962a6accff..726ff6f8e8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4255,5 +4255,65 @@ def test_lerp():
)
+def test_std():
+ class Std(Module):
+ def forward(self, x):
+ return torch.std(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None,
keepdims=False)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Std(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_var():
+ class Var(Module):
+ def forward(self, x):
+ return torch.var(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.variance(inp_0,
axis=None, keepdims=False)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Var(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_prod():
+ class Prod(Module):
+ def forward(self, x):
+ return torch.prod(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None,
keepdims=False)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()