This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aaf185bd4b [Relax][PyTorch] Support softshrink op for ExportedProgram 
(#17786)
aaf185bd4b is described below

commit aaf185bd4bc837350f2976da6e02003b1e13af81
Author: Deivanayaki S <[email protected]>
AuthorDate: Mon Mar 31 23:38:38 2025 +0530

    [Relax][PyTorch] Support softshrink op for ExportedProgram (#17786)
    
    * softshrink op support into exported program and test script code added
    
    * fix lint issue
    
    * update the formatting to fix lint issues
    
    * modify the code format to fix lint issue
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     | 33 ++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 .../relax/test_frontend_from_exported_program.py   | 51 ++++++++++++++++++++++
 3 files changed, 85 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index fe0ae412a2..839b4eb1bd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -307,6 +307,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         return self.block_builder.emit(relax.op.nn.softmax(x, dim))
 
+    def _softshrink(self, node: fx.Node) -> relax.Var:
+        """
+        Applies the Softshrink activation function in Relax.
+
+        Softshrink(x) =
+            x - λ    if x > λ
+            x + λ    if x < -λ
+            0        otherwise
+
+        Args:
+            node (fx.Node): The input node containing the tensor and lambda 
value.
+
+        Returns:
+            relax.Var: The resulting tensor after applying Softshrink.
+        """
+        args = self.retrieve_args(node)
+        x = args[0]
+        lambd = relax.const(args[1] if len(args) > 1 else 0.5, 
x.struct_info.dtype)
+
+        # Apply Softshrink transformation with masking
+        shrink_pos = relax.op.multiply(
+            relax.op.subtract(x, lambd),
+            relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype),
+        )
+
+        shrink_neg = relax.op.multiply(
+            relax.op.add(x, lambd),
+            relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), 
x.struct_info.dtype),
+        )
+
+        # Combine the positive and negative shrink results
+        return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))
+
     def _selu(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("alpha", 1.6732631921768188)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0f1dc11787..a28da6ee72 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -282,6 +282,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "sin.default": self._unary_op(relax.op.sin),
             "sinh.default": self._unary_op(relax.op.sinh),
             "softmax.int": self._softmax,
+            "softshrink.default": self._softshrink,
             "sqrt.default": self._unary_op(relax.op.sqrt),
             "square.default": self._unary_op(relax.op.square),
             "tan.default": self._unary_op(relax.op.tan),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 739fe87dc9..98f0f1d9ca 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -607,6 +607,9 @@ def test_extended_unary_ops():
     # softmax
     test_softmax()
 
+    # softshrink
+    test_softshrink()
+
     # tril, triu
     test_tril_triu()
 
@@ -741,6 +744,54 @@ def test_softmax():
     verify_model(Softmax2(), example_args, {}, expected1)
 
 
+def test_softshrink():
+    class Softshrink(Module):
+        def __init__(self):
+            super().__init__()
+            self.softshrink = torch.nn.Softshrink(lambd=0.5)
+
+        def forward(self, input):
+            return self.softshrink(input)
+
+    class Softshrink2(Module):
+        def forward(self, input):
+            return torch.nn.functional.softshrink(input, lambd=0.5)
+
+    @tvm.script.ir_module
+    class expected_softshrink:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+                    input, R.const(0.5, "float32")
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+                    input, R.const(0.5, "float32")
+                )
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, 
"float32")
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lv, lv2)
+
+                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+                    input, R.const(0.5, "float32")
+                )
+                lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, 
"float32"))
+                lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, 
lv5)
+                lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, 
"float32")
+                lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lv4, lv7)
+
+                lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, 
lv8)
+
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Softshrink(), example_args, {}, expected_softshrink)
+    verify_model(Softshrink2(), example_args, {}, expected_softshrink)
+
+
 def test_tril_triu():
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
 

Reply via email to