This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c5675dd10d [Relax][PyTorch] Support linspace op for ExportedProgram
importer (#17889)
c5675dd10d is described below
commit c5675dd10de4364bd1fd627df358ffcfb3efbbb9
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Apr 25 10:27:24 2025 +0800
[Relax][PyTorch] Support linspace op for ExportedProgram importer (#17889)
* Update base_fx_graph_translator.py
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* fix lint
* Update test_frontend_from_exported_program.py
---
.../frontend/torch/base_fx_graph_translator.py | 23 ++++++++++++++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 21 ++++++++++++++++++++
3 files changed, 45 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 5dd78be483..3e81ff1f0b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1465,6 +1465,29 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
self.env[node.args[0]] = output
return output
+ def _linspace(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ start = args[0]
+ stop = args[1]
+ step = args[2]
+
+ if step != 1:
+ step = (stop - start) / (step - 1)
+ stop = stop + (step / 2)
+ else:
+ stop = start + step
+
+ if len(args) <= 3 or args[3] is None:
+ import torch
+
+ dtype = self._convert_data_type(str(torch.get_default_dtype()))
+ else:
+ dtype = self._convert_data_type(args[3])
+
+ return self.block_builder.emit(
+ relax.op.arange(start=start, end=stop, step=step, dtype=dtype)
+ )
+
def _masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index db5ca01399..a3ab575c4b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -475,6 +475,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"full_like.default": self._full_like,
"index_select.default": self._index_select,
"lift_fresh_copy.default": self._to_copy,
+ "linspace.default": self._linspace,
"masked_fill.Scalar": self._masked_fill,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 589d3f5bae..e3b6f4ad9c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4817,5 +4817,26 @@ def test_eye():
verify_model(Eye2(), example_args2, {}, Expected2)
+def test_linspace():
+ class Linspace(Module):
+ def forward(self, input):
+ return torch.linspace(0, 1, steps=9, dtype=torch.float32)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((9, 9), dtype="float32")
+ ) -> R.Tuple(R.Tensor((9,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625,
0.125, dtype="float32")
+ gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(9, 9, dtype=torch.float32),)
+ verify_model(Linspace(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()