This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4ef582a331 [Relax][PyTorch] Add support for linspace op in fx graph
(#17915)
4ef582a331 is described below
commit 4ef582a3319f30fac2716091f835e493ec161ffd
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 4 21:44:14 2025 +0800
[Relax][PyTorch] Add support for linspace op in fx graph (#17915)
* Update fx_translator.py
* Update test_frontend_from_fx.py
---
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
tests/python/relax/test_frontend_from_fx.py | 18 ++++++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 2688f83c86..5f65f86a43 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -835,6 +835,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"fill_": self._inplace_fill,
"full": self._full,
"index_select": self._index_select,
+ "linspace": self._linspace,
"masked_fill_": self._inplace_masked_fill,
"masked_fill": self._masked_fill,
"masked_scatter": self._masked_scatter,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2ab20fbb11..490a2309aa 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5396,5 +5396,23 @@ def test_eye():
)
+def test_linspace():
+ import numpy as np
+
+ class Linspace(Module):
+ def forward(self, input):
+ return torch.linspace(0, 1, steps=9)
+
+ graph_model = fx.symbolic_trace(Linspace())
+ mod = from_fx(graph_model, [([9, 9], "float32")])
+ assert len(mod["main"].body.blocks) == 1
+ assert len(mod["main"].body.blocks[0].bindings) == 1
+ assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
+ tvm.testing.assert_allclose(
+ mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+ np.linspace(0, 1, num=9, dtype="float32"),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()