This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 929b8f49ac [Relax][PyTorch] Add support for torch.permute (#17184)
929b8f49ac is described below

commit 929b8f49ac73db3c6c7430bc1a414d4210e1aae5
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Jul 23 06:28:04 2024 +0900

    [Relax][PyTorch] Add support for torch.permute (#17184)
    
    * add testcase
    
    * support torch.permute
---
 python/tvm/relax/frontend/torch/fx_translator.py | 4 ++++
 tests/python/relax/test_frontend_from_fx.py      | 9 +++++++--
 2 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5ed0f18deb..f9a5d9c33f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -550,7 +550,11 @@ class TorchFXImporter:
         return self.block_builder.emit(relax.op.reshape(x, new_shape))
 
     def _permute(self, node: fx.node.Node) -> relax.Var:
+        import torch  # type: ignore
+
         args = self.retrieve_args(node)
+        if isinstance(args[1], (torch.Size, tuple, list)):
+            return self.block_builder.emit(relax.op.permute_dims(args[0], 
tuple(args[1])))
         return self.block_builder.emit(relax.op.permute_dims(args[0], 
args[1:]))
 
     def _reshape(self, node: fx.node.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index dd2719f8ce..46c079aa99 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3029,10 +3029,14 @@ def test_datatype():
 def test_permute():
     input_info = [([1, 2, 3, 4], "float32")]
 
-    class Permute(Module):
+    class Permute1(Module):
         def forward(self, x):
             return x.permute(0, 3, 2, 1)
 
+    class Permute2(Module):
+        def forward(self, x):
+            return torch.permute(x, (0, 3, 2, 1))
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -3046,7 +3050,8 @@ def test_permute():
                 R.output(gv)
             return gv
 
-    verify_model(Permute(), input_info, {}, expected1)
+    verify_model(Permute1(), input_info, {}, expected1)
+    verify_model(Permute2(), input_info, {}, expected1)
 
 
 def test_reshape():

Reply via email to