This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f4ac2312b [Relay][Pytorch] Add support for `aten::tile` (#17277)
6f4ac2312b is described below

commit 6f4ac2312b9bbcbfb465ead0de410ab7dd1494a4
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Aug 19 22:31:50 2024 +0900

    [Relay][Pytorch] Add support for `aten::tile` (#17277)
    
    * add test for torch.tile
    
    * add support for `aten::tile`
---
 python/tvm/relay/frontend/pytorch.py          | 11 +++++++++++
 tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 1f78d77390..0d93ff987c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4022,6 +4022,16 @@ class PyTorchOpConverter:
             attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, 
-1, -2])
         return attn_weight
 
+    def tile(self, inputs, input_types):
+        data = inputs[0]
+        reps = []
+        for r in inputs[1]:
+            if isinstance(r, int):
+                reps.append(r)
+            else:
+                reps.append(int(_infer_value(r, {}).numpy()))
+        return _op.tile(data, reps)
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -4302,6 +4312,7 @@ class PyTorchOpConverter:
             "aten::swapaxes": self.transpose,
             "aten::linalg_vector_norm": self.linalg_vector_norm,
             "aten::scaled_dot_product_attention": 
self.scaled_dot_product_attention,
+            "aten::tile": self.tile,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index a273af8fb8..9f8fac9306 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5658,6 +5658,30 @@ def test_parameterlist():
     verify_model(ParamListModel().float().eval(), input_data=input_data)
 
 
[email protected]_gpu
+def test_forward_tile():
+    """test_forward_repeat"""
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3]
+
+    class Tile1(Module):
+        def forward(self, *args):
+            return args[0].tile(1, 1)
+
+    class Tile2(Module):
+        def forward(self, *args):
+            return args[0].tile(4, 2)
+
+    class Tile3(Module):
+        def forward(self, *args):
+            return args[0].tile(4, 2, 1)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Tile1().float().eval(), input_data=input_data)
+    verify_model(Tile2().float().eval(), input_data=input_data)
+    verify_model(Tile3().float().eval(), input_data=input_data)
+
+
 class TestSetSpan:
     """test structural equal between translated / hand-crafted relay IR with 
span tagged."""
 

Reply via email to