This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cad1c68bb9 [Relax][Pytorch] Update SELU Implementation Using
Decomposed Core-Level Ops (#17797)
cad1c68bb9 is described below
commit cad1c68bb99f1fc853b3217e6e3e9f3e43f72777
Author: Deivanayaki S <[email protected]>
AuthorDate: Thu Apr 3 06:29:51 2025 +0530
[Relax][Pytorch] Update SELU Implementation Using Decomposed Core-Level Ops
(#17797)
* Integrate SELU into core ops for native R.nn.selu support
* fix trailing whitespace issue
* fixing selu mapping issue in fx_graph and lint issue
* update the test script of selu in fx graph
* modify test script to fix selu module check
* format documentations
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 31 ----------------------
.../frontend/torch/exported_program_translator.py | 2 +-
python/tvm/relax/frontend/torch/fx_translator.py | 4 +--
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 24 +++++++++++++++++
python/tvm/relax/transform/legalize_ops/nn.py | 18 +++++++++++++
src/relax/op/nn/nn.cc | 3 +++
src/relax/op/nn/nn.h | 3 +++
.../relax/test_frontend_from_exported_program.py | 18 +++----------
tests/python/relax/test_frontend_from_fx.py | 18 +++----------
10 files changed, 58 insertions(+), 64 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 74c620a33d..890f925079 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -340,37 +340,6 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
# Combine the positive and negative shrink results
return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))
- def _selu(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("alpha", 1.6732631921768188)
- gamma = node.args[2] if len(node.args) > 2 else
node.kwargs.get("gamma", 1.0507009873554805)
- dtype = x.struct_info.dtype
-
- if isinstance(alpha, (int, float)):
- alpha = relax.const(alpha, dtype)
- else:
- if not isinstance(alpha, relax.Var):
- alpha = self.block_builder.emit(relax.const(alpha, dtype))
-
- if isinstance(gamma, (int, float)):
- gamma = relax.const(gamma, dtype)
- else:
- if not isinstance(gamma, relax.Var):
- gamma = self.block_builder.emit(relax.const(gamma, dtype))
-
- # gamma * (ReLU(x) + alpha * (exp(x) - 1))
- return self.block_builder.emit(
- relax.op.multiply(
- gamma,
- relax.op.add(
- relax.op.nn.relu(x),
- relax.op.multiply(
- alpha, relax.op.subtract(relax.op.exp(x),
relax.const(1, dtype))
- ),
- ),
- )
- )
-
def _tril_triu(self, op: Callable) -> Callable:
from torch import fx
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index b35cf4ce20..2e7c682aa3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -280,7 +280,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
- "selu.default": self._selu,
+ "selu.default": self._unary_op(relax.op.nn.selu),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"sign.default": self._unary_op(relax.op.sign),
"silu.default": self._unary_op(relax.op.nn.silu),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index c3d605a329..3ddf919c2e 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -650,7 +650,7 @@ class TorchFXImporter(BaseFXGraphImporter):
relax.op.clip(self.env[node.args[0]], 0, 6)
),
nn.Sigmoid: self._unary_op(relax.op.sigmoid),
- nn.SELU: self._selu,
+ nn.SELU: self._unary_op(relax.op.nn.selu),
nn.SiLU: self._unary_op(relax.op.nn.silu),
nn.Softmax: self._softmax_module,
nn.Tanh: self._unary_op(relax.op.tanh),
@@ -710,7 +710,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
- "selu": self._selu,
+ "selu": self._unary_op(relax.op.nn.selu),
"sigmoid": self._unary_op(relax.op.sigmoid),
"sign": self._unary_op(relax.op.sign),
"silu": self._unary_op(relax.op.nn.silu),
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index 61212f33d8..e45982a0fe 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -45,6 +45,7 @@ from .nn import (
pad,
relu,
rms_norm,
+ selu,
silu,
softmax,
)
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 09a7df5149..5232eea047 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1304,6 +1304,30 @@ def gelu_tanh(data: Expr) -> Expr:
return _ffi_api.gelu_tanh(data) # type: ignore
+def selu(data: Expr) -> Expr:
+ r"""Scaled Exponential Linear Unit (SELU).
+
+ .. math::
+ \text{SELU}(x) = \lambda \begin{cases}
+ x & \text{if } x > 0 \\
+ \alpha (e^x - 1) & \text{if } x \leq 0
+ \end{cases}
+
+ where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6733`.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.selu(data)
+
+
def silu(data: Expr) -> Expr:
r"""Sigmoid Linear Unit function
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 4c8bdbc661..fd3db841e6 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -505,6 +505,24 @@ def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(te_gelu_tanh, call.args[0],
primfunc_name_hint="gelu_tanh")
+@register_legalize("relax.nn.selu")
+def _nn_selu(bb: BlockBuilder, call: Call) -> Expr:
+ def te_selu(x: te.Tensor):
+ dtype = x.dtype
+ alpha = tir.const(1.6732632423543772848170429916717, dtype)
+ scale = tir.const(1.0507009873554804934193349852946, dtype)
+
+ # Compute SELU
+ # SELU(x) = scale∗(max(0,x)+min(0,α∗(exp(x)−1)))
+ positive_part = topi.maximum(x, tir.const(0, dtype))
+ negative_part = topi.minimum(
+ tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))
+ )
+ return scale * (positive_part + negative_part)
+
+ return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu")
+
+
@register_legalize("relax.nn.silu")
def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
def te_silu(x: te.Tensor):
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index c768ea19af..4a5a9a7016 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -34,6 +34,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu",
/*require_float_dtype=*/tru
/* relax.nn.gelu_tanh */
RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh",
/*require_float_dtype=*/true);
+/* relax.nn.selu */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(selu, "nn.selu",
/*require_float_dtype=*/true);
+
/* relax.nn.silu */
RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu",
/*require_float_dtype=*/true);
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index 28c14139b9..d6db36aba5 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -57,6 +57,9 @@ Expr gelu(Expr data);
/*! \brief Gaussian Error Linear Units function approximated by tanh. */
Expr gelu_tanh(Expr data);
+/*! \brief Scaled Exponential Linear Unit function. */
+Expr selu(Expr data);
+
/*! \brief Sigmoid Linear Unit function. */
Expr silu(Expr data);
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e37ee0e404..2175f9aa39 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -544,24 +544,12 @@ def test_extended_unary_ops():
class expected_selu:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
- lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
- lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.exp(input_1)
- lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
- lv_exp, R.const(1.0, "float32")
- )
- lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.6732631921768188, "float32"), lv_sub
- )
- lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_relu, lv_scaled)
- lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.0507010221481323, "float32"), lv_add
- )
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(lv_selu,)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.selu(input)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
-
return gv
verify_model(Selu1(), example_args, {}, expected_selu)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index b8d7f0b14e..d913baf13a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2429,23 +2429,11 @@ def test_extended_unary_ops():
class expected_selu:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
- # block 0
with R.dataflow():
- lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
- lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.exp(input_1)
- lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
- lv_exp, R.const(1.0, "float32")
- )
- lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.6732631921768188, "float32"), lv_sub
- )
- lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_relu, lv_scaled)
- lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.0507009873554805, "float32"), lv_add
- )
- gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.selu(inp_0)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv