This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 111ddf7428 [Relax][PyTorch] Support eye op for ExportedProgram 
importer (#17864)
111ddf7428 is described below

commit 111ddf7428c57a55fa5548d4bffa16ed40e6e8cc
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 21 19:48:58 2025 +0800

    [Relax][PyTorch] Support eye op for ExportedProgram importer (#17864)
---
 .../frontend/torch/base_fx_graph_translator.py     |  7 ++++
 .../frontend/torch/exported_program_translator.py  |  2 ++
 .../relax/test_frontend_from_exported_program.py   | 40 ++++++++++++++++++++++
 3 files changed, 49 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index ae4c918900..733a5d6b1a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1416,6 +1416,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         x = self.env[node.args[0]]
         return self.block_builder.emit(relax.op.zeros_like(x))
 
+    def _eye(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        n = args[0]
+        m = args[1] if len(args) > 1 else n
+        dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+        return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype))
+
     def _fill(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 9326072875..af1393329e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -453,6 +453,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "clone.default": lambda node: self.env[node.args[0]],
             "empty.memory_format": self._empty,
             "empty_like.default": self._empty_like,
+            "eye.default": self._eye,
+            "eye.m": self._eye,
             "fill.Scalar": self._fill,
             "full.default": self._full,
             "full_like.default": self._full_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 80c0bd5fb4..ce68089048 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4377,5 +4377,45 @@ def test_narrow():
     verify_model(Narrow(), example_args, {}, Expected)
 
 
+def test_eye():
+    class Eye1(Module):
+        def forward(self, input):
+            return torch.eye(3, 5, dtype=torch.float32)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            input: R.Tensor((3, 5), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, 
dtype="float32")
+                gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Eye2(Module):
+        def forward(self, input):
+            return torch.eye(5, dtype=torch.float32)
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            input: R.Tensor((5,), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, 
dtype="float32")
+                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
+    verify_model(Eye1(), example_args1, {}, Expected1)
+
+    example_args2 = (torch.randn(5, dtype=torch.float32),)
+    verify_model(Eye2(), example_args2, {}, Expected2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to