This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new effdbeff1d [Relax][PyTorch] Support flip, gather, take ops for
ExportedProgram importer (#17747)
effdbeff1d is described below
commit effdbeff1dd1f53cdfdbc015aadec85bdb5b29b4
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Mar 14 20:15:19 2025 +0800
[Relax][PyTorch] Support flip, gather, take ops for ExportedProgram
importer (#17747)
* Update exported_program_translator.py
* Update test_frontend_nn_exporter.py
* Update test_frontend_nn_exporter.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
---
.../frontend/torch/exported_program_translator.py | 3 +
.../relax/test_frontend_from_exported_program.py | 146 +++++++++++++++++++++
2 files changed, 149 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0016046b0e..d67cacb960 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"cumsum.default": self._cumsum,
"expand.default": self._expand,
"expand_as.default": self._expand_as,
+ "flip.default": self._flip,
+ "gather.default": self._gather,
"permute.default": self._permute,
"repeat.default": self._repeat,
"select.int": self._select,
@@ -312,6 +314,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"split.Tensor": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
+ "take.default": self._take,
"tile.default": self._tile,
"transpose.int": self._transpose,
"unsqueeze.default": lambda node: self.block_builder.emit(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e18986187d..e2be933050 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3425,5 +3425,151 @@ def test_no_bind_return_tuple():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_gather():
+ class Gather0(Module):
+ def forward(self, data, indices):
+ return torch.gather(data, 0, indices)
+
+ class Gather1(Module):
+ def forward(self, data, indices):
+ return torch.gather(data, 1, indices)
+
+ class Gather2(Module):
+ def forward(self, data, indices):
+ return torch.gather(data, -1, indices)
+
+ class Gather3(Module):
+ def forward(self, data, indices):
+ return torch.gather(data, -2, indices)
+
+ @tvm.script.ir_module
+ class Expected0:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") =
R.gather_elements(inp_0, inp_1, axis=0)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") =
R.gather_elements(inp_0, inp_1, axis=1)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") =
R.gather_elements(inp_0, inp_1, axis=-1)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") =
R.gather_elements(inp_0, inp_1, axis=-2)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(2, 3, dtype=torch.float32),
+ torch.randint(0, 3, (2, 3), dtype=torch.int64),
+ )
+
+ verify_model(Gather0(), example_args, {}, Expected0)
+ verify_model(Gather1(), example_args, {}, Expected1)
+ verify_model(Gather2(), example_args, {}, Expected2)
+ verify_model(Gather3(), example_args, {}, Expected3)
+
+
+def test_flip():
+ class Flip0(Module):
+ def forward(self, data):
+ return torch.flip(data, [0])
+
+ class Flip1(Module):
+ def forward(self, data):
+ return torch.flip(data, [1])
+
+ @tvm.script.ir_module
+ class Expected0:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0)
+ gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 2), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1)
+ gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 2, dtype=torch.float32),)
+
+ verify_model(Flip0(), example_args, {}, Expected0)
+ verify_model(Flip1(), example_args, {}, Expected1)
+
+
+def test_take():
+ class Take(Module):
+ def forward(self, data, indices):
+ return torch.take(data, indices)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"),
+ inp_1: R.Tensor((3,), dtype="int64"),
+ ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1,
dtype="int32")
+ lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv,
axis=None)
+ gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(5, dtype=torch.float32),
+ torch.randint(0, 5, (3,), dtype=torch.int64),
+ )
+
+ verify_model(Take(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()