This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 36380389fd [Relax][PyTorch] Refactor binary ops tests (#17672)
36380389fd is described below

commit 36380389fd1fe10f4bcd027688ea657a22f27964
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Feb 26 03:54:26 2025 +0800

    [Relax][PyTorch] Refactor binary ops tests (#17672)
    
    This PR refactors binary ops tests and combine similar tests.
---
 tests/python/relax/test_frontend_from_fx.py | 693 ++++------------------------
 1 file changed, 93 insertions(+), 600 deletions(-)

diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 8b4ea5c8cc..9e7e1ff2ea 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import operator
 import pytest
 import torch
 import torch.nn.functional as F
@@ -1482,692 +1483,184 @@ def test_groupnorm():
     verify_model(model, input_info, binding, expected1)
 
 
-def test_binary():
+operator_binary_1 = [
+    (operator.add, R.add),
+    (operator.sub, R.subtract),
+    (operator.mul, R.multiply),
+    (operator.truediv, R.divide),
+    (operator.floordiv, R.floor_divide),
+    (operator.pow, R.power),
+    (operator.mod, R.mod),
+]
+
+
[email protected]("op, relax_op", operator_binary_1)
+def test_binary1(op, relax_op):
     input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
     input_info2 = [([1, 3, 10, 10], "float32")]
-    input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
-    input_info4 = [([1, 3, 10, 10], "int32")]
 
-    # Add
-    class Add1(Module):
+    class Binary1(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
+
         def forward(self, lhs, rhs):
-            return lhs + rhs
+            return self.op(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected1:
+    class expected_binary1:
         @R.function
         def main(
             lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
             rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class Add2(Module):
-        def forward(self, lhs):
-            return lhs + 1.0
-
-    @tvm.script.ir_module
-    class expected2:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, 
R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(Add1(), input_info1, {}, expected1)
-    verify_model(Add2(), input_info2, {}, expected2)
-
-    # Sub
-    class Sub1(Module):
-        def forward(self, lhs, rhs):
-            return lhs - rhs
-
-    @tvm.script.ir_module
-    class expected3:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class Sub2(Module):
-        def forward(self, lhs):
-            return lhs - 1.0
-
-    @tvm.script.ir_module
-    class expected4:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(lhs_1, R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(Sub1(), input_info1, {}, expected3)
-    verify_model(Sub2(), input_info2, {}, expected4)
-
-    # Mul
-    class Mul1(Module):
-        def forward(self, lhs, rhs):
-            return lhs * rhs
-
-    @tvm.script.ir_module
-    class expected5:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class Mul2(Module):
-        def forward(self, lhs):
-            return lhs * 1.0
-
-    @tvm.script.ir_module
-    class expected6:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lhs_1, R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(Mul1(), input_info1, {}, expected5)
-    verify_model(Mul2(), input_info2, {}, expected6)
-
-    # True div
-    class TrueDiv1(Module):
-        def forward(self, lhs, rhs):
-            return lhs / rhs
-
-    @tvm.script.ir_module
-    class expected7:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.divide(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class TrueDiv2(Module):
-        def forward(self, lhs):
-            return lhs / 1.0
-
-    @tvm.script.ir_module
-    class expected8:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.divide(lhs_1, R.const(1.0))
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(lhs, 
rhs)
                 gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
                 R.output(gv)
             return gv
 
-    verify_model(TrueDiv1(), input_info1, {}, expected7)
-    verify_model(TrueDiv2(), input_info2, {}, expected8)
-
-    # Floor div
-    class FloorDiv1(Module):
-        def forward(self, lhs, rhs):
-            return lhs // rhs
-
-    @tvm.script.ir_module
-    class expected9:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.floor_divide(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class FloorDiv2(Module):
-        def forward(self, lhs):
-            return lhs // 1.0
-
-    @tvm.script.ir_module
-    class expected10:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.floor_divide(lhs_1, R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(FloorDiv1(), input_info1, {}, expected9)
-    verify_model(FloorDiv2(), input_info2, {}, expected10)
-
-    # Power
-    class Power1(Module):
-        def forward(self, lhs, rhs):
-            return lhs**rhs
-
-    @tvm.script.ir_module
-    class expected11:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, 
rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    class Power2(Module):
-        def forward(self, lhs):
-            return lhs**1.0
-
-    @tvm.script.ir_module
-    class expected12:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(lhs_1, 
R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(Power1(), input_info1, {}, expected11)
-    verify_model(Power2(), input_info2, {}, expected12)
-
-    # LT
-    class LT1(Module):
-        def forward(self, lhs, rhs):
-            return lhs < rhs
-
-    @tvm.script.ir_module
-    class expected13:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, 
rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-            return gv
-
-    class LT2(Module):
-        def forward(self, lhs):
-            return lhs < 1.0
-
-    @tvm.script.ir_module
-    class expected14:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, 
R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(LT1(), input_info1, {}, expected13)
-    verify_model(LT2(), input_info2, {}, expected14)
-
-    # Mod
-    class Mod1(Module):
-        def forward(self, lhs, rhs):
-            return lhs % rhs
-
-    @tvm.script.ir_module
-    class expected15:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, 
rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
-                R.output(gv)
-            return gv
+    class Binary2(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    class Mod2(Module):
         def forward(self, lhs):
-            return lhs % 1.0
+            return self.op(lhs, 1.0)
 
     @tvm.script.ir_module
-    class expected16:
+    class expected_binary2:
         @R.function
         def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, 
R.const(1.0))
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(lhs, 
R.const(1.0))
                 gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
                 R.output(gv)
             return gv
 
-    verify_model(Mod1(), input_info1, {}, expected15)
-    verify_model(Mod2(), input_info2, {}, expected16)
-
-    # Ge
-    class Ge1(Module):
-        def forward(self, lhs, rhs):
-            return lhs >= rhs
-
-    @tvm.script.ir_module
-    class expected17:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.greater_equal(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-
-            return gv
-
-    class Ge2(Module):
-        def forward(self, lhs):
-            return lhs >= 1.0
-
-    @tvm.script.ir_module
-    class expected18:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.greater_equal(lhs_1, R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-
-            return gv
-
-    verify_model(Ge1(), input_info1, {}, expected17)
-    verify_model(Ge2(), input_info2, {}, expected18)
-
-    # Gt
-    class Gt1(Module):
-        def forward(self, lhs, rhs):
-            return lhs > rhs
-
-    @tvm.script.ir_module
-    class expected19:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, 
rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
+    verify_model(Binary1(op), input_info1, {}, expected_binary1)
+    verify_model(Binary2(op), input_info2, {}, expected_binary2)
 
-            return gv
 
-    class Gt2(Module):
-        def forward(self, lhs):
-            return lhs > 1.0
+operator_binary_2 = [
+    (operator.eq, R.equal),
+    (operator.ne, R.not_equal),
+    (operator.lt, R.less),
+    (operator.le, R.less_equal),
+    (operator.gt, R.greater),
+    (operator.ge, R.greater_equal),
+]
 
-    @tvm.script.ir_module
-    class expected20:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, 
R.const(1.0))
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
 
-            return gv
[email protected]("op, relax_op", operator_binary_2)
+def test_binary2(op, relax_op):
+    input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+    input_info2 = [([1, 3, 10, 10], "float32")]
 
-    verify_model(Gt1(), input_info1, {}, expected19)
-    verify_model(Gt2(), input_info2, {}, expected20)
+    class Binary1(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    # Le
-    class Le1(Module):
         def forward(self, lhs, rhs):
-            return lhs <= rhs
+            return self.op(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected21:
+    class expected_binary1:
         @R.function
         def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.less_equal(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-
-            return gv
-
-    class Le2(Module):
-        def forward(self, lhs):
-            return lhs <= 1.0
-
-    @tvm.script.ir_module
-    class expected22:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.less_equal(lhs_1, R.const(1.0))
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(lhs, rhs)
                 gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
                 R.output(gv)
-
             return gv
 
-    verify_model(Le1(), input_info1, {}, expected21)
-    verify_model(Le2(), input_info2, {}, expected22)
-
-    # Ne
-    class Ne1(Module):
-        def forward(self, lhs, rhs):
-            return lhs != rhs
-
-    @tvm.script.ir_module
-    class expected23:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.not_equal(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-
-            return gv
+    class Binary2(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    class Ne2(Module):
         def forward(self, lhs):
-            return lhs != 1.0
+            return self.op(lhs, 1.0)
 
     @tvm.script.ir_module
-    class expected24:
+    class expected_binary2:
         @R.function
         def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.not_equal(lhs_1, R.const(1.0))
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(lhs, 
R.const(1.0))
                 gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
                 R.output(gv)
-
             return gv
 
-    verify_model(Ne1(), input_info1, {}, expected23)
-    verify_model(Ne2(), input_info2, {}, expected24)
-
-    # Lshift
-    class LShift1(Module):
-        def forward(self, lhs, rhs):
-            return lhs << rhs
-
-    @tvm.script.ir_module
-    class expected25:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.left_shift(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
+    verify_model(Binary1(op), input_info1, {}, expected_binary1)
+    verify_model(Binary2(op), input_info2, {}, expected_binary2)
 
-    class LShift2(Module):
-        def forward(self, lhs):
-            return lhs << 1
 
-    @tvm.script.ir_module
-    class expected26:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.left_shift(lhs_1, R.const(1))
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
-
-    verify_model(LShift1(), input_info3, {}, expected25)
-    verify_model(LShift2(), input_info4, {}, expected26)
-
-    # Rshift
-    class RShift1(Module):
-        def forward(self, lhs, rhs):
-            return lhs >> rhs
-
-    @tvm.script.ir_module
-    class expected27:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.right_shift(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
-
-    class RShift2(Module):
-        def forward(self, lhs):
-            return lhs >> 1
-
-    @tvm.script.ir_module
-    class expected28:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.right_shift(lhs_1, R.const(1))
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
-
-    verify_model(RShift1(), input_info3, {}, expected27)
-    verify_model(RShift2(), input_info4, {}, expected28)
-
-    # Bitwise and
-    class BitwiseAnd1(Module):
-        def forward(self, lhs, rhs):
-            return lhs & rhs
-
-    @tvm.script.ir_module
-    class expected29:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_and(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
-
-    class BitwiseAnd2(Module):
-        def forward(self, lhs):
-            return lhs & 1
+operator_binary_3 = [
+    (operator.lshift, R.left_shift),
+    (operator.rshift, R.right_shift),
+    (operator.and_, R.bitwise_and),
+    (operator.or_, R.bitwise_or),
+    (operator.xor, R.bitwise_xor),
+]
 
-    @tvm.script.ir_module
-    class expected30:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_and(lhs_1, R.const(1))
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
 
-            return gv
[email protected]("op, relax_op", operator_binary_3)
+def test_binary3(op, relax_op):
+    input_info1 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
+    input_info2 = [([1, 3, 10, 10], "int32")]
 
-    verify_model(BitwiseAnd1(), input_info3, {}, expected29)
-    verify_model(BitwiseAnd2(), input_info4, {}, expected30)
+    class Binary1(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    # Bitwise or
-    class BitwiseOr1(Module):
         def forward(self, lhs, rhs):
-            return lhs | rhs
+            return self.op(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected31:
+    class expected_binary1:
         @R.function
         def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            lhs: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="int32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_or(lhs_1, rhs_1)
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = relax_op(lhs, 
rhs)
                 gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
                 R.output(gv)
-
             return gv
 
-    class BitwiseOr2(Module):
-        def forward(self, lhs):
-            return lhs | 1
-
-    @tvm.script.ir_module
-    class expected32:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_or(lhs_1, R.const(1))
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
-
-    verify_model(BitwiseOr1(), input_info3, {}, expected31)
-    verify_model(BitwiseOr2(), input_info4, {}, expected32)
-
-    # Bitwise xor
-    class BitwiseXor1(Module):
-        def forward(self, lhs, rhs):
-            return lhs ^ rhs
-
-    @tvm.script.ir_module
-    class expected33:
-        @R.function
-        def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
-        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_xor(lhs_1, rhs_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
-                R.output(gv)
-
-            return gv
+    class Binary2(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
 
-    class BitwiseXor2(Module):
         def forward(self, lhs):
-            return lhs ^ 1
+            return self.op(lhs, 1)
 
     @tvm.script.ir_module
-    class expected34:
+    class expected_binary2:
         @R.function
         def main(
-            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            lhs: R.Tensor((1, 3, 10, 10), dtype="int32"),
         ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
-            # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_xor(lhs_1, R.const(1))
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = relax_op(lhs, 
R.const(1))
                 gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
                 R.output(gv)
-
             return gv
 
-    verify_model(BitwiseXor1(), input_info3, {}, expected33)
-    verify_model(BitwiseXor2(), input_info4, {}, expected34)
+    verify_model(Binary1(op), input_info1, {}, expected_binary1)
+    verify_model(Binary2(op), input_info2, {}, expected_binary2)
 
 
 def test_size():

Reply via email to