This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 45a2a4082e [Relax][PyTorch] Add decomposed operator support for Binary 
(#18458)
45a2a4082e is described below

commit 45a2a4082e40e38fa6993d48e4bfb1ce45f97520
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 00:34:38 2025 +0800

    [Relax][PyTorch] Add decomposed operator support for Binary (#18458)
    
    ## Related Issue
    
    - https://github.com/apache/tvm/pull/18401
    
    ## Why
    
    - When `run_ep_decomposition=True` is enabled, PyTorch decomposes binary
    operators into lower-level operations and some of them are not
    supported, which cause error
    
    ## How
    - Added support for `bitwise_and.Tensor`, `bitwise_and.Scalar`,
    `bitwise_xor.Tensor` and `bitwise_xor.Scalar`
    - Updated `test_binary` to use `run_ep_decomposition=True`
---
 .../frontend/torch/exported_program_translator.py  |  6 +++
 .../relax/test_frontend_from_exported_program.py   | 53 ++++++++++++++++++----
 2 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 8c1cf80094..2a119e111b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -898,8 +898,12 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # binary
             "add.Tensor": self._binary_op(relax.op.add, operator.add),
             "add_.Tensor": self._binary_op(relax.op.add, operator.add),
+            "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, 
operator.and_),
+            "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, 
operator.or_),
             "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, 
operator.or_),
+            "bitwise_xor.Tensor": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
+            "bitwise_xor.Scalar": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
             "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, 
operator.or_),
             "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, 
operator.or_),
             "div.Scalar": self._binary_op(relax.op.divide, operator.truediv),
@@ -929,6 +933,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "min.other": self._binary_op(relax.op.minimum, min),
             "max.default": self._unary_op(relax.op.max),
             "min.default": self._unary_op(relax.op.min),
+            "maximum.default": self._binary_op(relax.op.maximum, 
torch.maximum),
+            "minimum.default": self._binary_op(relax.op.minimum, 
torch.minimum),
             "remainder.Tensor": self._binary_op(relax.op.floor_mod, 
operator.mod),
             "remainder.Scalar": self._binary_op(relax.op.floor_mod, 
operator.mod),
             "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 4bf0417108..f571ee1fd9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1291,6 +1291,21 @@ def test_binary1(op, relax_op):
                 R.output(gv)
             return gv
 
+    @tvm.script.ir_module
+    class expected_binary1_inplace:
+        @R.function
+        def main(
+            lhs: R.Tensor((10, 10), dtype="float32"),
+            rhs: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
+                gv: R.Tuple(
+                    R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")
+                ) = (lv, lv)
+                R.output(gv)
+            return gv
+
     class Binary2(Module):
         def __init__(self, op):
             super().__init__()
@@ -1311,8 +1326,30 @@ def test_binary1(op, relax_op):
                 R.output(gv)
             return gv
 
-    verify_model(Binary1(op), example_args1, {}, expected_binary1)
-    verify_model(Binary2(op), example_args2, {}, expected_binary2)
+    @tvm.script.ir_module
+    class expected_binary2_inplace:
+        @R.function
+        def main(
+            lhs: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, 
R.const(1.0))
+                gv: R.Tuple(
+                    R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")
+                ) = (lv, lv)
+                R.output(gv)
+            return gv
+
+    inplace_ops = [
+        torch.ops.aten.add_,
+        torch.ops.aten.bitwise_or_,
+        torch.ops.aten.mul_,
+    ]
+
+    expected1 = expected_binary1_inplace if op in inplace_ops else 
expected_binary1
+    expected2 = expected_binary2_inplace if op in inplace_ops else 
expected_binary2
+    verify_model(Binary1(op), example_args1, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Binary2(op), example_args2, {}, expected2, 
run_ep_decomposition=True)
 
 
 operator_binary_2 = [
@@ -1374,8 +1411,8 @@ def test_binary2(op, relax_op):
                 R.output(gv)
             return gv
 
-    verify_model(Binary1(op), example_args1, {}, expected_binary1)
-    verify_model(Binary2(op), example_args2, {}, expected_binary2)
+    verify_model(Binary1(op), example_args1, {}, expected_binary1, 
run_ep_decomposition=True)
+    verify_model(Binary2(op), example_args2, {}, expected_binary2, 
run_ep_decomposition=True)
 
 
 def test_binary3():
@@ -1403,7 +1440,7 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(Max1(), example_args1, {}, expected_max1)
+    verify_model(Max1(), example_args1, {}, expected_max1, 
run_ep_decomposition=True)
 
     # Min
     class Min1(Module):
@@ -1423,7 +1460,7 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(Min1(), example_args1, {}, expected_min1)
+    verify_model(Min1(), example_args1, {}, expected_min1, 
run_ep_decomposition=True)
 
     # RSub
     class RSub1(Module):
@@ -1458,8 +1495,8 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(RSub1(), example_args1, {}, expected_rsub1)
-    verify_model(RSub2(), example_args2, {}, expected_rsub2)
+    verify_model(RSub1(), example_args1, {}, expected_rsub1, 
run_ep_decomposition=True)
+    verify_model(RSub2(), example_args2, {}, expected_rsub2, 
run_ep_decomposition=True)
 
 
 # IsIn

Reply via email to