This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c5675dd10d [Relax][PyTorch] Support linspace op for ExportedProgram 
importer (#17889)
c5675dd10d is described below

commit c5675dd10de4364bd1fd627df358ffcfb3efbbb9
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Apr 25 10:27:24 2025 +0800

    [Relax][PyTorch] Support linspace op for ExportedProgram importer (#17889)
    
    * Update base_fx_graph_translator.py
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * fix lint
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/base_fx_graph_translator.py     | 23 ++++++++++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 .../relax/test_frontend_from_exported_program.py   | 21 ++++++++++++++++++++
 3 files changed, 45 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 5dd78be483..3e81ff1f0b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1465,6 +1465,29 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         self.env[node.args[0]] = output
         return output
 
+    def _linspace(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        start = args[0]
+        stop = args[1]
+        step = args[2]
+
+        if step != 1:
+            step = (stop - start) / (step - 1)
+            stop = stop + (step / 2)
+        else:
+            stop = start + step
+
+        if len(args) <= 3 or args[3] is None:
+            import torch
+
+            dtype = self._convert_data_type(str(torch.get_default_dtype()))
+        else:
+            dtype = self._convert_data_type(args[3])
+
+        return self.block_builder.emit(
+            relax.op.arange(start=start, end=stop, step=step, dtype=dtype)
+        )
+
     def _masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index db5ca01399..a3ab575c4b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -475,6 +475,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "full_like.default": self._full_like,
             "index_select.default": self._index_select,
             "lift_fresh_copy.default": self._to_copy,
+            "linspace.default": self._linspace,
             "masked_fill.Scalar": self._masked_fill,
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 589d3f5bae..e3b6f4ad9c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4817,5 +4817,26 @@ def test_eye():
     verify_model(Eye2(), example_args2, {}, Expected2)
 
 
+def test_linspace():
+    class Linspace(Module):
+        def forward(self, input):
+            return torch.linspace(0, 1, steps=9, dtype=torch.float32)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((9, 9), dtype="float32")
+        ) -> R.Tuple(R.Tensor((9,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 
0.125, dtype="float32")
+                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(9, 9, dtype=torch.float32),)
+    verify_model(Linspace(), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to