This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8a6c9bf230 [Relax][PyTorch] Add ReLU6 Op Support for Exported Program 
and FX graph (#17918)
8a6c9bf230 is described below

commit 8a6c9bf230b9ffdaf1e675ed9798532661d2e39b
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat May 10 19:28:19 2025 +0530

    [Relax][PyTorch] Add ReLU6 Op Support for Exported Program and FX graph 
(#17918)
    
    * add relu6 op support into relax frontend
    
    * fix lint ssues
    
    * fix unity issue in test script
    
    * fix issues in msc test script
    
    * fix relu6 layout value in msc test script
    
    * define relu6 op in relax_opcode file
    
    * define relu6 op in torch codegen file
    
    * update relu6 op implementation using clip op
    
    * update test script
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki.>
---
 python/tvm/relax/frontend/nn/op.py                 | 22 ++++++++++
 .../frontend/torch/exported_program_translator.py  |  2 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  5 +--
 python/tvm/relax/op/nn/__init__.py                 |  1 +
 python/tvm/relax/op/nn/nn.py                       | 21 +++++++++-
 .../relax/test_frontend_from_exported_program.py   | 47 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 19 +++++----
 tests/python/relax/test_frontend_nn_op.py          |  2 +
 tests/python/relax/test_op_nn.py                   |  5 +++
 9 files changed, 113 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 86be98cba7..f7df8d4417 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -930,6 +930,28 @@ def relu(x: Tensor, name: str = "relu") -> Tensor:
     return wrap_nested(_op.nn.relu(x._expr), name)
 
 
+def relu6(x: Tensor, name: str = "relu6") -> Tensor:
+    r"""ReLU6 activation function.
+
+    .. math::
+        \text{ReLU6}(x) = \min(\max(x, 0), 6)
+
+    Parameters
+    ----------
+    x : Tensor
+        The input data.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.nn.relu6(x._expr), name)
+
+
 def silu(x: Tensor, name: str = "silu") -> Tensor:
     r"""Sigmoid Linear Unit function
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index d69d5bcfa1..efa3de3a10 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -347,6 +347,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "reciprocal.default": self._reciprocal,
             "relu.default": self._unary_op(relax.op.nn.relu),
             "relu_.default": self._unary_op(relax.op.nn.relu),
+            "relu6.default": self._unary_op(relax.op.nn.relu6),
+            "relu6_.default": self._unary_op(relax.op.nn.relu6),
             "round.default": self._round,
             "rsqrt.default": self._unary_op(relax.op.rsqrt),
             "rsub.Tensor": self._rsub,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index b2a1f5eae1..fc12f877e0 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -716,9 +716,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             nn.LogSoftmax: self._log_softmax_module,
             nn.PReLU: self._prelu_module,
             nn.ReLU: self._unary_op(relax.op.nn.relu),
-            nn.ReLU6: lambda node: self.block_builder.emit(
-                relax.op.clip(self.env[node.args[0]], 0, 6)
-            ),
+            nn.ReLU6: self._unary_op(relax.op.nn.relu6),
             nn.Sigmoid: self._unary_op(relax.op.sigmoid),
             nn.SELU: self._unary_op(relax.op.nn.selu),
             nn.SiLU: self._unary_op(relax.op.nn.silu),
@@ -790,6 +788,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "prelu": self._prelu,
             "reciprocal": self._reciprocal,
             "relu": self._unary_op(relax.op.nn.relu),
+            "relu6": self._unary_op(relax.op.nn.relu6),
             "round": self._round,
             "rsqrt": self._unary_op(relax.op.rsqrt),
             "selu": self._unary_op(relax.op.nn.selu),
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index 08ecda275c..62fa0d53a9 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -46,6 +46,7 @@ from .nn import (
     pixel_shuffle,
     prelu,
     relu,
+    relu6,
     rms_norm,
     selu,
     silu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index b68d488e26..c6beea3158 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -17,7 +17,7 @@
 """Relax Neural Network (NN) operators"""
 from typing import List, Optional, Tuple, Union
 
-from tvm import DataType
+from tvm import DataType, relax
 from tvm.tir import FloatImm
 
 from ...expr import Expr
@@ -1267,6 +1267,25 @@ def relu(data: Expr) -> Expr:
     return _ffi_api.relu(data)  # type: ignore
 
 
+def relu6(data: Expr) -> Expr:
+    r"""ReLU6 activation function.
+
+    .. math::
+        \text{ReLU6}(x) = \min(\max(x, 0), 6)
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return relax.op.clip(data, 0, 6)
+
+
 def leakyrelu(data: Expr, alpha: float = 0.01) -> Expr:
     """Rectified linear unit.
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 4cb9e903a1..f01da1336e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -514,6 +514,53 @@ def test_extended_unary_ops():
 
     verify_model(MinModel(), example_args, {}, expected_min)
 
+    # relu6
+    class ReLU6_1(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.relu6 = torch.nn.ReLU6()
+
+        def forward(self, x):
+            return self.relu6(x)
+
+    class ReLU6_2(torch.nn.Module):
+        def forward(self, x):
+            return torch.nn.functional.relu6(x)
+
+    class ReLU6_3(torch.nn.Module):
+        def forward(self, x):
+            return torch.ops.aten.relu6_(x)
+
+    @tvm.script.ir_module
+    class expected_relu6_1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    x, R.prim_value(T.float64(0.0)), 
R.prim_value(T.float64(6.0))
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_relu6_2:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(x)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
+    verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
+    verify_model(ReLU6_3(), example_args, {}, expected_relu6_2)
+
 
 def test_hardtanh():
     class Hardtanh(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 7928975301..f507071b07 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3208,27 +3208,32 @@ def test_extended_unary_ops():
     verify_model(ReLU1(), input_info, {}, expected_relu)
 
     # relu6
-    class ReLU6(Module):
+    class ReLU6_1(torch.nn.Module):
         def __init__(self):
             super().__init__()
             self.relu6 = torch.nn.ReLU6()
 
-        def forward(self, input):
-            return self.relu6(input)
+        def forward(self, x):
+            return self.relu6(x)
+
+    class ReLU6_2(torch.nn.Module):
+        def forward(self, x):
+            return torch.nn.functional.relu6(x)
 
     @tvm.script.ir_module
-    class expected_relu6:
+    class expected:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.clip(input_1, 0, 6)
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu6(inp_0)
                 gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
                 R.output(gv)
             return gv
 
-    verify_model(ReLU6(), input_info, {}, expected_relu6)
+    verify_model(ReLU6_1(), input_info, {}, expected)
+    verify_model(ReLU6_2(), input_info, {}, expected)
 
     # selu
     class Selu1(Module):
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 1af13f0487..5c400ef8be 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -385,6 +385,7 @@ def test_nn():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
             relu_out = op.relu(x)
+            relu6_out = op.relu6(x)
             silu_out = op.silu(x)
             gelu_out = op.gelu(x)
             sigmoid_out = op.sigmoid(x)
@@ -409,6 +410,7 @@ def test_nn():
         R.func_attr({"num_input": 4})
         with R.dataflow():
             relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x)
+            relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x)
             silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
             gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
             sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x)
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index 1bf4444848..a0ff507ef8 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -74,9 +74,12 @@ def test_linear_unit_infer_struct_info():
 
     _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), 
"float32"))
     _check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3), 
