This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 65121c878a [Relay][Frontend] Add support for aten::concat (#16199)
65121c878a is described below

commit 65121c878aed37adb6434b0238e36611a59881d7
Author: Jongho Choi <sweetco...@snu.ac.kr>
AuthorDate: Sat Dec 9 10:34:26 2023 +0900

    [Relay][Frontend] Add support for aten::concat (#16199)
    
    * Update pytorch.py
    
    * Add concat test
    
    * rm whitespace
    
    * Add diable docstring
    
    * update comment
---
 python/tvm/relay/frontend/pytorch.py          |  1 +
 tests/python/frontend/pytorch/test_forward.py | 22 ++++++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 9583575bfc..c507da13a7 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4051,6 +4051,7 @@ class PyTorchOpConverter:
             "aten::squeeze": self.squeeze,
             "aten::unsqueeze": self.unsqueeze,
             "aten::cat": self.concatenate,
+            "aten::concat": self.concatenate,
             "aten::slice": self.slice,
             "aten::narrow": self.narrow,
             "aten::split": self.split,
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 6109141dea..56afe72ecd 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -720,9 +720,31 @@ def test_forward_concatenate():
             c = (args[0][:, :, 2] + 5) * 13
             return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)
 
+    class Concatenate3(Module):
+        """
+        torch.concat is preserved as aten::concat only when in a nested module.
+        (In the most cases, It is converted to aten::cat instead of 
aten::concat.)
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            class _Concatenate(Module):
+                def forward(self, *args):
+                    a = (args[0][:, :, 0] + 2) * 7
+                    b = (args[0][:, :, 1] + 3) * 11
+                    c = (args[0][:, :, 2] + 5) * 13
+                    return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2)
+
+            self.mod = _Concatenate()
+
+        def forward(self, *args):
+            return self.mod(*args)
+
     input_data = torch.rand(input_shape).float()
     verify_model(Concatenate1().float().eval(), input_data=input_data)
     verify_model(Concatenate2().float().eval(), input_data=input_data)
+    verify_model(Concatenate3().float().eval(), input_data=input_data)
 
 
 @tvm.testing.uses_gpu

Reply via email to