This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ef582a331 [Relax][PyTorch] Add support for linspace op in fx graph 
(#17915)
4ef582a331 is described below

commit 4ef582a3319f30fac2716091f835e493ec161ffd
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 4 21:44:14 2025 +0800

    [Relax][PyTorch] Add support for linspace op in fx graph (#17915)
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
---
 python/tvm/relax/frontend/torch/fx_translator.py |  1 +
 tests/python/relax/test_frontend_from_fx.py      | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 2688f83c86..5f65f86a43 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -835,6 +835,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "fill_": self._inplace_fill,
             "full": self._full,
             "index_select": self._index_select,
+            "linspace": self._linspace,
             "masked_fill_": self._inplace_masked_fill,
             "masked_fill": self._masked_fill,
             "masked_scatter": self._masked_scatter,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2ab20fbb11..490a2309aa 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5396,5 +5396,23 @@ def test_eye():
     )
 
 
+def test_linspace():
+    import numpy as np
+
+    class Linspace(Module):
+        def forward(self, input):
+            return torch.linspace(0, 1, steps=9)
+
+    graph_model = fx.symbolic_trace(Linspace())
+    mod = from_fx(graph_model, [([9, 9], "float32")])
+    assert len(mod["main"].body.blocks) == 1
+    assert len(mod["main"].body.blocks[0].bindings) == 1
+    assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
+    tvm.testing.assert_allclose(
+        mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+        np.linspace(0, 1, num=9, dtype="float32"),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to