This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 16ec246df9 [Relax][PyTorch] Add support for celu, selu, 
is_floating_point ops (#17702)
16ec246df9 is described below

commit 16ec246df91cf664cf88c12c068f140124603ead
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Mar 4 12:35:52 2025 +0800

    [Relax][PyTorch] Add support for celu, selu, is_floating_point ops (#17702)
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * lint
    
    * lint
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
---
 .../frontend/torch/base_fx_graph_translator.py     |  65 ++++++++++++-
 python/tvm/relax/frontend/torch/fx_translator.py   |  11 +++
 tests/python/relax/test_frontend_from_fx.py        | 103 ++++++++++++++++++++-
 3 files changed, 175 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e601f18181..4ce899685a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -111,6 +111,34 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    def _celu(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("alpha", 1.0)
+        dtype = x.struct_info.dtype
+
+        if isinstance(alpha, (int, float)):
+            alpha = relax.const(alpha, dtype)
+        else:
+            if not isinstance(alpha, relax.Var):
+                alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+        zero = relax.const(0, dtype)
+        # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
+        return self.block_builder.emit(
+            relax.op.add(
+                relax.op.multiply(
+                    alpha,
+                    relax.op.minimum(
+                        zero,
+                        relax.op.subtract(
+                            relax.op.divide(relax.op.exp(x), alpha), 
relax.const(1, dtype)
+                        ),
+                    ),
+                ),
+                relax.op.nn.relu(x),
+            )
+        )
+
     def _clamp(self, node: fx.Node) -> relax.Expr:
         args = self.retrieve_args(node)
         a_min = args[1] if len(args) > 1 else node.kwargs["min"]
@@ -133,12 +161,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dtype = x.struct_info.dtype
 
         if isinstance(alpha, (int, float)):
-            alpha = relax.const(alpha, dtype)
+            alpha = relax.const(-alpha, dtype)
         else:
             if not isinstance(alpha, relax.Var):
-                alpha = self.block_builder.emit(relax.const(alpha, dtype))
+                alpha = self.block_builder.emit(relax.const(-alpha, dtype))
 
-        # α⋅ReLU(1−exp(x))+ReLU(x)
+        # alpha * ReLU(1 − exp(x)) + ReLU(x)
         return self.block_builder.emit(
             relax.op.add(
                 relax.op.multiply(
@@ -203,6 +231,37 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         return self.block_builder.emit(relax.op.nn.softmax(x, dim))
 
+    def _selu(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("alpha", 1.6732631921768188)
+        gamma = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("gamma", 1.0507009873554805)
+        dtype = x.struct_info.dtype
+
+        if isinstance(alpha, (int, float)):
+            alpha = relax.const(alpha, dtype)
+        else:
+            if not isinstance(alpha, relax.Var):
+                alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+        if isinstance(gamma, (int, float)):
+            gamma = relax.const(gamma, dtype)
+        else:
+            if not isinstance(gamma, relax.Var):
+                gamma = self.block_builder.emit(relax.const(gamma, dtype))
+
+        # gamma * (ReLU(x) + alpha * (exp(x) - 1))
+        return self.block_builder.emit(
+            relax.op.multiply(
+                gamma,
+                relax.op.add(
+                    relax.op.nn.relu(x),
+                    relax.op.multiply(
+                        alpha, relax.op.subtract(relax.op.exp(x), 
relax.const(1, dtype))
+                    ),
+                ),
+            )
+        )
+
     def _tril_triu(self, op: Callable) -> Callable:
         from torch import fx
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index bbad7c0c70..af84f71bbf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -527,6 +527,12 @@ class TorchFXImporter(BaseFXGraphImporter):
     def _half(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
 
+    def _is_floating_point(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return relax.const(
+            x.struct_info.dtype in ["float16", "float32", "float64", 
"bfloat16"], "bool"
+        )
+
     def _to(self, node: fx.Node) -> relax.Var:
         import torch
 
@@ -580,6 +586,7 @@ class TorchFXImporter(BaseFXGraphImporter):
         return {
             ## call_module
             # unary
+            nn.CELU: self._celu,
             nn.Dropout: lambda node: self.env[node.args[0]],
             nn.ELU: self._elu,
             nn.GELU: self._gelu,
@@ -594,6 +601,7 @@ class TorchFXImporter(BaseFXGraphImporter):
                 relax.op.clip(self.env[node.args[0]], 0, 6)
             ),
             nn.Sigmoid: self._unary_op(relax.op.sigmoid),
+            nn.SELU: self._selu,
             nn.SiLU: self._unary_op(relax.op.nn.silu),
             nn.Softmax: self._softmax_module,
             nn.Tanh: self._unary_op(relax.op.tanh),
@@ -625,6 +633,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "atanh": self._unary_op(relax.op.atanh),
             "bitwise_not": self._unary_op(relax.op.bitwise_not),
             "ceil": self._unary_op(relax.op.ceil),
+            "celu": self._celu,
             "clamp": self._clamp,
             "cos": self._unary_op(relax.op.cos),
             "cosh": self._unary_op(relax.op.cosh),
@@ -648,6 +657,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "relu": self._unary_op(relax.op.nn.relu),
             "round": self._round,
             "rsqrt": self._unary_op(relax.op.rsqrt),
+            "selu": self._selu,
             "sigmoid": self._unary_op(relax.op.sigmoid),
             "sign": self._unary_op(relax.op.sign),
             "silu": self._unary_op(relax.op.nn.silu),
@@ -753,6 +763,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "astype": self._type,
             "float": self._float,
             "half": self._half,
+            "is_floating_point": self._is_floating_point,
             "to": self._to,
             "type": self._type,
             # other
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 797ce05a3f..e9fa796531 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1917,6 +1917,50 @@ def test_bool_unary_ops(pytorch_op, relax_op):
 def test_extended_unary_ops():
     input_info = [([1, 3, 10, 10], "float32")]
 
+    # celu
+    class Celu1(Module):
+        def __init__(self):
+            super().__init__()
+            self.celu = torch.nn.CELU()
+
+        def forward(self, input):
+            return self.celu(input)
+
+    class Celu2(Module):
+        def forward(self, input):
+            return torch.nn.functional.celu(input)
+
+    # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
+    @tvm.script.ir_module
+    class expected_celu:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
+                lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    lv, R.const(1.0, "float32")
+                )
+                lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+                    lv_div, R.const(1.0, "float32")
+                )
+                lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(
+                    R.const(0.0, "float32"), lv_sub
+                )
+                lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
+                    R.const(1.0, "float32"), lv_min
+                )
+                lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
+                lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_scaled, lv_relu_x)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_celu
+                R.output(gv)
+            return gv
+
+    verify_model(Celu1(), input_info, {}, expected_celu)
+    verify_model(Celu2(), input_info, {}, expected_celu)
+
     # clamp
     class Clamp(Module):
         def forward(self, input):
@@ -2018,7 +2062,7 @@ def test_extended_unary_ops():
                     lv_one_minus_exp
                 )
                 lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
-                    R.const(1.0, dtype="float32"), lv_relu_one_minus_exp
+                    R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp
                 )
                 lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
                 lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_scaled, lv_relu_x)
@@ -2256,6 +2300,46 @@ def test_extended_unary_ops():
 
     verify_model(ReLU6(), input_info, {}, expected_relu6)
 
+    # selu
+    class Selu1(Module):
+        def __init__(self):
+            super().__init__()
+            self.selu = torch.nn.SELU()
+
+        def forward(self, input):
+            return self.selu(input)
+
+    class Selu2(Module):
+        def forward(self, input):
+            return torch.nn.functional.selu(input)
+
+    @tvm.script.ir_module
+    class expected_selu:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
+                lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.exp(input_1)
+                lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+                    lv_exp, R.const(1.0, "float32")
+                )
+                lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
+                    R.const(1.6732631921768188, "float32"), lv_sub
+                )
+                lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.add(lv_relu, lv_scaled)
+                lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(
+                    R.const(1.0507009873554805, "float32"), lv_add
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu
+                R.output(gv)
+            return gv
+
+    verify_model(Selu1(), input_info, {}, expected_selu)
+    verify_model(Selu2(), input_info, {}, expected_selu)
+
     # sigmoid
     class Sigmoid(Module):
         def __init__(self):
@@ -3802,5 +3886,22 @@ def test_masked_scatter():
     )
 
 
+def test_is_floating_point():
+    class IsFloatingPoint(Module):
+        def forward(self, x):
+            return torch.is_floating_point(x)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((), 
dtype="bool"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="bool") = R.const(True, "bool")
+                R.output(gv)
+            return gv
+
+    verify_model(IsFloatingPoint(), [([2, 3], "float32")], {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to