This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aaf185bd4b [Relax][PyTorch] Support softshrink op for ExportedProgram
(#17786)
aaf185bd4b is described below
commit aaf185bd4bc837350f2976da6e02003b1e13af81
Author: Deivanayaki S <[email protected]>
AuthorDate: Mon Mar 31 23:38:38 2025 +0530
[Relax][PyTorch] Support softshrink op for ExportedProgram (#17786)
* softshrink op support into exported program and test script code added
* fix lint issue
* update the formatting to fix lint issues
* modify the code format to fix lint issue
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 33 ++++++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 51 ++++++++++++++++++++++
3 files changed, 85 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index fe0ae412a2..839b4eb1bd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -307,6 +307,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+ def _softshrink(self, node: fx.Node) -> relax.Var:
+ """
+ Applies the Softshrink activation function in Relax.
+
+ Softshrink(x) =
+ x - λ if x > λ
+ x + λ if x < -λ
+ 0 otherwise
+
+ Args:
+ node (fx.Node): The input node containing the tensor and lambda
value.
+
+ Returns:
+ relax.Var: The resulting tensor after applying Softshrink.
+ """
+ args = self.retrieve_args(node)
+ x = args[0]
+ lambd = relax.const(args[1] if len(args) > 1 else 0.5,
x.struct_info.dtype)
+
+ # Apply Softshrink transformation with masking
+ shrink_pos = relax.op.multiply(
+ relax.op.subtract(x, lambd),
+ relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype),
+ )
+
+ shrink_neg = relax.op.multiply(
+ relax.op.add(x, lambd),
+ relax.op.astype(relax.op.less(x, relax.op.negative(lambd)),
x.struct_info.dtype),
+ )
+
+ # Combine the positive and negative shrink results
+ return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))
+
def _selu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("alpha", 1.6732631921768188)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0f1dc11787..a28da6ee72 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -282,6 +282,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
+ "softshrink.default": self._softshrink,
"sqrt.default": self._unary_op(relax.op.sqrt),
"square.default": self._unary_op(relax.op.square),
"tan.default": self._unary_op(relax.op.tan),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 739fe87dc9..98f0f1d9ca 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -607,6 +607,9 @@ def test_extended_unary_ops():
# softmax
test_softmax()
+ # softshrink
+ test_softshrink()
+
# tril, triu
test_tril_triu()
@@ -741,6 +744,54 @@ def test_softmax():
verify_model(Softmax2(), example_args, {}, expected1)
+def test_softshrink():
+ class Softshrink(Module):
+ def __init__(self):
+ super().__init__()
+ self.softshrink = torch.nn.Softshrink(lambd=0.5)
+
+ def forward(self, input):
+ return self.softshrink(input)
+
+ class Softshrink2(Module):
+ def forward(self, input):
+ return torch.nn.functional.softshrink(input, lambd=0.5)
+
+ @tvm.script.ir_module
+ class expected_softshrink:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+ input, R.const(0.5, "float32")
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+ input, R.const(0.5, "float32")
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1,
"float32")
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lv, lv2)
+
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+ input, R.const(0.5, "float32")
+ )
+ lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5,
"float32"))
+ lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input,
lv5)
+ lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6,
"float32")
+ lv8: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lv4, lv7)
+
+ lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3,
lv8)
+
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Softshrink(), example_args, {}, expected_softshrink)
+ verify_model(Softshrink2(), example_args, {}, expected_softshrink)
+
+
def test_tril_triu():
example_args = (torch.randn(10, 10, dtype=torch.float32),)