This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf61216566 [Relax][PyTorch] Add Softplus Op Support for Exported 
Program and FX graph (#17806)
bf61216566 is described below

commit bf61216566e4d691b28d0950f99e2e083ffd934d
Author: Deivanayaki S <[email protected]>
AuthorDate: Wed Apr 9 15:26:46 2025 +0530

    [Relax][PyTorch] Add Softplus Op Support for Exported Program and FX graph 
(#17806)
    
    * add softplus op into exported program and fx graph frontend
    
    * fixing trailing whitespace issue
    
    * fixing lint issues
    
    * fix lint issue on docs
    
    * modify description to avoid cpplints issue
    
    * update softplus function with threshold attr
    
    * remove trailing spaces in softplus func
    
    * fix lint issues in legalize func
    
    * fixing cpp lints issue
    
    * test script for both exported and fx graph
    
    * trim trailing spaces iin test script
    
    * fix lint issues in test script
    
    * unit test script is added in test frontend op files
    
    * fixing lint issues in test_op_nn file
    
    * fixing attribute error in test script
    
    * fixing lint issues in test script functions
    
    * adding softplus wrapper function in op file
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 include/tvm/relax/attrs/nn.h                       | 13 +++++++
 python/tvm/relax/frontend/nn/op.py                 | 26 ++++++++++++++
 .../frontend/torch/base_fx_graph_translator.py     |  6 ++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  9 +++++
 python/tvm/relax/op/nn/__init__.py                 |  1 +
 python/tvm/relax/op/nn/nn.py                       | 25 ++++++++++++++
 python/tvm/relax/transform/legalize_ops/nn.py      | 10 ++++++
 python/tvm/topi/nn/elemwise.py                     | 33 ++++++++++++++++++
 src/relax/op/nn/nn.cc                              | 21 ++++++++++++
 src/relax/op/nn/nn.h                               |  3 ++
 .../relax/test_frontend_from_exported_program.py   | 40 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 40 ++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py          |  4 +++
 tests/python/relax/test_op_nn.py                   |  5 +++
 15 files changed, 237 insertions(+)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 8f63012e09..0adcf29772 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -455,6 +455,19 @@ struct LeakyReluAttrs : public 
tvm::AttrsNode<LeakyReluAttrs> {
   }
 };
 
+/*! \brief Attributes used in softplus operators */
+struct SoftplusAttrs : public tvm::AttrsNode<SoftplusAttrs> {
+  double beta;
+  double threshold;
+
+  TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") {
+    TVM_ATTR_FIELD(beta).describe(
+        "Scaling factor controlling the sharpness of the Softplus 
transition.");
+    TVM_ATTR_FIELD(threshold).describe(
+        "Value determining when to use linear approximation for numerical 
stability.");
+  }
+};
+
 /*! \brief Attributes used in batch_norm operator */
 struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
   int axis;
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 23045f7c4e..e81ff7c5ad 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1046,6 +1046,32 @@ def softmax(x: Tensor, axis: int = -1, name: str = 
"softmax") -> Tensor:
     return wrap_nested(_op.nn.softmax(x._expr, axis), name)
 
 
+def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str 
= "softplus"):
+    r"""Softplus activation function.
+
+    .. math::
+        \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data.
+
+    beta : float, optional
+        Controls the smoothness of the transition. Default is 1.0.
+
+    threshold : float, optional
+        The value beyond which the function is approximated as linear
+        to avoid numerical instability. Default is 20.0.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return wrap_nested(_op.nn.softplus(x._expr, beta=beta, 
threshold=threshold), name)
+
+
 def tanh(x: Tensor, name: str = "tanh") -> Tensor:
     r"""Applies the hyperbolic tangent function.
 
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index affbd81e1c..d1a42d645c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -308,6 +308,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         return self.block_builder.emit(relax.op.nn.softmax(x, dim))
 
+    def _softplus(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        beta = node.args[1] if len(node.args) > 1 else node.kwargs.get("beta", 
1.0)
+        threshold = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("threshold", 20.0)
+        return self.block_builder.emit(relax.op.nn.softplus(x, beta, 
threshold))
+
     def _softshrink(self, node: fx.Node) -> relax.Var:
         """
         Applies the Softshrink activation function in Relax.
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5e38d2ff6c..73742f952b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "sin.default": self._unary_op(relax.op.sin),
             "sinh.default": self._unary_op(relax.op.sinh),
             "softmax.int": self._softmax,
+            "softplus.default": self._softplus,
             "softshrink.default": self._softshrink,
             "sqrt.default": self._unary_op(relax.op.sqrt),
             "square.default": self._unary_op(relax.op.square),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a151a57ae6..f3732b3472 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -72,6 +72,13 @@ class TorchFXImporter(BaseFXGraphImporter):
         alpha = module.negative_slope
         return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
 
+    def _softplus_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        beta = module.beta
+        threshold = module.threshold
+        return self.block_builder.emit(relax.op.nn.softplus(x, beta, 
threshold))
+
     def _log2(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         return self.block_builder.emit(
@@ -622,6 +629,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             nn.SELU: self._unary_op(relax.op.nn.selu),
             nn.SiLU: self._unary_op(relax.op.nn.silu),
             nn.Softmax: self._softmax_module,
+            nn.Softplus: self._softplus_module,
             nn.Tanh: self._unary_op(relax.op.tanh),
             # neural network
             nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
@@ -686,6 +694,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "sin": self._unary_op(relax.op.sin),
             "sinh": self._unary_op(relax.op.sinh),
             "softmax": self._softmax,
+            "softplus": self._softplus,
             "sqrt": self._unary_op(relax.op.sqrt),
             "square": self._unary_op(relax.op.square),
             "tan": self._unary_op(relax.op.tan),
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index e45982a0fe..9d56058e46 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -48,4 +48,5 @@ from .nn import (
     selu,
     silu,
     softmax,
+    softplus,
 )
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5232eea047..17197b010e 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1378,6 +1378,31 @@ def softmax(data: Expr, axis: int = -1) -> Expr:
     return _ffi_api.softmax(data, axis)  # type: ignore
 
 
+def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr:
+    r"""Softplus activation function.
+
+    .. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data.
+
+    beta : float, optional
+        Controls the smoothness of the transition. Default is 1.0.
+
+    threshold : float, optional
+        The value beyond which the function is approximated as linear
+        to avoid numerical instability. Default is 20.0.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.softplus(data, beta, threshold)
+
+
 def log_softmax(data: Expr, axis: int = -1) -> Expr:
     r"""Computes log softmax.
 
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index fd3db841e6..98fa3ef1ea 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -531,6 +531,16 @@ def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu")
 
 
+@register_legalize("relax.nn.softplus")
+def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.nn.softplus,
+        call.args[0],
+        call.attrs.beta,
+        call.attrs.threshold,
+    )
+
+
 @register_legalize("relax.nn.softmax")
 def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis)
diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py
index a80047d900..2b174f8f1e 100644
--- a/python/tvm/topi/nn/elemwise.py
+++ b/python/tvm/topi/nn/elemwise.py
@@ -65,6 +65,39 @@ def leaky_relu(x, alpha):
     return te.compute(x.shape, _compute)
 
 
[email protected]_scope(tag=tag.ELEMWISE)
+def softplus(x, beta=1.0, threshold=20.0):
+    """Compute Softplus activation for input x with numerical stability.
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        Input tensor.
+
+    beta : float, optional
+        The scaling factor β in the Softplus formula (default is 1.0).
+
+    threshold : float, optional
+        The threshold value for numerical stability (default is 20.0).
+
+    Returns
+    -------
+    y : tvm.te.Tensor
+        The result.
+    """
+
+    def _compute(*indices):
+        value = x(*indices)
+        b = tvm.tir.const(beta, value.dtype)
+        t = tvm.tir.const(threshold, value.dtype)
+
+        return tvm.tir.Select(
+            b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * 
value))
+        )
+
+    return te.compute(x.shape, _compute)
+
+
 @tvm.te.tag_scope(tag=tag.BROADCAST)
 def prelu(x, slope, axis=1):
     """PReLU.
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 4a5a9a7016..7f545af130 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -60,6 +60,27 @@ TVM_REGISTER_OP("relax.nn.leakyrelu")
                                 
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.softplus */
+TVM_REGISTER_NODE_TYPE(SoftplusAttrs);
+
+Expr softplus(Expr data, double beta, double threshold) {
+  auto attrs = make_object<SoftplusAttrs>();
+  attrs->beta = beta;
+  attrs->threshold = threshold;
+  static const Op& op = Op::Get("relax.nn.softplus");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus);
+
+TVM_REGISTER_OP("relax.nn.softplus")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attrs_type<SoftplusAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo",
+                                
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.nn.softmax */
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index d6db36aba5..3f5571af82 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -66,6 +66,9 @@ Expr silu(Expr data);
 /*! \brief Softmax function. */
 Expr softmax(Expr data, int axis);
 
+/*! \brief Softplus function. */
+Expr softplus(Expr data, double beta, double threshold);
+
 /*! \brief LogSoftmax function. */
 Expr log_softmax(Expr data, int axis);
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index a3c939fcb6..58f28fb1b3 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -419,6 +419,9 @@ def test_extended_unary_ops():
     # leakyrelu
     test_leakyrelu()
 
+    # softplus
+    test_softplus()
+
     # log2
     class Log2(Module):
         def forward(self, x):
@@ -680,6 +683,43 @@ def test_hardtanh():
     verify_model(Hardtanh3(), example_args, {}, expected1)
 
 
+def test_softplus():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+
+    class Softplus0(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.softplus = torch.nn.Softplus(1.0, 20.0)
+
+        def forward(self, x):
+            return self.softplus(x)
+
+    class Softplus1(Module):
+        def forward(self, input):
+            return torch.nn.functional.softplus(input, 1.0, 20.0)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(
+                    x, beta=1.0, threshold=20.0
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Softplus0(), example_args, {}, expected)
+    verify_model(Softplus1(), example_args, {}, expected)
+
+
 def test_leakyrelu():
     import torch
     from torch.nn import Module
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2c5560b577..fd9bfdf633 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -749,6 +749,43 @@ def test_einsum():
     verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, 
Expected2)
 
 
[email protected]_gpu
+def test_softplus():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+
+    class Softplus0(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.softplus = torch.nn.Softplus(1.0, 20.0)
+
+        def forward(self, x):
+            return self.softplus(x)
+
+    class Softplus1(Module):
+        def forward(self, input):
+            return torch.nn.functional.softplus(input, 1.0, 20.0)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 
10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus(
+                    inp_0, beta=1.0, threshold=20.0
+                )
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    input_info = [([10, 10], "float32")]
+    verify_model(Softplus0(), input_info, {}, expected)
+    verify_model(Softplus1(), input_info, {}, expected)
+
+
 @tvm.testing.requires_gpu
 def test_leakyrelu():
     import torch
@@ -2226,6 +2263,9 @@ def test_extended_unary_ops():
     # leaky_relu
     test_leakyrelu()
 
+    # softplus
+    test_softplus()
+
     # log2
     class Log2(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 6e63b0e4c0..ed81aa49ed 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -391,6 +391,7 @@ def test_nn():
             tanh_out = op.tanh(x)
             exp_out = op.exp(x)
             negative_out = op.negative(x)
+            softplus_out = op.softplus(x, beta=1.0, threshold=20.0)
             softmax_out = op.softmax(x, axis=2)
             rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
             rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
@@ -413,6 +414,9 @@ def test_nn():
             tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x)
             exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x)
             negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x)
+            softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus(
+                x, beta=1.0, threshold=20.0
+            )
             softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, 
axis=2)
             rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
                 x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index ec4551872f..2401153c61 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -27,6 +27,7 @@ def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu")
     assert relax.op.nn.leakyrelu(x).op == Op.get("relax.nn.leakyrelu")
+    assert relax.op.nn.softplus(x).op == Op.get("relax.nn.softplus")
     assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu")
     assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu")
     assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
@@ -75,6 +76,8 @@ def test_linear_unit_infer_struct_info():
     _check_inference(bb, relax.op.nn.gelu(x4), 
relax.TensorStructInfo(dtype=""))
     _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 
3), "float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 
4), dtype=""))
+    _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2, 
3), "float32"))
+    _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 
4), dtype=""))
 
 
 def test_linear_unit_infer_struct_info_shape_symbolic():
@@ -87,6 +90,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic():
     _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), 
"float32"))
     _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), 
"float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, 
n), "float32"))
+    _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, 
n), "float32"))
 
 
 def test_linear_unit_infer_struct_info_shape_var():
@@ -99,6 +103,7 @@ def test_linear_unit_infer_struct_info_shape_var():
     _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, 
"float32"))
     _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, 
"float32"))
+    _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, 
"float32"))
 
 
 def test_linear_unit_infer_struct_info_more_input_dtype():

Reply via email to