This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf61216566 [Relax][PyTorch] Add Softplus Op Support for Exported
Program and FX graph (#17806)
bf61216566 is described below
commit bf61216566e4d691b28d0950f99e2e083ffd934d
Author: Deivanayaki S <[email protected]>
AuthorDate: Wed Apr 9 15:26:46 2025 +0530
[Relax][PyTorch] Add Softplus Op Support for Exported Program and FX graph
(#17806)
* add softplus op into exported program and fx graph frontend
* fixing trailing whitespace issue
* fixing lint issues
* fix lint issue on docs
* modify description to avoid cpplints issue
* update softplus function with threshold attr
* remove trailing spaces in softplus func
* fix lint issues in legalize func
* fixing cpp lints issue
* test script for both exported and fx graph
* trim trailing spaces iin test script
* fix lint issues in test script
* unit test script is added in test frontend op files
* fixing lint issues in test_op_nn file
* fixing attribute error in test script
* fixing lint issues in test script functions
* adding softplus wrapper function in op file
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
include/tvm/relax/attrs/nn.h | 13 +++++++
python/tvm/relax/frontend/nn/op.py | 26 ++++++++++++++
.../frontend/torch/base_fx_graph_translator.py | 6 ++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 25 ++++++++++++++
python/tvm/relax/transform/legalize_ops/nn.py | 10 ++++++
python/tvm/topi/nn/elemwise.py | 33 ++++++++++++++++++
src/relax/op/nn/nn.cc | 21 ++++++++++++
src/relax/op/nn/nn.h | 3 ++
.../relax/test_frontend_from_exported_program.py | 40 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 40 ++++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 4 +++
tests/python/relax/test_op_nn.py | 5 +++
15 files changed, 237 insertions(+)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 8f63012e09..0adcf29772 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -455,6 +455,19 @@ struct LeakyReluAttrs : public
tvm::AttrsNode<LeakyReluAttrs> {
}
};
+/*! \brief Attributes used in softplus operators */
+struct SoftplusAttrs : public tvm::AttrsNode<SoftplusAttrs> {
+ double beta;
+ double threshold;
+
+ TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") {
+ TVM_ATTR_FIELD(beta).describe(
+ "Scaling factor controlling the sharpness of the Softplus
transition.");
+ TVM_ATTR_FIELD(threshold).describe(
+ "Value determining when to use linear approximation for numerical
stability.");
+ }
+};
+
/*! \brief Attributes used in batch_norm operator */
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
int axis;
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 23045f7c4e..e81ff7c5ad 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1046,6 +1046,32 @@ def softmax(x: Tensor, axis: int = -1, name: str =
"softmax") -> Tensor:
return wrap_nested(_op.nn.softmax(x._expr, axis), name)
+def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str
= "softplus"):
+ r"""Softplus activation function.
+
+ .. math::
+ \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data.
+
+ beta : float, optional
+ Controls the smoothness of the transition. Default is 1.0.
+
+ threshold : float, optional
+ The value beyond which the function is approximated as linear
+ to avoid numerical instability. Default is 20.0.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return wrap_nested(_op.nn.softplus(x._expr, beta=beta,
threshold=threshold), name)
+
+
def tanh(x: Tensor, name: str = "tanh") -> Tensor:
r"""Applies the hyperbolic tangent function.
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index affbd81e1c..d1a42d645c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -308,6 +308,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+ def _softplus(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ beta = node.args[1] if len(node.args) > 1 else node.kwargs.get("beta",
1.0)
+ threshold = node.args[2] if len(node.args) > 2 else
node.kwargs.get("threshold", 20.0)
+ return self.block_builder.emit(relax.op.nn.softplus(x, beta,
threshold))
+
def _softshrink(self, node: fx.Node) -> relax.Var:
"""
Applies the Softshrink activation function in Relax.
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5e38d2ff6c..73742f952b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
+ "softplus.default": self._softplus,
"softshrink.default": self._softshrink,
"sqrt.default": self._unary_op(relax.op.sqrt),
"square.default": self._unary_op(relax.op.square),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a151a57ae6..f3732b3472 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -72,6 +72,13 @@ class TorchFXImporter(BaseFXGraphImporter):
alpha = module.negative_slope
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+ def _softplus_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ beta = module.beta
+ threshold = module.threshold
+ return self.block_builder.emit(relax.op.nn.softplus(x, beta,
threshold))
+
def _log2(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.block_builder.emit(
@@ -622,6 +629,7 @@ class TorchFXImporter(BaseFXGraphImporter):
nn.SELU: self._unary_op(relax.op.nn.selu),
nn.SiLU: self._unary_op(relax.op.nn.silu),
nn.Softmax: self._softmax_module,
+ nn.Softplus: self._softplus_module,
nn.Tanh: self._unary_op(relax.op.tanh),
# neural network
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
@@ -686,6 +694,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"sin": self._unary_op(relax.op.sin),
"sinh": self._unary_op(relax.op.sinh),
"softmax": self._softmax,
+ "softplus": self._softplus,
"sqrt": self._unary_op(relax.op.sqrt),
"square": self._unary_op(relax.op.square),
"tan": self._unary_op(relax.op.tan),
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index e45982a0fe..9d56058e46 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -48,4 +48,5 @@ from .nn import (
selu,
silu,
softmax,
+ softplus,
)
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5232eea047..17197b010e 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1378,6 +1378,31 @@ def softmax(data: Expr, axis: int = -1) -> Expr:
return _ffi_api.softmax(data, axis) # type: ignore
+def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr:
+ r"""Softplus activation function.
+
+ .. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data.
+
+ beta : float, optional
+ Controls the smoothness of the transition. Default is 1.0.
+
+ threshold : float, optional
+ The value beyond which the function is approximated as linear
+ to avoid numerical instability. Default is 20.0.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.softplus(data, beta, threshold)
+
+
def log_softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes log softmax.
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index fd3db841e6..98fa3ef1ea 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -531,6 +531,16 @@ def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu")
+@register_legalize("relax.nn.softplus")
+def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.nn.softplus,
+ call.args[0],
+ call.attrs.beta,
+ call.attrs.threshold,
+ )
+
+
@register_legalize("relax.nn.softmax")
def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis)
diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py
index a80047d900..2b174f8f1e 100644
--- a/python/tvm/topi/nn/elemwise.py
+++ b/python/tvm/topi/nn/elemwise.py
@@ -65,6 +65,39 @@ def leaky_relu(x, alpha):
return te.compute(x.shape, _compute)
[email protected]_scope(tag=tag.ELEMWISE)
+def softplus(x, beta=1.0, threshold=20.0):
+ """Compute Softplus activation for input x with numerical stability.
+
+ Parameters
+ ----------
+ x : tvm.te.Tensor
+ Input tensor.
+
+ beta : float, optional
+ The scaling factor β in the Softplus formula (default is 1.0).
+
+ threshold : float, optional
+ The threshold value for numerical stability (default is 20.0).
+
+ Returns
+ -------
+ y : tvm.te.Tensor
+ The result.
+ """
+
+ def _compute(*indices):
+ value = x(*indices)
+ b = tvm.tir.const(beta, value.dtype)
+ t = tvm.tir.const(threshold, value.dtype)
+
+ return tvm.tir.Select(
+ b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b *
value))
+ )
+
+ return te.compute(x.shape, _compute)
+
+
@tvm.te.tag_scope(tag=tag.BROADCAST)
def prelu(x, slope, axis=1):
"""PReLU.
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 4a5a9a7016..7f545af130 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -60,6 +60,27 @@ TVM_REGISTER_OP("relax.nn.leakyrelu")
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.softplus */
+TVM_REGISTER_NODE_TYPE(SoftplusAttrs);
+
+Expr softplus(Expr data, double beta, double threshold) {
+ auto attrs = make_object<SoftplusAttrs>();
+ attrs->beta = beta;
+ attrs->threshold = threshold;
+ static const Op& op = Op::Get("relax.nn.softplus");
+ return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus);
+
+TVM_REGISTER_OP("relax.nn.softplus")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attrs_type<SoftplusAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo",
+
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.nn.softmax */
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index d6db36aba5..3f5571af82 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -66,6 +66,9 @@ Expr silu(Expr data);
/*! \brief Softmax function. */
Expr softmax(Expr data, int axis);
+/*! \brief Softplus function. */
+Expr softplus(Expr data, double beta, double threshold);
+
/*! \brief LogSoftmax function. */
Expr log_softmax(Expr data, int axis);
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index a3c939fcb6..58f28fb1b3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -419,6 +419,9 @@ def test_extended_unary_ops():
# leakyrelu
test_leakyrelu()
+ # softplus
+ test_softplus()
+
# log2
class Log2(Module):
def forward(self, x):
@@ -680,6 +683,43 @@ def test_hardtanh():
verify_model(Hardtanh3(), example_args, {}, expected1)
+def test_softplus():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+
+ class Softplus0(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.softplus = torch.nn.Softplus(1.0, 20.0)
+
+ def forward(self, x):
+ return self.softplus(x)
+
+ class Softplus1(Module):
+ def forward(self, input):
+ return torch.nn.functional.softplus(input, 1.0, 20.0)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(
+ x, beta=1.0, threshold=20.0
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Softplus0(), example_args, {}, expected)
+ verify_model(Softplus1(), example_args, {}, expected)
+
+
def test_leakyrelu():
import torch
from torch.nn import Module
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2c5560b577..fd9bfdf633 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -749,6 +749,43 @@ def test_einsum():
verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {},
Expected2)
[email protected]_gpu
+def test_softplus():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+
+ class Softplus0(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.softplus = torch.nn.Softplus(1.0, 20.0)
+
+ def forward(self, x):
+ return self.softplus(x)
+
+ class Softplus1(Module):
+ def forward(self, input):
+ return torch.nn.functional.softplus(input, 1.0, 20.0)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10,
10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus(
+ inp_0, beta=1.0, threshold=20.0
+ )
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ input_info = [([10, 10], "float32")]
+ verify_model(Softplus0(), input_info, {}, expected)
+ verify_model(Softplus1(), input_info, {}, expected)
+
+
@tvm.testing.requires_gpu
def test_leakyrelu():
import torch
@@ -2226,6 +2263,9 @@ def test_extended_unary_ops():
# leaky_relu
test_leakyrelu()
+ # softplus
+ test_softplus()
+
# log2
class Log2(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 6e63b0e4c0..ed81aa49ed 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -391,6 +391,7 @@ def test_nn():
tanh_out = op.tanh(x)
exp_out = op.exp(x)
negative_out = op.negative(x)
+ softplus_out = op.softplus(x, beta=1.0, threshold=20.0)
softmax_out = op.softmax(x, axis=2)
rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
@@ -413,6 +414,9 @@ def test_nn():
tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x)
exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x)
negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x)
+ softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus(
+ x, beta=1.0, threshold=20.0
+ )
softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x,
axis=2)
rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index ec4551872f..2401153c61 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -27,6 +27,7 @@ def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu")
assert relax.op.nn.leakyrelu(x).op == Op.get("relax.nn.leakyrelu")
+ assert relax.op.nn.softplus(x).op == Op.get("relax.nn.softplus")
assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu")
assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu")
assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
@@ -75,6 +76,8 @@ def test_linear_unit_infer_struct_info():
_check_inference(bb, relax.op.nn.gelu(x4),
relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2,
3), "float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3,
4), dtype=""))
+ _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2,
3), "float32"))
+ _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3,
4), dtype=""))
def test_linear_unit_infer_struct_info_shape_symbolic():
@@ -87,6 +90,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic():
_check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n),
"float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n),
"float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4,
n), "float32"))
+ _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4,
n), "float32"))
def test_linear_unit_infer_struct_info_shape_var():
@@ -99,6 +103,7 @@ def test_linear_unit_infer_struct_info_shape_var():
_check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0,
"float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1,
"float32"))
+ _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1,
"float32"))
def test_linear_unit_infer_struct_info_more_input_dtype():