This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4cdb3f3ef [Relax][PyTorch] Add support for ge, gt, le, mod, ne ops
(#17664)
f4cdb3f3ef is described below
commit f4cdb3f3ef130a6a9dbe005ef420d953f4150213
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Feb 19 05:34:22 2025 +0800
[Relax][PyTorch] Add support for ge, gt, le, mod, ne ops (#17664)
* Update fx_translator.py
* Update test_frontend_from_fx.py
---
python/tvm/relax/frontend/torch/fx_translator.py | 5 +
tests/python/relax/test_frontend_from_fx.py | 290 ++++++++++++++++++-----
2 files changed, 238 insertions(+), 57 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 724bb3fc20..d49cfa6893 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -662,13 +662,18 @@ class TorchFXImporter(BaseFXGraphImporter):
"add": self._binary_op(relax.op.add, operator.add),
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide,
operator.floordiv),
+ "ge": self._binary_op(relax.op.greater_equal, operator.ge),
+ "gt": self._binary_op(relax.op.greater, operator.gt),
"iadd": self._binary_op(relax.op.add, operator.add),
+ "le": self._binary_op(relax.op.less_equal, operator.le),
"lt": self._binary_op(relax.op.less, operator.lt),
"matmul": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
),
"max": self._binary_op(relax.op.maximum, max),
+ "mod": self._binary_op(relax.op.mod, operator.mod),
"mul": self._binary_op(relax.op.multiply, operator.mul),
+ "ne": self._binary_op(relax.op.not_equal, operator.ne),
"pow": self._binary_op(relax.op.power, operator.pow),
"sub": self._binary_op(relax.op.subtract, operator.sub),
"truediv": self._binary_op(relax.op.divide, operator.truediv),
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 19bc15b192..371343b60a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1759,6 +1759,209 @@ def test_binary():
verify_model(LT1(), input_info1, {}, expected13)
verify_model(LT2(), input_info2, {}, expected14)
+ # Mod
+ class Mod1(Module):
+ def forward(self, lhs, rhs):
+ return lhs % rhs
+
+ @tvm.script.ir_module
+ class expected15:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1,
rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Mod2(Module):
+ def forward(self, lhs):
+ return lhs % 1.0
+
+ @tvm.script.ir_module
+ class expected16:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1,
R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Mod1(), input_info1, {}, expected15)
+ verify_model(Mod2(), input_info2, {}, expected16)
+
+ # Ge
+ class Ge1(Module):
+ def forward(self, lhs, rhs):
+ return lhs >= rhs
+
+ @tvm.script.ir_module
+ class expected17:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.greater_equal(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ class Ge2(Module):
+ def forward(self, lhs):
+ return lhs >= 1.0
+
+ @tvm.script.ir_module
+ class expected18:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.greater_equal(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(Ge1(), input_info1, {}, expected17)
+ verify_model(Ge2(), input_info2, {}, expected18)
+
+ # Gt
+ class Gt1(Module):
+ def forward(self, lhs, rhs):
+ return lhs > rhs
+
+ @tvm.script.ir_module
+ class expected19:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1,
rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ class Gt2(Module):
+ def forward(self, lhs):
+ return lhs > 1.0
+
+ @tvm.script.ir_module
+ class expected20:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1,
R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(Gt1(), input_info1, {}, expected19)
+ verify_model(Gt2(), input_info2, {}, expected20)
+
+ # Le
+ class Le1(Module):
+ def forward(self, lhs, rhs):
+ return lhs <= rhs
+
+ @tvm.script.ir_module
+ class expected21:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.less_equal(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ class Le2(Module):
+ def forward(self, lhs):
+ return lhs <= 1.0
+
+ @tvm.script.ir_module
+ class expected22:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.less_equal(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(Le1(), input_info1, {}, expected21)
+ verify_model(Le2(), input_info2, {}, expected22)
+
+ # Ne
+ class Ne1(Module):
+ def forward(self, lhs, rhs):
+ return lhs != rhs
+
+ @tvm.script.ir_module
+ class expected23:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.not_equal(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ class Ne2(Module):
+ def forward(self, lhs):
+ return lhs != 1.0
+
+ @tvm.script.ir_module
+ class expected24:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.not_equal(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(Ne1(), input_info1, {}, expected23)
+ verify_model(Ne2(), input_info2, {}, expected24)
+
def test_size():
input_info = [([1, 3, 10, 10], "float32")]
@@ -1981,6 +2184,36 @@ def test_basic_unary_ops(pytorch_op, relax_op):
verify_model(Unary(), input_info, {}, expected_unary)
+operator_bool_unary = [
+ (torch.isnan, R.isnan),
+ (torch.isinf, R.isinf),
+ (torch.isfinite, R.isfinite),
+]
+
+
[email protected]("pytorch_op, relax_op", operator_bool_unary)
+def test_bool_unary_ops(pytorch_op, relax_op):
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Unary(Module):
+ def forward(self, input):
+ return pytorch_op(input)
+
+ @tvm.script.ir_module
+ class expected_unary:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Unary(), input_info, {}, expected_unary)
+
+
def test_extended_unary_ops():
input_info = [([1, 3, 10, 10], "float32")]
@@ -2201,63 +2434,6 @@ def test_extended_unary_ops():
verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
- # isfinite
- class IsFinite(Module):
- def forward(self, input):
- return torch.isfinite(input)
-
- @tvm.script.ir_module
- class expected_isfinite:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="bool") =
R.isfinite(input_1)
- gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
- R.output(gv)
- return gv
-
- verify_model(IsFinite(), input_info, {}, expected_isfinite)
-
- # isinf
- class IsInf(Module):
- def forward(self, input):
- return torch.isinf(input)
-
- @tvm.script.ir_module
- class expected_isinf:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isinf(input_1)
- gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
- R.output(gv)
- return gv
-
- verify_model(IsInf(), input_info, {}, expected_isinf)
-
- # isnan
- class IsNan(Module):
- def forward(self, input):
- return torch.isnan(input)
-
- @tvm.script.ir_module
- class expected_isnan:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isnan(input_1)
- gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
- R.output(gv)
- return gv
-
- verify_model(IsNan(), input_info, {}, expected_isnan)
-
# relu
class ReLU0(Module):
def __init__(self):