This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 12ad4fbcf4 [Relay][Frontend][Torch] fix pytorch frontend not support 
logical or (#16400)
12ad4fbcf4 is described below

commit 12ad4fbcf43f3d73d757e69b1b9c02e45a291ffa
Author: TaoMiao <[email protected]>
AuthorDate: Wed Jan 17 05:15:02 2024 +0800

    [Relay][Frontend][Torch] fix pytorch frontend not support logical or 
(#16400)
    
    add logical_or to relay pytorch frontend
---
 python/tvm/relay/frontend/pytorch.py          |  7 +++++++
 tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++
 2 files changed, 22 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index b9650e6e9a..35f74544b8 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2672,6 +2672,12 @@ class PyTorchOpConverter:
 
         return _op.logical_and(lhs, rhs)
 
+    def logical_or(self, inputs, input_types):
+        lhs = _op.cast(inputs[0], "bool")
+        rhs = _op.cast(inputs[1], "bool")
+
+        return _op.logical_or(lhs, rhs)
+
     def nonzero(self, inputs, input_types, is_numpy_style=False):
         data = inputs[0]
         ret = _op.transform.argwhere(data)
@@ -4238,6 +4244,7 @@ class PyTorchOpConverter:
             "aten::unbind": self.unbind,
             "aten::__and__": self.logical_and,
             "aten::logical_and": self.logical_and,
+            "aten::logical_or": self.logical_or,
             "aten::_shape_as_tensor": self.shape_as_tensor,
             "aten::nonzero": self.nonzero,
             "aten::nonzero_numpy": self.nonzero_numpy,
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 9bf40cfcdd..bf96c21399 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4882,6 +4882,21 @@ def test_logical_and():
     verify_model(test_fn, [a, b])
 
 
+def test_logical_or():
+    """test_logical_or"""
+
+    def test_fn(x, y):
+        return torch.logical_or(x, y)
+
+    a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
+    b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
+    verify_model(test_fn, [a, b])
+
+    a = torch.tensor([True, False, True])
+    b = torch.tensor([True, False, False])
+    verify_model(test_fn, [a, b])
+
+
 def test_masked_select():
     """test_masked_select"""
 

Reply via email to