This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 12ad4fbcf4 [Relay][Frontend][Torch] fix pytorch frontend not support
logical or (#16400)
12ad4fbcf4 is described below
commit 12ad4fbcf43f3d73d757e69b1b9c02e45a291ffa
Author: TaoMiao <[email protected]>
AuthorDate: Wed Jan 17 05:15:02 2024 +0800
[Relay][Frontend][Torch] fix pytorch frontend not support logical or
(#16400)
add logical_or to relay pytorch frontend
---
python/tvm/relay/frontend/pytorch.py | 7 +++++++
tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++
2 files changed, 22 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index b9650e6e9a..35f74544b8 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2672,6 +2672,12 @@ class PyTorchOpConverter:
return _op.logical_and(lhs, rhs)
+ def logical_or(self, inputs, input_types):
+ lhs = _op.cast(inputs[0], "bool")
+ rhs = _op.cast(inputs[1], "bool")
+
+ return _op.logical_or(lhs, rhs)
+
def nonzero(self, inputs, input_types, is_numpy_style=False):
data = inputs[0]
ret = _op.transform.argwhere(data)
@@ -4238,6 +4244,7 @@ class PyTorchOpConverter:
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
+ "aten::logical_or": self.logical_or,
"aten::_shape_as_tensor": self.shape_as_tensor,
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 9bf40cfcdd..bf96c21399 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4882,6 +4882,21 @@ def test_logical_and():
verify_model(test_fn, [a, b])
+def test_logical_or():
+ """test_logical_or"""
+
+ def test_fn(x, y):
+ return torch.logical_or(x, y)
+
+ a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
+ b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
+ verify_model(test_fn, [a, b])
+
+ a = torch.tensor([True, False, True])
+ b = torch.tensor([True, False, False])
+ verify_model(test_fn, [a, b])
+
+
def test_masked_select():
"""test_masked_select"""