This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 95d769e738 [Relay][Pytorch] Add support for `aten::bitwise_and`
(#16105)
95d769e738 is described below
commit 95d769e7381c50ed25477ba985bbd0e1b0c8f1b1
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Nov 11 04:18:09 2023 +0900
[Relay][Pytorch] Add support for `aten::bitwise_and` (#16105)
add support for aten::bitwise_and
---
python/tvm/relay/frontend/pytorch.py | 9 +++++++++
tests/python/frontend/pytorch/test_forward.py | 27 +++++++++++++++++++++++++++
2 files changed, 36 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 402ab59202..bdfd8f78b2 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2327,6 +2327,14 @@ class PyTorchOpConverter:
return _op.bitwise_xor(lhs, rhs)
+ def bitwise_and(self, inputs, input_types):
+ lhs = inputs[0]
+ rhs = inputs[1]
+ lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else
_op.cast(lhs, "int")
+ rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else
_op.cast(rhs, "int")
+
+ return _op.bitwise_and(lhs, rhs)
+
def logical_not(self, inputs, input_types):
data = _wrap_const(inputs[0])
return _op.logical_not(_op.cast(data, "bool"))
@@ -4033,6 +4041,7 @@ class PyTorchOpConverter:
"aten::logical_xor": self.logical_xor,
"aten::bitwise_not": self.bitwise_not,
"aten::bitwise_xor": self.bitwise_xor,
+ "aten::bitwise_and": self.bitwise_and,
"aten::Bool": self.Bool,
"aten::Float": self.Float,
"aten::rsub": self.rsub,
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index b9c1b6ce9c..894bea60ed 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3695,6 +3695,33 @@ def test_forward_bitwise_xor():
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
+def test_forward_bitwise_and():
+ """test_forward_bitwise_and"""
+ torch.set_grad_enabled(False)
+
+ class BitwiseAnd1(Module):
+ def forward(self, *args):
+ return torch.bitwise_and(args[0], args[1])
+
+ class BitwiseAnd2(Module):
+ def forward(self, *args):
+ rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
+ if torch.cuda.is_available():
+ rhs = rhs.cuda()
+ return torch.bitwise_and(args[0], rhs)
+
+ lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
+ rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
+ verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])
+
+ lhs = torch.tensor([True, True, False])
+ rhs = torch.tensor([False, True, False])
+ verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])
+
+ lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
+ verify_model(BitwiseAnd2().float().eval(), input_data=[lhs])
+
+
@tvm.testing.uses_gpu
def test_forward_logical_xor():
"""test_forward_logical_xor"""