This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f9ac3b98b1 [Relay][Pytorch] Add support for aten::swapaxes operator 
(#16079)
f9ac3b98b1 is described below

commit f9ac3b98b12badb727215097cdf380809cb01309
Author: Duc-Nhat Luong <[email protected]>
AuthorDate: Fri Nov 10 11:18:36 2023 +0900

    [Relay][Pytorch] Add support for aten::swapaxes operator (#16079)
    
    support the pytorch's maxvit model by adding the aten::swapaxes operator 
support.
    
    Co-authored-by: Masahiro Hiramori 
<[email protected]>
---
 python/tvm/relay/frontend/pytorch.py          |  1 +
 tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
 2 files changed, 25 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 81392a08ec..402ab59202 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4108,6 +4108,7 @@ class PyTorchOpConverter:
             "aten::multinomial": self.multinomial,
             "aten::_weight_norm": self.weight_norm,
             "aten::copy_": self.inplace_copy,
+            "aten::swapaxes": self.transpose,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index abdbda8e40..b9c1b6ce9c 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5381,6 +5381,30 @@ def test_inplace_copy():
     verify_model(PartialDimensionInplaceCopy(), [inputs])
 
 
[email protected]_gpu
+def test_swapaxes():
+    """test_swapaxes"""
+    torch.set_grad_enabled(False)
+    input_shape = [2, 3, 10, 5]
+
+    class Swapaxes1(Module):
+        def forward(self, *args):
+            return args[0].swapaxes(2, 3)
+
+    class Swapaxes2(Module):
+        def forward(self, *args):
+            return args[0].swapaxes(-2, -1)
+
+    class Swapaxes3(Module):
+        def forward(self, *args):
+            return args[0].swapaxes(1, 1)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Swapaxes1().float().eval(), input_data=input_data)
+    verify_model(Swapaxes2().float().eval(), input_data=input_data)
+    verify_model(Swapaxes3().float().eval(), input_data=input_data)
+
+
 class TestSetSpan:
     """test structural equal between translated / hand-crafted relay IR with 
span tagged."""
 

Reply via email to