This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8a6c9bf230 [Relax][PyTorch] Add ReLU6 Op Support for Exported Program
and FX graph (#17918)
8a6c9bf230 is described below
commit 8a6c9bf230b9ffdaf1e675ed9798532661d2e39b
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat May 10 19:28:19 2025 +0530
[Relax][PyTorch] Add ReLU6 Op Support for Exported Program and FX graph
(#17918)
* add relu6 op support into relax frontend
* fix lint ssues
* fix unity issue in test script
* fix issues in msc test script
* fix relu6 layout value in msc test script
* define relu6 op in relax_opcode file
* define relu6 op in torch codegen file
* update relu6 op implementation using clip op
* update test script
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki.>
---
python/tvm/relax/frontend/nn/op.py | 22 ++++++++++
.../frontend/torch/exported_program_translator.py | 2 +
python/tvm/relax/frontend/torch/fx_translator.py | 5 +--
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 21 +++++++++-
.../relax/test_frontend_from_exported_program.py | 47 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 19 +++++----
tests/python/relax/test_frontend_nn_op.py | 2 +
tests/python/relax/test_op_nn.py | 5 +++
9 files changed, 113 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 86be98cba7..f7df8d4417 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -930,6 +930,28 @@ def relu(x: Tensor, name: str = "relu") -> Tensor:
return wrap_nested(_op.nn.relu(x._expr), name)
+def relu6(x: Tensor, name: str = "relu6") -> Tensor:
+ r"""ReLU6 activation function.
+
+ .. math::
+ \text{ReLU6}(x) = \min(\max(x, 0), 6)
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.nn.relu6(x._expr), name)
+
+
def silu(x: Tensor, name: str = "silu") -> Tensor:
r"""Sigmoid Linear Unit function
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index d69d5bcfa1..efa3de3a10 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -347,6 +347,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
"relu_.default": self._unary_op(relax.op.nn.relu),
+ "relu6.default": self._unary_op(relax.op.nn.relu6),
+ "relu6_.default": self._unary_op(relax.op.nn.relu6),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"rsub.Tensor": self._rsub,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index b2a1f5eae1..fc12f877e0 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -716,9 +716,7 @@ class TorchFXImporter(BaseFXGraphImporter):
nn.LogSoftmax: self._log_softmax_module,
nn.PReLU: self._prelu_module,
nn.ReLU: self._unary_op(relax.op.nn.relu),
- nn.ReLU6: lambda node: self.block_builder.emit(
- relax.op.clip(self.env[node.args[0]], 0, 6)
- ),
+ nn.ReLU6: self._unary_op(relax.op.nn.relu6),
nn.Sigmoid: self._unary_op(relax.op.sigmoid),
nn.SELU: self._unary_op(relax.op.nn.selu),
nn.SiLU: self._unary_op(relax.op.nn.silu),
@@ -790,6 +788,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"prelu": self._prelu,
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
+ "relu6": self._unary_op(relax.op.nn.relu6),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
"selu": self._unary_op(relax.op.nn.selu),
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index 08ecda275c..62fa0d53a9 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -46,6 +46,7 @@ from .nn import (
pixel_shuffle,
prelu,
relu,
+ relu6,
rms_norm,
selu,
silu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index b68d488e26..c6beea3158 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -17,7 +17,7 @@
"""Relax Neural Network (NN) operators"""
from typing import List, Optional, Tuple, Union
-from tvm import DataType
+from tvm import DataType, relax
from tvm.tir import FloatImm
from ...expr import Expr
@@ -1267,6 +1267,25 @@ def relu(data: Expr) -> Expr:
return _ffi_api.relu(data) # type: ignore
+def relu6(data: Expr) -> Expr:
+ r"""ReLU6 activation function.
+
+ .. math::
+ \text{ReLU6}(x) = \min(\max(x, 0), 6)
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return relax.op.clip(data, 0, 6)
+
+
def leakyrelu(data: Expr, alpha: float = 0.01) -> Expr:
"""Rectified linear unit.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 4cb9e903a1..f01da1336e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -514,6 +514,53 @@ def test_extended_unary_ops():
verify_model(MinModel(), example_args, {}, expected_min)
+ # relu6
+ class ReLU6_1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.relu6 = torch.nn.ReLU6()
+
+ def forward(self, x):
+ return self.relu6(x)
+
+ class ReLU6_2(torch.nn.Module):
+ def forward(self, x):
+ return torch.nn.functional.relu6(x)
+
+ class ReLU6_3(torch.nn.Module):
+ def forward(self, x):
+ return torch.ops.aten.relu6_(x)
+
+ @tvm.script.ir_module
+ class expected_relu6_1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ x, R.prim_value(T.float64(0.0)),
R.prim_value(T.float64(6.0))
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_relu6_2:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(x)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
+ verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
+ verify_model(ReLU6_3(), example_args, {}, expected_relu6_2)
+
def test_hardtanh():
class Hardtanh(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 7928975301..f507071b07 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3208,27 +3208,32 @@ def test_extended_unary_ops():
verify_model(ReLU1(), input_info, {}, expected_relu)
# relu6
- class ReLU6(Module):
+ class ReLU6_1(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu6 = torch.nn.ReLU6()
- def forward(self, input):
- return self.relu6(input)
+ def forward(self, x):
+ return self.relu6(x)
+
+ class ReLU6_2(torch.nn.Module):
+ def forward(self, x):
+ return torch.nn.functional.relu6(x)
@tvm.script.ir_module
- class expected_relu6:
+ class expected:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0, 6)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu6(inp_0)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv
- verify_model(ReLU6(), input_info, {}, expected_relu6)
+ verify_model(ReLU6_1(), input_info, {}, expected)
+ verify_model(ReLU6_2(), input_info, {}, expected)
# selu
class Selu1(Module):
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 1af13f0487..5c400ef8be 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -385,6 +385,7 @@ def test_nn():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
relu_out = op.relu(x)
+ relu6_out = op.relu6(x)
silu_out = op.silu(x)
gelu_out = op.gelu(x)
sigmoid_out = op.sigmoid(x)
@@ -409,6 +410,7 @@ def test_nn():
R.func_attr({"num_input": 4})
with R.dataflow():
relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
+ relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x)
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x)
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index 1bf4444848..a0ff507ef8 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -74,9 +74,12 @@ def test_linear_unit_infer_struct_info():
_check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3),
"float32"))
_check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3),
"float32", vdev0))
+ _check_inference(bb, relax.op.nn.relu6(x0), relax.TensorStructInfo((2, 3),
"float32"))
+ _check_inference(bb, relax.op.nn.relu6(x6), relax.TensorStructInfo((2, 3),
"float32", vdev0))
_check_inference(bb, relax.op.nn.silu(x1),
relax.TensorStructInfo(dtype="float32", ndim=3))
_check_inference(bb, relax.op.nn.gelu(x2),
relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3),
dtype=""))
+ _check_inference(bb, relax.op.nn.relu6(x3), relax.TensorStructInfo((2, 3),
dtype=""))
_check_inference(bb, relax.op.nn.gelu(x4),
relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2,
3), "float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3,
4), dtype=""))
@@ -93,6 +96,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic():
_check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n),
"float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n),
"float32"))
+ _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo((4, n),
"float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4,
n), "float32"))
_check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4,
n), "float32"))
@@ -106,6 +110,7 @@ def test_linear_unit_infer_struct_info_shape_var():
_check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0,
"float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1,
"float32"))
+ _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1,
"float32"))