This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 95d769e738 [Relay][Pytorch] Add support for `aten::bitwise_and` 
(#16105)
95d769e738 is described below

commit 95d769e7381c50ed25477ba985bbd0e1b0c8f1b1
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Nov 11 04:18:09 2023 +0900

    [Relay][Pytorch] Add support for `aten::bitwise_and` (#16105)
    
    add support for aten::bitwise_and
---
 python/tvm/relay/frontend/pytorch.py          |  9 +++++++++
 tests/python/frontend/pytorch/test_forward.py | 27 +++++++++++++++++++++++++++
 2 files changed, 36 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 402ab59202..bdfd8f78b2 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2327,6 +2327,14 @@ class PyTorchOpConverter:
 
         return _op.bitwise_xor(lhs, rhs)
 
+    def bitwise_and(self, inputs, input_types):
+        lhs = inputs[0]
+        rhs = inputs[1]
+        lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else 
_op.cast(lhs, "int")
+        rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else 
_op.cast(rhs, "int")
+
+        return _op.bitwise_and(lhs, rhs)
+
     def logical_not(self, inputs, input_types):
         data = _wrap_const(inputs[0])
         return _op.logical_not(_op.cast(data, "bool"))
@@ -4033,6 +4041,7 @@ class PyTorchOpConverter:
             "aten::logical_xor": self.logical_xor,
             "aten::bitwise_not": self.bitwise_not,
             "aten::bitwise_xor": self.bitwise_xor,
+            "aten::bitwise_and": self.bitwise_and,
             "aten::Bool": self.Bool,
             "aten::Float": self.Float,
             "aten::rsub": self.rsub,
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index b9c1b6ce9c..894bea60ed 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3695,6 +3695,33 @@ def test_forward_bitwise_xor():
     verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
 
 
+def test_forward_bitwise_and():
+    """test_forward_bitwise_and"""
+    torch.set_grad_enabled(False)
+
+    class BitwiseAnd1(Module):
+        def forward(self, *args):
+            return torch.bitwise_and(args[0], args[1])
+
+    class BitwiseAnd2(Module):
+        def forward(self, *args):
+            rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
+            if torch.cuda.is_available():
+                rhs = rhs.cuda()
+            return torch.bitwise_and(args[0], rhs)
+
+    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
+    rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
+    verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])
+
+    lhs = torch.tensor([True, True, False])
+    rhs = torch.tensor([False, True, False])
+    verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])
+
+    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
+    verify_model(BitwiseAnd2().float().eval(), input_data=[lhs])
+
+
 @tvm.testing.uses_gpu
 def test_forward_logical_xor():
     """test_forward_logical_xor"""

Reply via email to