This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2da3798dd1 [Relay][Frontend][Torch] add aten:broadcast_to  (#16319)
2da3798dd1 is described below

commit 2da3798dd150d15d4d560b495d22422c9eb23194
Author: Huan Mei <[email protected]>
AuthorDate: Mon Jan 1 07:41:56 2024 +0800

    [Relay][Frontend][Torch] add aten:broadcast_to  (#16319)
    
    Recently, I worked with the Stable Video Diffusion model, which contains 
the `aten::broadcast_to` op, but TVM does not support it.
    
    Add support for it here.
---
 python/tvm/relay/frontend/pytorch.py          | 16 ++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 25 +++++++++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 54004c379d..0213dcc488 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2328,6 +2328,21 @@ class PyTorchOpConverter:
             res_shape = list(torch.broadcast_tensors(*map(torch.empty, 
infer_shape_value))[0].shape)
         return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]
 
+    def broadcast_to(self, inputs, input_types):
+        tensor = inputs[0]
+        new_shape = inputs[1]
+        import torch
+
+        if not isinstance(new_shape, (list, tuple, torch.Size)):
+            msg = f"Data type {type(new_shape)} could not be parsed in 
broadcast_to op"
+            raise AssertionError(msg)
+
+        for i, dim in enumerate(new_shape):
+            if not isinstance(dim, int):
+                new_shape[i] = int(_infer_value(dim, {}).numpy())
+
+        return _op.broadcast_to(tensor, new_shape)
+
     def Bool(self, inputs, input_types):
         assert len(inputs) == 1
         return inputs[0]
@@ -4190,6 +4205,7 @@ class PyTorchOpConverter:
             "aten::upsample_nearest3d": 
self.make_upsample3d("nearest_neighbor"),
             "aten::expand_as": self.expand_as,
             "aten::broadcast_tensors": self.broadcast_tensors,
+            "aten::broadcast_to": self.broadcast_to,
             "aten::lt": self.make_elemwise("less"),
             "aten::gt": self.make_elemwise("greater"),
             "aten::le": self.make_elemwise("less_equal"),
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 56afe72ecd..6178a58b6d 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -2162,6 +2162,31 @@ def test_forward_broadcast_tensors():
     verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z])
 
 
[email protected]_gpu
+def test_forward_broadcast_to():
+    """test_forward_broadcast_to"""
+    torch.set_grad_enabled(False)
+
+    class BroadCastTo1(Module):
+        def forward(self, x):
+            return torch.broadcast_to(x, (3, 3))
+
+    x = torch.tensor([1, 2, 3])
+    verify_model(BroadCastTo1().float().eval(), input_data=[x])
+
+    class BroadCastTo2(Module):
+        def __init__(self):
+            super().__init__()
+            self.y = torch.tensor(1)
+            self.z = torch.tensor(2)
+
+        def forward(self, x):
+            return torch.broadcast_to(x, (self.y + self.z, 3))
+
+    x = torch.tensor([1, 2, 3])
+    verify_model(BroadCastTo2().float().eval(), input_data=[x])
+
+
 @tvm.testing.uses_gpu
 def test_forward_pow():
     """test_forward_pow"""

Reply via email to