This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 12e72c29e8 [Relax][PyTorch] Add support for elu, hardtanh ops (#17694)
12e72c29e8 is described below

commit 12e72c29e877bced07f4d0cb5a0861e92cddb38c
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 2 11:47:52 2025 +0800

    [Relax][PyTorch] Add support for elu, hardtanh ops (#17694)
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update base_fx_graph_translator.py
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update fx_translator.py
    
    * Update fx_translator.py
    
    * lint
    
    * lint
    
    * lint
---
 .../frontend/torch/base_fx_graph_translator.py     | 29 +++++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |  4 ++
 tests/python/relax/test_frontend_from_fx.py        | 70 ++++++++++++++++++++++
 3 files changed, 103 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d84993c68d..e601f18181 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -127,6 +127,28 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             )
         return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
 
+    def _elu(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("alpha", 1.0)
+        dtype = x.struct_info.dtype
+
+        if isinstance(alpha, (int, float)):
+            alpha = relax.const(alpha, dtype)
+        else:
+            if not isinstance(alpha, relax.Var):
+                alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+        # α⋅ReLU(1−exp(x))+ReLU(x)
+        return self.block_builder.emit(
+            relax.op.add(
+                relax.op.multiply(
+                    alpha,
+                    relax.op.nn.relu(relax.op.subtract(relax.const(1, dtype), 
relax.op.exp(x))),
+                ),
+                relax.op.nn.relu(x),
+            )
+        )
+
     def _gelu(self, node: fx.Node) -> relax.Expr:
         approximate = node.kwargs.get("approximate", "none")
         if approximate == "none":
@@ -153,6 +175,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         x2 = relax.op.divide(x1, relax.const(6, dtype))
         return self.block_builder.emit(relax.op.multiply(x, x2))
 
+    def _hardtanh(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        x = args[0]
+        min_val = node.kwargs.get("min_val", -1.0)
+        max_val = node.kwargs.get("max_val", 1.0)
+        return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
+
     def _leakyrelu(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("negative_slope", 0.01)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index dffe2b60eb..bbad7c0c70 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -581,9 +581,11 @@ class TorchFXImporter(BaseFXGraphImporter):
             ## call_module
             # unary
             nn.Dropout: lambda node: self.env[node.args[0]],
+            nn.ELU: self._elu,
             nn.GELU: self._gelu,
             nn.Hardsigmoid: self._hardsigmoid,
             nn.Hardswish: self._hardswish,
+            nn.Hardtanh: self._hardtanh,
             nn.Identity: lambda node: self.env[node.args[0]],
             nn.LeakyReLU: self._leakyrelu_module,
             nn.LogSoftmax: self._log_softmax_module,
@@ -627,12 +629,14 @@ class TorchFXImporter(BaseFXGraphImporter):
             "cos": self._unary_op(relax.op.cos),
             "cosh": self._unary_op(relax.op.cosh),
             "dropout": lambda node: self.env[node.args[0]],
+            "elu": self._elu,
             "erf": self._unary_op(relax.op.erf),
             "exp": self._unary_op(relax.op.exp),
             "floor": self._unary_op(relax.op.floor),
             "gelu": self._gelu,
             "hardsigmoid": self._hardsigmoid,
             "hardswish": self._hardswish,
+            "hardtanh": self._hardtanh,
             "isfinite": self._unary_op(relax.op.isfinite),
             "isinf": self._unary_op(relax.op.isinf),
             "isnan": self._unary_op(relax.op.isnan),
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 9e7e1ff2ea..797ce05a3f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1989,6 +1989,46 @@ def test_extended_unary_ops():
     verify_model(Dropout1(), input_info, {}, expected_dropout)
     verify_model(Dropout2(), input_info, {}, expected_dropout)
 
+    # elu
+    class Elu(Module):
+        def __init__(self):
+            super().__init__()
+            self.elu = torch.nn.ELU()
+
+        def forward(self, input):
+            return self.elu(input)
+
+    class Elu2(Module):
+        def forward(self, input):
+            return torch.nn.functional.elu(input)
+
+    @tvm.script.ir_module
+    class expected_elu:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.exp(input_1)
+                lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(
+                    R.const(1.0, dtype="float32"), lv_exp
+                )
+                lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), 
dtype="float32") = R.nn.relu(
+                    lv_one_minus_exp
+                )
+                lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
+                    R.const(1.0, dtype="float32"), lv_relu_one_minus_exp
+                )
+                lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
+                lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_scaled, lv_relu_x)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_elu
+                R.output(gv)
+            return gv
+
+    verify_model(Elu(), input_info, {}, expected_elu)
+    verify_model(Elu2(), input_info, {}, expected_elu)
+
     # gelu
     class Gelu(Module):
         def __init__(self):
@@ -2086,6 +2126,36 @@ def test_extended_unary_ops():
     verify_model(Hardswish(), input_info, {}, expected_hardswish)
     verify_model(Hardswish2(), input_info, {}, expected_hardswish)
 
+    # hardtanh
+    class Hardtanh(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.ht = torch.nn.Hardtanh()
+
+        def forward(self, input):
+            return self.ht(input)
+
+    class Hardtanh2(torch.nn.Module):
+        def forward(self, input):
+            return torch.nn.functional.hardtanh(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    inp_0, R.prim_value(T.float64(-1.0)), 
R.prim_value(T.float64(1.0))
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Hardtanh(), input_info, {}, expected1)
+    verify_model(Hardtanh2(), input_info, {}, expected1)
+
     # logical_not
     class LogicalNot(Module):
         def forward(self, input):

Reply via email to