This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 45a2a4082e [Relax][PyTorch] Add decomposed operator support for Binary
(#18458)
45a2a4082e is described below
commit 45a2a4082e40e38fa6993d48e4bfb1ce45f97520
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 00:34:38 2025 +0800
[Relax][PyTorch] Add decomposed operator support for Binary (#18458)
## Related Issue
- https://github.com/apache/tvm/pull/18401
## Why
- When `run_ep_decomposition=True` is enabled, PyTorch decomposes binary
operators into lower-level operations and some of them are not
supported, which cause error
## How
- Added support for `bitwise_and.Tensor`, `bitwise_and.Scalar`,
`bitwise_xor.Tensor` and `bitwise_xor.Scalar`
- Updated `test_binary` to use `run_ep_decomposition=True`
---
.../frontend/torch/exported_program_translator.py | 6 +++
.../relax/test_frontend_from_exported_program.py | 53 ++++++++++++++++++----
2 files changed, 51 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 8c1cf80094..2a119e111b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -898,8 +898,12 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
+ "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and,
operator.and_),
+ "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and,
operator.and_),
"bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
"bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
+ "bitwise_xor.Tensor": self._binary_op(relax.op.bitwise_xor,
operator.xor),
+ "bitwise_xor.Scalar": self._binary_op(relax.op.bitwise_xor,
operator.xor),
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or,
operator.or_),
"div.Scalar": self._binary_op(relax.op.divide, operator.truediv),
@@ -929,6 +933,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"min.other": self._binary_op(relax.op.minimum, min),
"max.default": self._unary_op(relax.op.max),
"min.default": self._unary_op(relax.op.min),
+ "maximum.default": self._binary_op(relax.op.maximum,
torch.maximum),
+ "minimum.default": self._binary_op(relax.op.minimum,
torch.minimum),
"remainder.Tensor": self._binary_op(relax.op.floor_mod,
operator.mod),
"remainder.Scalar": self._binary_op(relax.op.floor_mod,
operator.mod),
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 4bf0417108..f571ee1fd9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1291,6 +1291,21 @@ def test_binary1(op, relax_op):
R.output(gv)
return gv
+ @tvm.script.ir_module
+ class expected_binary1_inplace:
+ @R.function
+ def main(
+ lhs: R.Tensor((10, 10), dtype="float32"),
+ rhs: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
+ gv: R.Tuple(
+ R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")
+ ) = (lv, lv)
+ R.output(gv)
+ return gv
+
class Binary2(Module):
def __init__(self, op):
super().__init__()
@@ -1311,8 +1326,30 @@ def test_binary1(op, relax_op):
R.output(gv)
return gv
- verify_model(Binary1(op), example_args1, {}, expected_binary1)
- verify_model(Binary2(op), example_args2, {}, expected_binary2)
+ @tvm.script.ir_module
+ class expected_binary2_inplace:
+ @R.function
+ def main(
+ lhs: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs,
R.const(1.0))
+ gv: R.Tuple(
+ R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")
+ ) = (lv, lv)
+ R.output(gv)
+ return gv
+
+ inplace_ops = [
+ torch.ops.aten.add_,
+ torch.ops.aten.bitwise_or_,
+ torch.ops.aten.mul_,
+ ]
+
+ expected1 = expected_binary1_inplace if op in inplace_ops else
expected_binary1
+ expected2 = expected_binary2_inplace if op in inplace_ops else
expected_binary2
+ verify_model(Binary1(op), example_args1, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Binary2(op), example_args2, {}, expected2,
run_ep_decomposition=True)
operator_binary_2 = [
@@ -1374,8 +1411,8 @@ def test_binary2(op, relax_op):
R.output(gv)
return gv
- verify_model(Binary1(op), example_args1, {}, expected_binary1)
- verify_model(Binary2(op), example_args2, {}, expected_binary2)
+ verify_model(Binary1(op), example_args1, {}, expected_binary1,
run_ep_decomposition=True)
+ verify_model(Binary2(op), example_args2, {}, expected_binary2,
run_ep_decomposition=True)
def test_binary3():
@@ -1403,7 +1440,7 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(Max1(), example_args1, {}, expected_max1)
+ verify_model(Max1(), example_args1, {}, expected_max1,
run_ep_decomposition=True)
# Min
class Min1(Module):
@@ -1423,7 +1460,7 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(Min1(), example_args1, {}, expected_min1)
+ verify_model(Min1(), example_args1, {}, expected_min1,
run_ep_decomposition=True)
# RSub
class RSub1(Module):
@@ -1458,8 +1495,8 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(RSub1(), example_args1, {}, expected_rsub1)
- verify_model(RSub2(), example_args2, {}, expected_rsub2)
+ verify_model(RSub1(), example_args1, {}, expected_rsub1,
run_ep_decomposition=True)
+ verify_model(RSub2(), example_args2, {}, expected_rsub2,
run_ep_decomposition=True)
# IsIn