This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f88e43f [PyTorch] Fix neg indexing issue for `aten::flatten` (#10796)
f88e43f is described below
commit f88e43f18a74c38d352047a668bfeb260cfe50b7
Author: Colin Y. Li <[email protected]>
AuthorDate: Mon Mar 28 13:27:49 2022 +0800
[PyTorch] Fix neg indexing issue for `aten::flatten` (#10796)
---
python/tvm/relay/frontend/pytorch.py | 3 +++
tests/python/frontend/pytorch/test_forward.py | 31 ++++++++++++++++++---------
2 files changed, 24 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 9d3980d..e0bc935 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1252,8 +1252,11 @@ class PyTorchOpConverter:
end = int(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
ndim = len(dshape)
+ if start < 0:
+ start += ndim
if end < 0:
end += ndim
+ assert start <= end, "start dim cannot come after end dim"
new_shape = [0] * start
new_shape.append(-1)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 1470520..285d857 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1289,17 +1289,28 @@ def test_forward_reshape():
@tvm.testing.uses_gpu
def test_flatten():
- class Flatten(Module):
- def forward(self, x):
- return torch.flatten(x)
+ def _test_flatten(start_dim, end_dim):
+ return lambda inp: torch.flatten(inp, start_dim, end_dim)
- class BatchFlatten(Module):
- def forward(self, x):
- return torch.flatten(x, start_dim=1)
+ inp = torch.rand((3, 5, 2, 2))
+
+ # [3, 5, 2, 2] -> [60]
+ verify_model(_test_flatten(0, -1), inp)
+ verify_model(_test_flatten(0, 3), inp)
+ verify_model(_test_flatten(-4, 3), inp)
+ verify_model(_test_flatten(-4, -1), inp)
- inp = torch.rand((5, 2, 2))
- verify_model(Flatten(), input_data=inp)
- verify_model(BatchFlatten(), input_data=inp)
+ # [3, 5, 2, 2] -> [3, 5, 2, 2]
+ verify_model(_test_flatten(3, -1), inp)
+ verify_model(_test_flatten(-1, -1), inp)
+ verify_model(_test_flatten(0, -4), inp)
+ verify_model(_test_flatten(-4, -4), inp)
+
+ # [3, 5, 2, 2] -> [3, 10, 2]
+ verify_model(_test_flatten(1, 2), inp)
+ verify_model(_test_flatten(1, -2), inp)
+ verify_model(_test_flatten(-3, 2), inp)
+ verify_model(_test_flatten(-3, -2), inp)
@tvm.testing.uses_gpu
@@ -4249,7 +4260,7 @@ def test_mod():
return torch.fmod(x, y)
def test_remainder(x, y):
- return torch.fmod(x, y)
+ return torch.remainder(x, y)
for test_fn in [test_fmod, test_remainder]:
verify_model(test_fn, [torch.tensor([-3.0, -2, -1, 1, 2, 3]),
torch.tensor(2)])