This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cad1c68bb9 [Relax][Pytorch] Update SELU Implementation Using 
Decomposed Core-Level Ops (#17797)
cad1c68bb9 is described below

commit cad1c68bb99f1fc853b3217e6e3e9f3e43f72777
Author: Deivanayaki S <[email protected]>
AuthorDate: Thu Apr 3 06:29:51 2025 +0530

    [Relax][Pytorch] Update SELU Implementation Using Decomposed Core-Level Ops 
(#17797)
    
    * Integrate SELU into core ops for native R.nn.selu support
    
    * fix trailing whitespace issue
    
    * fixing selu mapping issue in fx_graph and lint issue
    
    * update the test script of selu in fx graph
    
    * modify test script to fix selu module check
    
    * format documentations
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     | 31 ----------------------
 .../frontend/torch/exported_program_translator.py  |  2 +-
 python/tvm/relax/frontend/torch/fx_translator.py   |  4 +--
 python/tvm/relax/op/nn/__init__.py                 |  1 +
 python/tvm/relax/op/nn/nn.py                       | 24 +++++++++++++++++
 python/tvm/relax/transform/legalize_ops/nn.py      | 18 +++++++++++++
 src/relax/op/nn/nn.cc                              |  3 +++
 src/relax/op/nn/nn.h                               |  3 +++
 .../relax/test_frontend_from_exported_program.py   | 18 +++----------
 tests/python/relax/test_frontend_from_fx.py        | 18 +++----------
 10 files changed, 58 insertions(+), 64 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 74c620a33d..890f925079 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -340,37 +340,6 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         # Combine the positive and negative shrink results
         return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))
 
-    def _selu(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("alpha", 1.6732631921768188)
-        gamma = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("gamma", 1.0507009873554805)
-        dtype = x.struct_info.dtype
-
-        if isinstance(alpha, (int, float)):
-            alpha = relax.const(alpha, dtype)
-        else:
-            if not isinstance(alpha, relax.Var):
-                alpha = self.block_builder.emit(relax.const(alpha, dtype))
-
-        if isinstance(gamma, (int, float)):
-            gamma = relax.const(gamma, dtype)
-        else:
-            if not isinstance(gamma, relax.Var):
-                gamma = self.block_builder.emit(relax.const(gamma, dtype))
-
-        # gamma * (ReLU(x) + alpha * (exp(x) - 1))
-        return self.block_builder.emit(
-            relax.op.multiply(
-                gamma,
-                relax.op.add(
-                    relax.op.nn.relu(x),
-                    relax.op.multiply(
-                        alpha, relax.op.subtract(relax.op.exp(x), 
relax.const(1, dtype))
-                    ),
-                ),
-            )
-        )
-
     def _tril_triu(self, op: Callable) -> Callable:
         from torch import fx
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index b35cf4ce20..2e7c682aa3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -280,7 +280,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "relu.default": self._unary_op(relax.op.nn.relu),
             "round.default": self._round,
             "rsqrt.default": self._unary_op(relax.op.rsqrt),
-            "selu.default": self._selu,
+            "selu.default": self._unary_op(relax.op.nn.selu),
             "sigmoid.default": self._unary_op(relax.op.sigmoid),
             "sign.default": self._unary_op(relax.op.sign),
             "silu.default": self._unary_op(relax.op.nn.silu),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index c3d605a329..3ddf919c2e 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -650,7 +650,7 @@ class TorchFXImporter(BaseFXGraphImporter):
                 relax.op.clip(self.env[node.args[0]], 0, 6)
             ),
             nn.Sigmoid: self._unary_op(relax.op.sigmoid),
-            nn.SELU: self._selu,
+            nn.SELU: self._unary_op(relax.op.nn.selu),
             nn.SiLU: self._unary_op(relax.op.nn.silu),
             nn.Softmax: self._softmax_module,
             nn.Tanh: self._unary_op(relax.op.tanh),
@@ -710,7 +710,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "relu": self._unary_op(relax.op.nn.relu),
             "round": self._round,
             "rsqrt": self._unary_op(relax.op.rsqrt),
