This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd7fecc95a [Relax][PyTorch] Add support for eye op in fx graph (#17908)
fd7fecc95a is described below
commit fd7fecc95a0ecdf0e9c3240790a104ce697171bb
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Apr 30 19:28:27 2025 +0800
[Relax][PyTorch] Add support for eye op in fx graph (#17908)
---
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
tests/python/relax/test_frontend_from_fx.py | 18 ++++++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 39e562e06a..6d1d218651 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -821,6 +821,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"clone": lambda node: self.env[node.args[0]],
"empty": self._empty,
"empty_like": self._empty_like,
+ "eye": self._eye,
"fill": self._fill,
"fill_": self._inplace_fill,
"full": self._full,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 0d50b4a112..eeabcd40bd 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5342,5 +5342,23 @@ def test_bfloat16():
verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10],
"bfloat16")], {}, Expected)
+def test_eye():
+ import numpy as np
+
+ class Eye(Module):
+ def forward(self, input):
+ return torch.eye(3)
+
+ graph_model = fx.symbolic_trace(Eye())
+ mod = from_fx(graph_model, [([3, 3], "float32")])
+ assert len(mod["main"].body.blocks) == 1
+ assert len(mod["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ tvm.testing.assert_allclose(
+ mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+ np.eye(3, dtype="float32"),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()