This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d17f753 Support aten::flip (#8398)
d17f753 is described below
commit d17f75384d83111b9211ef0e6e0570c706a97e49
Author: delldu <[email protected]>
AuthorDate: Sun Jul 4 10:33:49 2021 +0800
Support aten::flip (#8398)
* Support test aten::flip
* Support aten::flip
---
python/tvm/relay/frontend/pytorch.py | 6 ++++++
tests/python/frontend/pytorch/test_forward.py | 20 ++++++++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 118af5a..909b804 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2322,6 +2322,11 @@ class PyTorchOpConverter:
weights = _op.full(_expr.const(1), (num_class,),
dtype=input_types[0])
return _op.nn.nll_loss(predictions, targets, weights, reduction,
ignore_index)
+ def flip(self, inputs, input_types):
+ data = inputs[0]
+ axis = inputs[1]
+ return _op.transform.reverse(data, axis=axis[0])
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -2536,6 +2541,7 @@ class PyTorchOpConverter:
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
+ "aten::flip": self.flip,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 2ec2810..f76ea9a 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3893,6 +3893,25 @@ def test_forward_nll_loss():
verify_model(torch.nn.NLLLoss(reduction="none").eval(),
input_data=[predictions, targets])
[email protected]_gpu
+def test_forward_flip():
+ torch.set_grad_enabled(False)
+
+ class Flip(Module):
+ def __init__(self, axis=0):
+ super().__init__()
+ self.axis = axis
+
+ def forward(self, x):
+ return x.flip([self.axis])
+
+ input = torch.randn(2, 3, 4)
+ verify_model(Flip(axis=0), input_data=input)
+ verify_model(Flip(axis=1), input_data=input)
+ verify_model(Flip(axis=2), input_data=input)
+ verify_model(Flip(axis=-1), input_data=input)
+
+
if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
@@ -4035,6 +4054,7 @@ if __name__ == "__main__":
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()
+ test_forward_flip()
# Model tests
test_resnet18()