This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 929b8f49ac [Relax][PyTorch] Add support for torch.permute (#17184)
929b8f49ac is described below
commit 929b8f49ac73db3c6c7430bc1a414d4210e1aae5
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Jul 23 06:28:04 2024 +0900
[Relax][PyTorch] Add support for torch.permute (#17184)
* add testcase
* support torch.permute
---
python/tvm/relax/frontend/torch/fx_translator.py | 4 ++++
tests/python/relax/test_frontend_from_fx.py | 9 +++++++--
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5ed0f18deb..f9a5d9c33f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -550,7 +550,11 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.reshape(x, new_shape))
def _permute(self, node: fx.node.Node) -> relax.Var:
+ import torch # type: ignore
+
args = self.retrieve_args(node)
+ if isinstance(args[1], (torch.Size, tuple, list)):
+ return self.block_builder.emit(relax.op.permute_dims(args[0],
tuple(args[1])))
return self.block_builder.emit(relax.op.permute_dims(args[0],
args[1:]))
def _reshape(self, node: fx.node.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index dd2719f8ce..46c079aa99 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3029,10 +3029,14 @@ def test_datatype():
def test_permute():
input_info = [([1, 2, 3, 4], "float32")]
- class Permute(Module):
+ class Permute1(Module):
def forward(self, x):
return x.permute(0, 3, 2, 1)
+ class Permute2(Module):
+ def forward(self, x):
+ return torch.permute(x, (0, 3, 2, 1))
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -3046,7 +3050,8 @@ def test_permute():
R.output(gv)
return gv
- verify_model(Permute(), input_info, {}, expected1)
+ verify_model(Permute1(), input_info, {}, expected1)
+ verify_model(Permute2(), input_info, {}, expected1)
def test_reshape():