This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd7fecc95a [Relax][PyTorch] Add support for eye op in fx graph (#17908)
fd7fecc95a is described below

commit fd7fecc95a0ecdf0e9c3240790a104ce697171bb
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Apr 30 19:28:27 2025 +0800

    [Relax][PyTorch] Add support for eye op in fx graph (#17908)
---
 python/tvm/relax/frontend/torch/fx_translator.py |  1 +
 tests/python/relax/test_frontend_from_fx.py      | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 39e562e06a..6d1d218651 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -821,6 +821,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "clone": lambda node: self.env[node.args[0]],
             "empty": self._empty,
             "empty_like": self._empty_like,
+            "eye": self._eye,
             "fill": self._fill,
             "fill_": self._inplace_fill,
             "full": self._full,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 0d50b4a112..eeabcd40bd 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5342,5 +5342,23 @@ def test_bfloat16():
     verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], 
"bfloat16")], {}, Expected)
 
 
+def test_eye():
+    import numpy as np
+
+    class Eye(Module):
+        def forward(self, input):
+            return torch.eye(3)
+
+    graph_model = fx.symbolic_trace(Eye())
+    mod = from_fx(graph_model, [([3, 3], "float32")])
+    assert len(mod["main"].body.blocks) == 1
+    assert len(mod["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    tvm.testing.assert_allclose(
+        mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+        np.eye(3, dtype="float32"),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to