This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d38cf25c5 [Relax][PyTorch] Support several binary ops for 
ExportedProgram importer (#17689)
7d38cf25c5 is described below

commit 7d38cf25c52c2930e6a204258bc59f93c8af6644
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Feb 28 14:08:08 2025 +0800

    [Relax][PyTorch] Support several binary ops for ExportedProgram importer 
(#17689)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/exported_program_translator.py  |  17 +
 .../relax/test_frontend_from_exported_program.py   | 359 +++++----------------
 2 files changed, 103 insertions(+), 273 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0acc6ec1a0..c8d9d12505 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -204,16 +204,33 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
             "eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
             "floor_divide.default": self._binary_op(relax.op.floor_divide, 
operator.floordiv),
+            "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge),
+            "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge),
+            "gt.Scalar": self._binary_op(relax.op.greater, operator.gt),
+            "gt.Tensor": self._binary_op(relax.op.greater, operator.gt),
+            "le.Scalar": self._binary_op(relax.op.less_equal, operator.le),
+            "le.Tensor": self._binary_op(relax.op.less_equal, operator.le),
             "lt.Scalar": self._binary_op(relax.op.less, operator.lt),
             "lt.Tensor": self._binary_op(relax.op.less, operator.lt),
             "matmul.default": self._binary_op(
                 partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
             ),
             "max.other": self._binary_op(relax.op.maximum, max),
+            "min.other": self._binary_op(relax.op.minimum, min),
+            "remainder.Tensor": self._binary_op(relax.op.mod, operator.mod),
+            "remainder.Scalar": self._binary_op(relax.op.mod, operator.mod),
             "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
+            "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
+            "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
             "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
             "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
             "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
+            "__and__.Tensor": self._binary_op(relax.op.bitwise_and, 
operator.and_),
+            "__and__.Scalar": self._binary_op(relax.op.bitwise_and, 
operator.and_),
+            "__or__.Tensor": self._binary_op(relax.op.bitwise_or, 
operator.or_),
+            "__or__.Scalar": self._binary_op(relax.op.bitwise_or, 
operator.or_),
+            "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
+            "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
             # neural network
             "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 52cdc12bb7..8ca335c2fe 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import operator
 import pytest
 import torch
 from torch.nn import Module
@@ -542,233 +543,142 @@ def test_tril_triu():
     verify_model(Triu(), example_args, {}, expected_triu)
 
 
-def test_binary():
+operator_binary_1 = [
+    (operator.add, R.add),
+    (operator.sub, R.subtract),
+    (operator.mul, R.multiply),
+    (operator.truediv, R.divide),
+    (operator.floordiv, R.floor_divide),
+    (operator.pow, R.power),
+    (operator.mod, R.mod),
+    (operator.and_, R.bitwise_and),
+    (operator.or_, R.bitwise_or),
+    (operator.xor, R.bitwise_xor),
+]
+
+
[email protected]("op, relax_op", operator_binary_1)
+def test_binary1(op, relax_op):
     example_args1 = (
         torch.randn(10, 10, dtype=torch.float32),
         torch.randn(10, 10, dtype=torch.float32),
     )
     example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
 
-    # Add
-    class Add1(Module):
+    class Binary1(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
+
         def forward(self, lhs, rhs):
-            return lhs + rhs
+            return self.op(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected_add1:
+    class expected_binary1:
         @R.function
         def main(
             lhs: R.Tensor((10, 10), dtype="float32"),
             rhs: R.Tensor((10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs)
+                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
                 gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
-    class Add2(Module):
-        def forward(self, lhs):
-            return lhs + 1.0
-
-    @tvm.script.ir_module
-    class expected_add2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1, 
R.const(1.0))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    verify_model(Add1(), example_args1, {}, expected_add1)
-    verify_model(Add2(), example_args2, {}, expected_add2)
-
-    # True div
-    class TrueDiv1(Module):
-        def forward(self, lhs, rhs):
-            return lhs / rhs
-
-    @tvm.script.ir_module
-    class expected_truediv1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, 
rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
+    class Binary2(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    class TrueDiv2(Module):
         def forward(self, lhs):
-            return lhs / 1.0
+            return self.op(lhs, 1.0)
 
     @tvm.script.ir_module
-    class expected_truediv2:
+    class expected_binary2:
         @R.function
         def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            lhs: R.Tensor((10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, 
R.const(1.0))
+                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, 
R.const(1.0))
                 gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
-    verify_model(TrueDiv1(), example_args1, {}, expected_truediv1)
-    verify_model(TrueDiv2(), example_args2, {}, expected_truediv2)
-
-    # EQ
-    class EQ1(Module):
-        def forward(self, lhs, rhs):
-            return lhs == rhs
+    verify_model(Binary1(op), example_args1, {}, expected_binary1)
+    verify_model(Binary2(op), example_args2, {}, expected_binary2)
 
-    @tvm.script.ir_module
-    class expected_eq1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
-                R.output(gv)
-            return gv
-
-    class EQ2(Module):
-        def forward(self, lhs):
-            return lhs == 1.0
 
-    @tvm.script.ir_module
-    class expected_eq2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, 
R.const(1.0))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
-                R.output(gv)
-            return gv
-
-    verify_model(EQ1(), example_args1, {}, expected_eq1)
-    verify_model(EQ2(), example_args2, {}, expected_eq2)
-
-    # Floor div
-    class FloorDiv1(Module):
-        def forward(self, lhs, rhs):
-            return lhs // rhs
+operator_binary_2 = [
+    (operator.eq, R.equal),
+    (operator.ne, R.not_equal),
+    (operator.lt, R.less),
+    (operator.le, R.less_equal),
+    (operator.gt, R.greater),
+    (operator.ge, R.greater_equal),
+]
 
-    @tvm.script.ir_module
-    class expected_floordiv1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = 
R.floor_divide(lhs_1, rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
 
-    class FloorDiv2(Module):
-        def forward(self, lhs):
-            return lhs // 1.0
-
-    @tvm.script.ir_module
-    class expected_floordiv2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = 
R.floor_divide(lhs_1, R.const(1.0))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
[email protected]("op, relax_op", operator_binary_2)
+def test_binary2(op, relax_op):
+    example_args1 = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+    )
+    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
 
-    verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1)
-    verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2)
+    class Binary1(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    # LT
-    class LT1(Module):
         def forward(self, lhs, rhs):
-            return lhs < rhs
+            return self.op(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected_lt1:
+    class expected_binary1:
         @R.function
         def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
+            lhs: R.Tensor((10, 10), dtype="float32"),
+            rhs: R.Tensor((10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1)
+                lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, rhs)
                 gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
                 R.output(gv)
             return gv
 
-    class LT2(Module):
+    class Binary2(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
+
         def forward(self, lhs):
-            return lhs < 1.0
+            return self.op(lhs, 1.0)
 
     @tvm.script.ir_module
-    class expected_lt2:
+    class expected_binary2:
         @R.function
         def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            lhs: R.Tensor((10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, 
R.const(1.0))
+                lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, 
R.const(1.0))
                 gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
                 R.output(gv)
             return gv
 
-    verify_model(LT1(), example_args1, {}, expected_lt1)
-    verify_model(LT2(), example_args2, {}, expected_lt2)
-
-    # MatMul
-    class MatMul1(Module):
-        def __init__(self):
-            super().__init__()
+    verify_model(Binary1(op), example_args1, {}, expected_binary1)
+    verify_model(Binary2(op), example_args2, {}, expected_binary2)
 
-        def forward(self, x, y):
-            return torch.matmul(x, y)
 
-    @tvm.script.ir_module
-    class expected_matmul1:
-        @R.function
-        def main(
-            input_1: R.Tensor((10, 10), dtype="float32"),
-            input_2: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(
-                    input_1, input_2, out_dtype="float32"
-                )
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    verify_model(MatMul1(), example_args1, {}, expected_matmul1)
+def test_binary3():
+    example_args1 = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+    )
+    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
 
     # Max
     class Max1(Module):
@@ -790,122 +700,25 @@ def test_binary():
 
     verify_model(Max1(), example_args1, {}, expected_max1)
 
-    # Mul
-    class Mul1(Module):
-        def forward(self, lhs, rhs):
-            return lhs * rhs
-
-    @tvm.script.ir_module
-    class expected_mul1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, 
rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    class Mul2(Module):
-        def forward(self, lhs):
-            return lhs * 1.0
-
-    @tvm.script.ir_module
-    class expected_mul2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, 
R.const(1.0))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    verify_model(Mul1(), example_args1, {}, expected_mul1)
-    verify_model(Mul2(), example_args2, {}, expected_mul2)
-
-    # Power
-    class Power1(Module):
-        def forward(self, lhs, rhs):
-            return lhs**rhs
-
-    @tvm.script.ir_module
-    class expected_power1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    class Power2(Module):
-        def forward(self, lhs):
-            return lhs**1.0
-
-    @tvm.script.ir_module
-    class expected_power2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, 
R.const(1.0))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    verify_model(Power1(), example_args1, {}, expected_power1)
-    verify_model(Power2(), example_args2, {}, expected_power2)
-
-    # Sub
-    class Sub1(Module):
-        def forward(self, lhs, rhs):
-            return lhs - rhs
-
-    @tvm.script.ir_module
-    class expected_sub1:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
-            rhs_1: R.Tensor((10, 10), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, 
rhs_1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
-                R.output(gv)
-            return gv
-
-    class Sub2(Module):
-        def forward(self, lhs):
-            return lhs - 1.0
+    # Min
+    class Min1(Module):
+        def forward(self, x, y):
+            return torch.min(x, y)
 
-    @tvm.script.ir_module
-    class expected_sub2:
+    @I.ir_module
+    class expected_min1:
         @R.function
         def main(
-            lhs_1: R.Tensor((10, 10), dtype="float32"),
+            inp_0: R.Tensor((10, 10), dtype="float32"),
+            inp_1: R.Tensor((10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, 
R.const(1.0))
+                lv: R.Tensor((10, 10), dtype="float32") = R.minimum(inp_0, 
inp_1)
                 gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
-    verify_model(Sub1(), example_args1, {}, expected_sub1)
-    verify_model(Sub2(), example_args2, {}, expected_sub2)
+    verify_model(Min1(), example_args1, {}, expected_min1)
 
 
 @pytest.mark.skipif(

Reply via email to