This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe1b228210 [Relax][Pytorch] Add support for bitwise_or op support
(#17871)
fe1b228210 is described below
commit fe1b228210da0f04a084da5bece1013d3633cf6c
Author: kavin-mcw <[email protected]>
AuthorDate: Mon Apr 21 20:32:49 2025 +0530
[Relax][Pytorch] Add support for bitwise_or op support (#17871)
This PR adds support for the bitwise OR operation used in the
Mistral/Mistral-3B-Instruct model.
---
python/tvm/relax/frontend/torch/exported_program_translator.py | 4 ++++
python/tvm/relax/frontend/torch/fx_translator.py | 2 ++
tests/python/relax/test_frontend_from_exported_program.py | 2 ++
tests/python/relax/test_frontend_from_fx.py | 2 ++
4 files changed, 10 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ab55ded36c..ed6740a25e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -328,6 +328,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
+ "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
+ "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
+ "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
+ "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index ed42f995bb..548320bd85 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -679,6 +679,8 @@ class TorchFXImporter(BaseFXGraphImporter):
# binary
"add": self._binary_op(relax.op.add, operator.add),
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
+ "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_),
+ "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide,
operator.floordiv),
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ce68089048..a386a989f0 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -845,6 +845,8 @@ def test_tril_triu():
operator_binary_1 = [
(operator.add, R.add),
(torch.ops.aten.add_, R.add),
+ (torch.ops.aten.bitwise_or, R.bitwise_or),
+ (torch.ops.aten.bitwise_or_, R.bitwise_or),
(operator.sub, R.subtract),
(operator.mul, R.multiply),
(torch.ops.aten.mul_, R.multiply),
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2498fec35c..e8db6af347 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1769,6 +1769,8 @@ def test_binary2(op, relax_op):
operator_binary_3 = [
+ (torch.ops.aten.bitwise_or_, R.bitwise_or),
+ (torch.ops.aten.bitwise_or, R.bitwise_or),
(operator.lshift, R.left_shift),
(operator.rshift, R.right_shift),
(operator.and_, R.bitwise_and),