"float32", vdev0))
+    _check_inference(bb, relax.op.nn.relu6(x0), relax.TensorStructInfo((2, 3), 
"float32"))
+    _check_inference(bb, relax.op.nn.relu6(x6), relax.TensorStructInfo((2, 3), 
"float32", vdev0))
     _check_inference(bb, relax.op.nn.silu(x1), 
relax.TensorStructInfo(dtype="float32", ndim=3))
     _check_inference(bb, relax.op.nn.gelu(x2), 
relax.TensorStructInfo(dtype="float32"))
     _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), 
dtype=""))
+    _check_inference(bb, relax.op.nn.relu6(x3), relax.TensorStructInfo((2, 3), 
dtype=""))
     _check_inference(bb, relax.op.nn.gelu(x4), 
relax.TensorStructInfo(dtype=""))
     _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 
3), "float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 
4), dtype=""))
@@ -93,6 +96,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic():
 
     _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), 
"float32"))
     _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), 
"float32"))
+    _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo((4, n), 
"float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, 
n), "float32"))
     _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, 
n), "float32"))
 
@@ -106,6 +110,7 @@ def test_linear_unit_infer_struct_info_shape_var():
 
     _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, 
"float32"))
     _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, 
"float32"))
+    _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, 
"float32"))
 

Reply via email to