This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2da3798dd1 [Relay][Frontend][Torch] add aten:broadcast_to (#16319)
2da3798dd1 is described below
commit 2da3798dd150d15d4d560b495d22422c9eb23194
Author: Huan Mei <[email protected]>
AuthorDate: Mon Jan 1 07:41:56 2024 +0800
[Relay][Frontend][Torch] add aten:broadcast_to (#16319)
Recently, I worked with the Stable Video Diffusion model, which contains
the `aten::broadcast_to` op, but TVM does not support it.
Add support for it here.
---
python/tvm/relay/frontend/pytorch.py | 16 ++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 25 +++++++++++++++++++++++++
2 files changed, 41 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 54004c379d..0213dcc488 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2328,6 +2328,21 @@ class PyTorchOpConverter:
res_shape = list(torch.broadcast_tensors(*map(torch.empty,
infer_shape_value))[0].shape)
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]
+ def broadcast_to(self, inputs, input_types):
+ tensor = inputs[0]
+ new_shape = inputs[1]
+ import torch
+
+ if not isinstance(new_shape, (list, tuple, torch.Size)):
+ msg = f"Data type {type(new_shape)} could not be parsed in
broadcast_to op"
+ raise AssertionError(msg)
+
+ for i, dim in enumerate(new_shape):
+ if not isinstance(dim, int):
+ new_shape[i] = int(_infer_value(dim, {}).numpy())
+
+ return _op.broadcast_to(tensor, new_shape)
+
def Bool(self, inputs, input_types):
assert len(inputs) == 1
return inputs[0]
@@ -4190,6 +4205,7 @@ class PyTorchOpConverter:
"aten::upsample_nearest3d":
self.make_upsample3d("nearest_neighbor"),
"aten::expand_as": self.expand_as,
"aten::broadcast_tensors": self.broadcast_tensors,
+ "aten::broadcast_to": self.broadcast_to,
"aten::lt": self.make_elemwise("less"),
"aten::gt": self.make_elemwise("greater"),
"aten::le": self.make_elemwise("less_equal"),
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 56afe72ecd..6178a58b6d 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -2162,6 +2162,31 @@ def test_forward_broadcast_tensors():
verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z])
[email protected]_gpu
+def test_forward_broadcast_to():
+ """test_forward_broadcast_to"""
+ torch.set_grad_enabled(False)
+
+ class BroadCastTo1(Module):
+ def forward(self, x):
+ return torch.broadcast_to(x, (3, 3))
+
+ x = torch.tensor([1, 2, 3])
+ verify_model(BroadCastTo1().float().eval(), input_data=[x])
+
+ class BroadCastTo2(Module):
+ def __init__(self):
+ super().__init__()
+ self.y = torch.tensor(1)
+ self.z = torch.tensor(2)
+
+ def forward(self, x):
+ return torch.broadcast_to(x, (self.y + self.z, 3))
+
+ x = torch.tensor([1, 2, 3])
+ verify_model(BroadCastTo2().float().eval(), input_data=[x])
+
+
@tvm.testing.uses_gpu
def test_forward_pow():
"""test_forward_pow"""