This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c80e5c7a4 [Relax][PyTorch] Add support for index_select  (#17790)
2c80e5c7a4 is described below

commit 2c80e5c7a4d8d8c4125226beb16d792b964f3e5f
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 31 22:15:38 2025 -0400

    [Relax][PyTorch] Add support for index_select  (#17790)
    
    * suddenly copy.default is unsupported
    
    * wip
    
    * Able to split uneven tensors!
    
    Remaining TODOs
    Make sure that also works if we input a list of indices
    Write tests at the py and cpp level
    Cleanup
    
    * split size test passes!
    
    * test sizes and lists
    
    * just one func
    
    * cleanup
    
    * no assert
    
    * linting
    
    * chunk
    
    * remove unsused modulo
    
    * fixed first test
    
    * fixed second test and lint
    
    * linting
    
    * fix one test
    
    * chunk not passing anymore
    
    * get_item error
    
    * chunk unit tests
    
    * index select test passes
    
    * fix test
    
    * cleanup
---
 python/tvm/relax/frontend/torch/base_fx_graph_translator.py  |  6 ++++++
 .../tvm/relax/frontend/torch/exported_program_translator.py  |  3 ++-
 python/tvm/relax/frontend/torch/fx_translator.py             |  6 ------
 tests/python/relax/test_from_exported_to_cuda.py             | 12 ++++++++++++
 4 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 839b4eb1bd..74c620a33d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1243,6 +1243,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
         return self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
 
+    def _index_select(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        index = self.env[node.args[2]]
+        return self.block_builder.emit(relax.op.take(x, index, dim))
+
     def _new_ones(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a28da6ee72..97ccc6393c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -390,7 +390,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "reshape.default": self._reshape,
             # tensor creation
             "_to_copy.default": self._to_copy,
-            "lift_fresh_copy.default": self._to_copy,
             "detach.default": self._detach,
             "detach_.default": self._detach,
             "arange.start": self._arange,
@@ -399,6 +398,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "empty.memory_format": self._empty,
             "empty_like.default": self._empty_like,
             "fill.Scalar": self._fill,
+            "index_select.default": self._index_select,
+            "lift_fresh_copy.default": self._to_copy,
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
             # other
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index c4008a9396..c3d605a329 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -489,12 +489,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             )
         )
 
-    def _index_select(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        dim = node.args[1]
-        index = self.env[node.args[2]]
-        return self.block_builder.emit(relax.op.take(x, index, dim))
-
     def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 64babdc43a..19b8f80a23 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -467,5 +467,17 @@ def test_chunk_too_many(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_index_select(target, dev):
+    class IndexSelectModel(nn.Module):
+        def forward(self, x):
+            indices = torch.tensor([0, 2])
+            return torch.index_select(x, 0, indices)
+
+    raw_data = np.random.rand(3, 4).astype("float32")
+    torch_module = IndexSelectModel().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to