This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f9ac3b98b1 [Relay][Pytorch] Add support for aten::swapaxes operator
(#16079)
f9ac3b98b1 is described below
commit f9ac3b98b12badb727215097cdf380809cb01309
Author: Duc-Nhat Luong <[email protected]>
AuthorDate: Fri Nov 10 11:18:36 2023 +0900
[Relay][Pytorch] Add support for aten::swapaxes operator (#16079)
support the pytorch's maxvit model by adding the aten::swapaxes operator
support.
Co-authored-by: Masahiro Hiramori
<[email protected]>
---
python/tvm/relay/frontend/pytorch.py | 1 +
tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
2 files changed, 25 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 81392a08ec..402ab59202 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4108,6 +4108,7 @@ class PyTorchOpConverter:
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
"aten::copy_": self.inplace_copy,
+ "aten::swapaxes": self.transpose,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index abdbda8e40..b9c1b6ce9c 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5381,6 +5381,30 @@ def test_inplace_copy():
verify_model(PartialDimensionInplaceCopy(), [inputs])
[email protected]_gpu
+def test_swapaxes():
+ """test_swapaxes"""
+ torch.set_grad_enabled(False)
+ input_shape = [2, 3, 10, 5]
+
+ class Swapaxes1(Module):
+ def forward(self, *args):
+ return args[0].swapaxes(2, 3)
+
+ class Swapaxes2(Module):
+ def forward(self, *args):
+ return args[0].swapaxes(-2, -1)
+
+ class Swapaxes3(Module):
+ def forward(self, *args):
+ return args[0].swapaxes(1, 1)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Swapaxes1().float().eval(), input_data=input_data)
+ verify_model(Swapaxes2().float().eval(), input_data=input_data)
+ verify_model(Swapaxes3().float().eval(), input_data=input_data)
+
+
class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with
span tagged."""