This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 111ddf7428 [Relax][PyTorch] Support eye op for ExportedProgram
importer (#17864)
111ddf7428 is described below
commit 111ddf7428c57a55fa5548d4bffa16ed40e6e8cc
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 21 19:48:58 2025 +0800
[Relax][PyTorch] Support eye op for ExportedProgram importer (#17864)
---
.../frontend/torch/base_fx_graph_translator.py | 7 ++++
.../frontend/torch/exported_program_translator.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 40 ++++++++++++++++++++++
3 files changed, 49 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index ae4c918900..733a5d6b1a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1416,6 +1416,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.zeros_like(x))
+ def _eye(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ n = args[0]
+ m = args[1] if len(args) > 1 else n
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype))
+
def _fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 9326072875..af1393329e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -453,6 +453,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"clone.default": lambda node: self.env[node.args[0]],
"empty.memory_format": self._empty,
"empty_like.default": self._empty_like,
+ "eye.default": self._eye,
+ "eye.m": self._eye,
"fill.Scalar": self._fill,
"full.default": self._full,
"full_like.default": self._full_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 80c0bd5fb4..ce68089048 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4377,5 +4377,45 @@ def test_narrow():
verify_model(Narrow(), example_args, {}, Expected)
+def test_eye():
+ class Eye1(Module):
+ def forward(self, input):
+ return torch.eye(3, 5, dtype=torch.float32)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ input: R.Tensor((3, 5), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5,
dtype="float32")
+ gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Eye2(Module):
+ def forward(self, input):
+ return torch.eye(5, dtype=torch.float32)
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ input: R.Tensor((5,), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 5), dtype="float32") = R.eye(5,
dtype="float32")
+ gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
+ verify_model(Eye1(), example_args1, {}, Expected1)
+
+ example_args2 = (torch.randn(5, dtype=torch.float32),)
+ verify_model(Eye2(), example_args2, {}, Expected2)
+
+
if __name__ == "__main__":
tvm.testing.main()