This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c80e5c7a4 [Relax][PyTorch] Add support for index_select (#17790)
2c80e5c7a4 is described below
commit 2c80e5c7a4d8d8c4125226beb16d792b964f3e5f
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 31 22:15:38 2025 -0400
[Relax][PyTorch] Add support for index_select (#17790)
* suddenly copy.default is unsupported
* wip
* Able to split uneven tensors!
Remaining TODOs
Make sure that also works if we input a list of indices
Write tests at the py and cpp level
Cleanup
* split size test passes!
* test sizes and lists
* just one func
* cleanup
* no assert
* linting
* chunk
* remove unsused modulo
* fixed first test
* fixed second test and lint
* linting
* fix one test
* chunk not passing anymore
* get_item error
* chunk unit tests
* index select test passes
* fix test
* cleanup
---
python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 ++++++
.../tvm/relax/frontend/torch/exported_program_translator.py | 3 ++-
python/tvm/relax/frontend/torch/fx_translator.py | 6 ------
tests/python/relax/test_from_exported_to_cuda.py | 12 ++++++++++++
4 files changed, 20 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 839b4eb1bd..74c620a33d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1243,6 +1243,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
return self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+ def _index_select(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ index = self.env[node.args[2]]
+ return self.block_builder.emit(relax.op.take(x, index, dim))
+
def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a28da6ee72..97ccc6393c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -390,7 +390,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"reshape.default": self._reshape,
# tensor creation
"_to_copy.default": self._to_copy,
- "lift_fresh_copy.default": self._to_copy,
"detach.default": self._detach,
"detach_.default": self._detach,
"arange.start": self._arange,
@@ -399,6 +398,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"empty.memory_format": self._empty,
"empty_like.default": self._empty_like,
"fill.Scalar": self._fill,
+ "index_select.default": self._index_select,
+ "lift_fresh_copy.default": self._to_copy,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
# other
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index c4008a9396..c3d605a329 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -489,12 +489,6 @@ class TorchFXImporter(BaseFXGraphImporter):
)
)
- def _index_select(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- dim = node.args[1]
- index = self.env[node.args[2]]
- return self.block_builder.emit(relax.op.take(x, index, dim))
-
def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index 64babdc43a..19b8f80a23 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -467,5 +467,17 @@ def test_chunk_too_many(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_index_select(target, dev):
+ class IndexSelectModel(nn.Module):
+ def forward(self, x):
+ indices = torch.tensor([0, 2])
+ return torch.index_select(x, 0, indices)
+
+ raw_data = np.random.rand(3, 4).astype("float32")
+ torch_module = IndexSelectModel().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
if __name__ == "__main__":
tvm.testing.main()