-            "selu": self._selu,
+            "selu": self._unary_op(relax.op.nn.selu),
             "sigmoid": self._unary_op(relax.op.sigmoid),
             "sign": self._unary_op(relax.op.sign),
             "silu": self._unary_op(relax.op.nn.silu),
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index 61212f33d8..e45982a0fe 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -45,6 +45,7 @@ from .nn import (
     pad,
     relu,
     rms_norm,
+    selu,
     silu,
     softmax,
 )
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 09a7df5149..5232eea047 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1304,6 +1304,30 @@ def gelu_tanh(data: Expr) -> Expr:
     return _ffi_api.gelu_tanh(data)  # type: ignore
 
 
+def selu(data: Expr) -> Expr:
+    r"""Scaled Exponential Linear Unit (SELU).
+
+    .. math::
+        \text{SELU}(x) = \lambda \begin{cases}
+            x & \text{if } x > 0 \\
+            \alpha (e^x - 1) & \text{if } x \leq 0
+        \end{cases}
+
+    where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6733`.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.selu(data)
+
+
 def silu(data: Expr) -> Expr:
     r"""Sigmoid Linear Unit function
 
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 4c8bdbc661..fd3db841e6 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -505,6 +505,24 @@ def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(te_gelu_tanh, call.args[0], 
primfunc_name_hint="gelu_tanh")
 
 
+@register_legalize("relax.nn.selu")
+def _nn_selu(bb: BlockBuilder, call: Call) -> Expr:
+    def te_selu(x: te.Tensor):
+        dtype = x.dtype
+        alpha = tir.const(1.6732632423543772848170429916717, dtype)
+        scale = tir.const(1.0507009873554804934193349852946, dtype)
+
+        # Compute SELU
+        # SELU(x) = scale∗(max(0,x)+min(0,α∗(exp(x)−1)))
+        positive_part = topi.maximum(x, tir.const(0, dtype))
+        negative_part = topi.minimum(
+            tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))
+        )
+        return scale * (positive_part + negative_part)
+
+    return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu")
+
+
 @register_legalize("relax.nn.silu")
 def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
     def te_silu(x: te.Tensor):
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index c768ea19af..4a5a9a7016 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -34,6 +34,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", 
/*require_float_dtype=*/tru
 /* relax.nn.gelu_tanh */
 RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh", 
/*require_float_dtype=*/true);
 
+/* relax.nn.selu */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(selu, "nn.selu", 
/*require_float_dtype=*/true);
+
 /* relax.nn.silu */
 RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", 
/*require_float_dtype=*/true);
 
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index 28c14139b9..d6db36aba5 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -57,6 +57,9 @@ Expr gelu(Expr data);
 /*! \brief Gaussian Error Linear Units function approximated by tanh. */
 Expr gelu_tanh(Expr data);
 
+/*! \brief Scaled Exponential Linear Unit function. */
+Expr selu(Expr data);
+
 /*! \brief Sigmoid Linear Unit function. */
 Expr silu(Expr data);
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e37ee0e404..2175f9aa39 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -544,24 +544,12 @@ def test_extended_unary_ops():
     class expected_selu:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            input: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             with R.dataflow():
-                lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
-                lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.exp(input_1)
-                lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
-                    lv_exp, R.const(1.0, "float32")
-                )
-                lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
-                    R.const(1.6732631921768188, "float32"), lv_sub
-                )
-                lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_relu, lv_scaled)
-                lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
-                    R.const(1.0507010221481323, "float32"), lv_add
-                )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = 
(lv_selu,)
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.selu(input)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
-
             return gv
 
     verify_model(Selu1(), example_args, {}, expected_selu)
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index b8d7f0b14e..d913baf13a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2429,23 +2429,11 @@ def test_extended_unary_ops():
     class expected_selu:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
-            # block 0
             with R.dataflow():
-                lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
-                lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.exp(input_1)
-                lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
-                    lv_exp, R.const(1.0, "float32")
-                )
-                lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
-                    R.const(1.6732631921768188, "float32"), lv_sub
-                )
-                lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_relu, lv_scaled)
-                lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
-                    R.const(1.0507009873554805, "float32"), lv_add
-                )
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.selu(inp_0)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
                 R.output(gv)
             return gv
 

Reply via email to