This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4cdb3f3ef [Relax][PyTorch] Add support for ge, gt, le, mod, ne ops 
(#17664)
f4cdb3f3ef is described below

commit f4cdb3f3ef130a6a9dbe005ef420d953f4150213
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Feb 19 05:34:22 2025 +0800

    [Relax][PyTorch] Add support for ge, gt, le, mod, ne ops (#17664)
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
---
 python/tvm/relax/frontend/torch/fx_translator.py |   5 +
 tests/python/relax/test_frontend_from_fx.py      | 290 ++++++++++++++++++-----
 2 files changed, 238 insertions(+), 57 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 724bb3fc20..d49cfa6893 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -662,13 +662,18 @@ class TorchFXImporter(BaseFXGraphImporter):
             "add": self._binary_op(relax.op.add, operator.add),
             "eq": self._binary_op(relax.op.equal, operator.eq),
             "floordiv": self._binary_op(relax.op.floor_divide, 
operator.floordiv),
+            "ge": self._binary_op(relax.op.greater_equal, operator.ge),
+            "gt": self._binary_op(relax.op.greater, operator.gt),
             "iadd": self._binary_op(relax.op.add, operator.add),
+            "le": self._binary_op(relax.op.less_equal, operator.le),
             "lt": self._binary_op(relax.op.less, operator.lt),
             "matmul": self._binary_op(
                 partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
             ),
             "max": self._binary_op(relax.op.maximum, max),
+            "mod": self._binary_op(relax.op.mod, operator.mod),
             "mul": self._binary_op(relax.op.multiply, operator.mul),
+            "ne": self._binary_op(relax.op.not_equal, operator.ne),
             "pow": self._binary_op(relax.op.power, operator.pow),
             "sub": self._binary_op(relax.op.subtract, operator.sub),
             "truediv": self._binary_op(relax.op.divide, operator.truediv),
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 19bc15b192..371343b60a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1759,6 +1759,209 @@ def test_binary():
     verify_model(LT1(), input_info1, {}, expected13)
     verify_model(LT2(), input_info2, {}, expected14)
 
+    # Mod
+    class Mod1(Module):
+        def forward(self, lhs, rhs):
+            return lhs % rhs
+
+    @tvm.script.ir_module
+    class expected15:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, 
rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Mod2(Module):
+        def forward(self, lhs):
+            return lhs % 1.0
+
+    @tvm.script.ir_module
+    class expected16:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, 
R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Mod1(), input_info1, {}, expected15)
+    verify_model(Mod2(), input_info2, {}, expected16)
+
+    # Ge
+    class Ge1(Module):
+        def forward(self, lhs, rhs):
+            return lhs >= rhs
+
+    @tvm.script.ir_module
+    class expected17:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.greater_equal(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    class Ge2(Module):
+        def forward(self, lhs):
+            return lhs >= 1.0
+
+    @tvm.script.ir_module
+    class expected18:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.greater_equal(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(Ge1(), input_info1, {}, expected17)
+    verify_model(Ge2(), input_info2, {}, expected18)
+
+    # Gt
+    class Gt1(Module):
+        def forward(self, lhs, rhs):
+            return lhs > rhs
+
+    @tvm.script.ir_module
+    class expected19:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, 
rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    class Gt2(Module):
+        def forward(self, lhs):
+            return lhs > 1.0
+
+    @tvm.script.ir_module
+    class expected20:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, 
R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(Gt1(), input_info1, {}, expected19)
+    verify_model(Gt2(), input_info2, {}, expected20)
+
+    # Le
+    class Le1(Module):
+        def forward(self, lhs, rhs):
+            return lhs <= rhs
+
+    @tvm.script.ir_module
+    class expected21:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.less_equal(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    class Le2(Module):
+        def forward(self, lhs):
+            return lhs <= 1.0
+
+    @tvm.script.ir_module
+    class expected22:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.less_equal(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(Le1(), input_info1, {}, expected21)
+    verify_model(Le2(), input_info2, {}, expected22)
+
+    # Ne
+    class Ne1(Module):
+        def forward(self, lhs, rhs):
+            return lhs != rhs
+
+    @tvm.script.ir_module
+    class expected23:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.not_equal(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    class Ne2(Module):
+        def forward(self, lhs):
+            return lhs != 1.0
+
+    @tvm.script.ir_module
+    class expected24:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.not_equal(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(Ne1(), input_info1, {}, expected23)
+    verify_model(Ne2(), input_info2, {}, expected24)
+
 
 def test_size():
     input_info = [([1, 3, 10, 10], "float32")]
@@ -1981,6 +2184,36 @@ def test_basic_unary_ops(pytorch_op, relax_op):
     verify_model(Unary(), input_info, {}, expected_unary)
 
 
+operator_bool_unary = [
+    (torch.isnan, R.isnan),
+    (torch.isinf, R.isinf),
+    (torch.isfinite, R.isfinite),
+]
+
+
[email protected]("pytorch_op, relax_op", operator_bool_unary)
+def test_bool_unary_ops(pytorch_op, relax_op):
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Unary(Module):
+        def forward(self, input):
+            return pytorch_op(input)
+
+    @tvm.script.ir_module
+    class expected_unary:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Unary(), input_info, {}, expected_unary)
+
+
 def test_extended_unary_ops():
     input_info = [([1, 3, 10, 10], "float32")]
 
@@ -2201,63 +2434,6 @@ def test_extended_unary_ops():
     verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
     verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
 
-    # isfinite
-    class IsFinite(Module):
-        def forward(self, input):
-            return torch.isfinite(input)
-
-    @tvm.script.ir_module
-    class expected_isfinite:
-        @R.function
-        def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.isfinite(input_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(IsFinite(), input_info, {}, expected_isfinite)
-
-    # isinf
-    class IsInf(Module):
-        def forward(self, input):
-            return torch.isinf(input)
-
-    @tvm.script.ir_module
-    class expected_isinf:
-        @R.function
-        def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isinf(input_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(IsInf(), input_info, {}, expected_isinf)
-
-    # isnan
-    class IsNan(Module):
-        def forward(self, input):
-            return torch.isnan(input)
-
-    @tvm.script.ir_module
-    class expected_isnan:
-        @R.function
-        def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
-        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
-            with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isnan(input_1)
-                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
-                R.output(gv)
-            return gv
-
-    verify_model(IsNan(), input_info, {}, expected_isnan)
-
     # relu
     class ReLU0(Module):
         def __init__(self):

Reply via email to