This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f4ac2312b [Relay][Pytorch] Add support for `aten::tile` (#17277)
6f4ac2312b is described below
commit 6f4ac2312b9bbcbfb465ead0de410ab7dd1494a4
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Aug 19 22:31:50 2024 +0900
[Relay][Pytorch] Add support for `aten::tile` (#17277)
* add test for torch.tile
* add support for `aten::tile`
---
python/tvm/relay/frontend/pytorch.py | 11 +++++++++++
tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
2 files changed, 35 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 1f78d77390..0d93ff987c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4022,6 +4022,16 @@ class PyTorchOpConverter:
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size,
-1, -2])
return attn_weight
+ def tile(self, inputs, input_types):
+ data = inputs[0]
+ reps = []
+ for r in inputs[1]:
+ if isinstance(r, int):
+ reps.append(r)
+ else:
+ reps.append(int(_infer_value(r, {}).numpy()))
+ return _op.tile(data, reps)
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -4302,6 +4312,7 @@ class PyTorchOpConverter:
"aten::swapaxes": self.transpose,
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention":
self.scaled_dot_product_attention,
+ "aten::tile": self.tile,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index a273af8fb8..9f8fac9306 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5658,6 +5658,30 @@ def test_parameterlist():
verify_model(ParamListModel().float().eval(), input_data=input_data)
[email protected]_gpu
+def test_forward_tile():
+ """test_forward_repeat"""
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3]
+
+ class Tile1(Module):
+ def forward(self, *args):
+ return args[0].tile(1, 1)
+
+ class Tile2(Module):
+ def forward(self, *args):
+ return args[0].tile(4, 2)
+
+ class Tile3(Module):
+ def forward(self, *args):
+ return args[0].tile(4, 2, 1)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Tile1().float().eval(), input_data=input_data)
+ verify_model(Tile2().float().eval(), input_data=input_data)
+ verify_model(Tile3().float().eval(), input_data=input_data)
+
+
class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with
span tagged."""