This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new effdbeff1d [Relax][PyTorch] Support flip, gather, take ops for 
ExportedProgram importer (#17747)
effdbeff1d is described below

commit effdbeff1dd1f53cdfdbc015aadec85bdb5b29b4
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Mar 14 20:15:19 2025 +0800

    [Relax][PyTorch] Support flip, gather, take ops for ExportedProgram 
importer (#17747)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_nn_exporter.py
    
    * Update test_frontend_nn_exporter.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/exported_program_translator.py  |   3 +
 .../relax/test_frontend_from_exported_program.py   | 146 +++++++++++++++++++++
 2 files changed, 149 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0016046b0e..d67cacb960 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "cumsum.default": self._cumsum,
             "expand.default": self._expand,
             "expand_as.default": self._expand_as,
+            "flip.default": self._flip,
+            "gather.default": self._gather,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
             "select.int": self._select,
@@ -312,6 +314,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "split.Tensor": self._split,
             "squeeze.default": self._squeeze,
             "squeeze.dim": self._squeeze,
+            "take.default": self._take,
             "tile.default": self._tile,
             "transpose.int": self._transpose,
             "unsqueeze.default": lambda node: self.block_builder.emit(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e18986187d..e2be933050 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3425,5 +3425,151 @@ def test_no_bind_return_tuple():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_gather():
+    class Gather0(Module):
+        def forward(self, data, indices):
+            return torch.gather(data, 0, indices)
+
+    class Gather1(Module):
+        def forward(self, data, indices):
+            return torch.gather(data, 1, indices)
+
+    class Gather2(Module):
+        def forward(self, data, indices):
+            return torch.gather(data, -1, indices)
+
+    class Gather3(Module):
+        def forward(self, data, indices):
+            return torch.gather(data, -2, indices)
+
+    @tvm.script.ir_module
+    class Expected0:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = 
R.gather_elements(inp_0, inp_1, axis=0)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = 
R.gather_elements(inp_0, inp_1, axis=1)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = 
R.gather_elements(inp_0, inp_1, axis=-1)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = 
R.gather_elements(inp_0, inp_1, axis=-2)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(2, 3, dtype=torch.float32),
+        torch.randint(0, 3, (2, 3), dtype=torch.int64),
+    )
+
+    verify_model(Gather0(), example_args, {}, Expected0)
+    verify_model(Gather1(), example_args, {}, Expected1)
+    verify_model(Gather2(), example_args, {}, Expected2)
+    verify_model(Gather3(), example_args, {}, Expected3)
+
+
+def test_flip():
+    class Flip0(Module):
+        def forward(self, data):
+            return torch.flip(data, [0])
+
+    class Flip1(Module):
+        def forward(self, data):
+            return torch.flip(data, [1])
+
+    @tvm.script.ir_module
+    class Expected0:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0)
+                gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1)
+                gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 2, dtype=torch.float32),)
+
+    verify_model(Flip0(), example_args, {}, Expected0)
+    verify_model(Flip1(), example_args, {}, Expected1)
+
+
+def test_take():
+    class Take(Module):
+        def forward(self, data, indices):
+            return torch.take(data, indices)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"),
+            inp_1: R.Tensor((3,), dtype="int64"),
+        ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, 
dtype="int32")
+                lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, 
axis=None)
+                gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(5, dtype=torch.float32),
+        torch.randint(0, 5, (3,), dtype=torch.int64),
+    )
+
+    verify_model(Take(), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to