This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 176d01e612 [Relax][PyTorch] Support more unary ops for ExportedProgram 
importer (#17421)
176d01e612 is described below

commit 176d01e61276b0e94910fd904363ef4cd91fb8b5
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Sep 28 05:12:17 2024 +0900

    [Relax][PyTorch] Support more unary ops for ExportedProgram importer 
(#17421)
    
    * support more unary ops
    
    * support clamp
    
    * support gelu
    
    * support hardsigmoid
    
    * support hardswish
    
    * support hardtanh
    
    * support leaky_relu
    
    * support log_softmax
    
    * support round
    
    * support softmax
    
    * support tril and triu
    
    * skip flaky test
---
 .../frontend/torch/base_fx_graph_translator.py     |  74 +++
 .../frontend/torch/exported_program_translator.py  |  38 ++
 python/tvm/relax/frontend/torch/fx_translator.py   |  74 ---
 .../relax/test_frontend_from_exported_program.py   | 705 ++++++++++++++++++++-
 tests/python/relay/test_to_mixed_precision.py      |   1 +
 5 files changed, 812 insertions(+), 80 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6a001b5a04..d52b3d598f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -111,6 +111,80 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    def _clamp(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        a_min = args[1] if len(args) > 1 else node.kwargs["min"]
+        a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+        if not isinstance(a_min, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant min value for torch.clamp/clip, "
+                f"but got {a_min} with type {type(a_min)}"
+            )
+        if not isinstance(a_max, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant max value for torch.clamp/clip, "
+                f"but got {a_max} with type {type(a_max)}"
+            )
+        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+    def _gelu(self, node: fx.Node) -> relax.Expr:
+        approximate = node.kwargs.get("approximate", "none")
+        if approximate == "none":
+            return 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
+        elif approximate == "tanh":
+            return 
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
+        else:
+            raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
+
+    def _hardsigmoid(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        x0 = relax.op.add(x, relax.const(3, dtype))
+        x1 = relax.op.clip(x0, 0, 6)
+        return self.block_builder.emit(relax.op.divide(x1, relax.const(6, 
dtype)))
+
+    def _hardswish(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        x0 = relax.op.add(x, relax.const(3, dtype))
+        x1 = relax.op.clip(x0, 0, 6)
+        x2 = relax.op.divide(x1, relax.const(6, dtype))
+        return self.block_builder.emit(relax.op.multiply(x, x2))
+
+    def _leakyrelu(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("negative_slope", 0.01)
+        return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+    def _log_softmax(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
+
+    def _round(self, node: fx.Node) -> relax.Expr:
+        if node.kwargs.get("decimals", 0) != 0:
+            raise ValueError("specifying decimals for round is not supported 
yet")
+        arg = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.round(arg))
+
+    def _softmax(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
+        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+    def _tril_triu(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            x = self.env[node.args[0]]
+            k = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("diagonal", 0)
+            assert isinstance(k, int)
+            return self.block_builder.emit(op(x, k))
+
+        return convert
+
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 9af422d1c3..1ceddad7d7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -64,13 +64,51 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         return parameters_buffers_constants, user_inputs
 
+    ########## Unary Ops ##########
+
+    def _hardtanh(self, node: fx.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        x = args[0]
+        min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", 
-1.0)
+        max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 
1.0)
+        return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
+
     def create_convert_map(
         self,
     ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
         return {
             # unary
+            "acos.default": self._unary_op(relax.op.acos),
+            "acosh.default": self._unary_op(relax.op.acosh),
+            "asin.default": self._unary_op(relax.op.asin),
+            "asinh.default": self._unary_op(relax.op.asinh),
+            "atan.default": self._unary_op(relax.op.atan),
+            "atanh.default": self._unary_op(relax.op.atanh),
+            "clamp.default": self._clamp,
+            "cos.default": self._unary_op(relax.op.cos),
+            "cosh.default": self._unary_op(relax.op.cosh),
             "dropout.default": lambda node: self.env[node.args[0]],
+            "exp.default": self._unary_op(relax.op.exp),
+            "gelu.default": self._gelu,
+            "hardsigmoid.default": self._hardsigmoid,
+            "hardswish.default": self._hardswish,
+            "hardtanh.default": self._hardtanh,
+            "leaky_relu.default": self._leakyrelu,
+            "log_softmax.int": self._log_softmax,
+            "neg.default": self._unary_op(relax.op.negative),
             "relu.default": self._unary_op(relax.op.nn.relu),
+            "round.default": self._round,
+            "rsqrt.default": self._unary_op(relax.op.rsqrt),
+            "sigmoid.default": self._unary_op(relax.op.sigmoid),
+            "silu.default": self._unary_op(relax.op.nn.silu),
+            "sin.default": self._unary_op(relax.op.sin),
+            "sinh.default": self._unary_op(relax.op.sinh),
+            "softmax.int": self._softmax,
+            "sqrt.default": self._unary_op(relax.op.sqrt),
+            "tan.default": self._unary_op(relax.op.tan),
+            "tanh.default": self._unary_op(relax.op.tanh),
+            "tril.default": self._tril_triu(relax.op.tril),
+            "triu.default": self._tril_triu(relax.op.triu),
             # neural network
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
             "conv2d.default": self._conv2d,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index ec53cf23ed..6f7c6fa2c5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -62,64 +62,12 @@ class TorchFXImporter(BaseFXGraphImporter):
 
     ########## Unary Ops ##########
 
-    def _clamp(self, node: fx.Node) -> relax.Expr:
-        args = self.retrieve_args(node)
-        a_min = args[1] if len(args) > 1 else node.kwargs["min"]
-        a_max = args[2] if len(args) > 2 else node.kwargs["max"]
-        if not isinstance(a_min, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant min value for torch.clamp/clip, "
-                f"but got {a_min} with type {type(a_min)}"
-            )
-        if not isinstance(a_max, (int, float)):
-            raise ValueError(
-                f"TVM only supports constant max value for torch.clamp/clip, "
-                f"but got {a_max} with type {type(a_max)}"
-            )
-        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
-
-    def _gelu(self, node: fx.Node) -> relax.Expr:
-        approximate = node.kwargs.get("approximate", "none")
-        if approximate == "none":
-            return 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
-        elif approximate == "tanh":
-            return 
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
-        else:
-            raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
-
-    def _hardsigmoid(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dtype = x.struct_info.dtype
-        x0 = relax.op.add(x, relax.const(3, dtype))
-        x1 = relax.op.clip(x0, 0, 6)
-        return self.block_builder.emit(relax.op.divide(x1, relax.const(6, 
dtype)))
-
-    def _hardswish(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        dtype = x.struct_info.dtype
-        x0 = relax.op.add(x, relax.const(3, dtype))
-        x1 = relax.op.clip(x0, 0, 6)
-        x2 = relax.op.divide(x1, relax.const(6, dtype))
-        return self.block_builder.emit(relax.op.multiply(x, x2))
-
-    def _leakyrelu(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        alpha = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("negative_slope", 0.01)
-        return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
-
     def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         alpha = module.negative_slope
         return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
 
-    def _log_softmax(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
-        return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
-
     def _log_softmax_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -127,17 +75,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         assert dim is not None
         return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
 
-    def _round(self, node: fx.Node) -> relax.Expr:
-        if node.kwargs.get("decimals", 0) != 0:
-            raise ValueError("specifying decimals for round is not supported 
yet")
-        arg = self.env[node.args[0]]
-        return self.block_builder.emit(relax.op.round(arg))
-
-    def _softmax(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
-        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
-
     def _softmax_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -159,17 +96,6 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return convert
 
-    def _tril_triu(self, op: Callable) -> Callable:
-        from torch import fx
-
-        def convert(node: fx.Node) -> relax.Var:
-            x = self.env[node.args[0]]
-            k = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("diagonal", 0)
-            assert isinstance(k, int)
-            return self.block_builder.emit(op(x, k))
-
-        return convert
-
     ########## Binary Ops ##########
 
     def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> 
Callable:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 112390fe60..6c17d96004 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -39,6 +39,166 @@ def verify_model(torch_model, example_args, binding, 
expected):
 def test_unary():
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
 
+    # acos
+    class Acos(Module):
+        def forward(self, input):
+            return torch.acos(input)
+
+    @tvm.script.ir_module
+    class expected_acos:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Acos(), example_args, {}, expected_acos)
+
+    # acosh
+    class Acosh(Module):
+        def forward(self, input):
+            return torch.acosh(input)
+
+    @tvm.script.ir_module
+    class expected_acosh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.acosh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Acosh(), example_args, {}, expected_acosh)
+
+    # asin
+    class Asin(Module):
+        def forward(self, input):
+            return torch.asin(input)
+
+    @tvm.script.ir_module
+    class expected_asin:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Asin(), example_args, {}, expected_asin)
+
+    # asinh
+    class Asinh(Module):
+        def forward(self, input):
+            return torch.asinh(input)
+
+    @tvm.script.ir_module
+    class expected_asinh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.asinh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Asinh(), example_args, {}, expected_asinh)
+
+    # atan
+    class Atan(Module):
+        def forward(self, input):
+            return torch.atan(input)
+
+    @tvm.script.ir_module
+    class expected_atan:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Atan(), example_args, {}, expected_atan)
+
+    # atanh
+    class Atanh(Module):
+        def forward(self, input):
+            return torch.atanh(input)
+
+    @tvm.script.ir_module
+    class expected_atanh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.atanh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Atanh(), example_args, {}, expected_atanh)
+
+    # cos
+    class Cos(Module):
+        def forward(self, input):
+            return torch.cos(input)
+
+    @tvm.script.ir_module
+    class expected_cos:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Cos(), example_args, {}, expected_cos)
+
+    # cosh
+    class Cosh(Module):
+        def forward(self, input):
+            return torch.cosh(input)
+
+    @tvm.script.ir_module
+    class expected_cosh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Cosh(), example_args, {}, expected_cosh)
+
     # dropout
     class Dropout1(Module):
         def __init__(self):
@@ -53,7 +213,7 @@ def test_unary():
             return torch.dropout(input, 0.5, train=True)
 
     @tvm.script.ir_module
-    class expected1:
+    class expected_dropout:
         @R.function
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -64,8 +224,47 @@ def test_unary():
                 R.output(gv)
             return gv
 
-    verify_model(Dropout1(), example_args, {}, expected1)
-    verify_model(Dropout2(), example_args, {}, expected1)
+    verify_model(Dropout1(), example_args, {}, expected_dropout)
+    verify_model(Dropout2(), example_args, {}, expected_dropout)
+
+    # exp
+    class Exp(Module):
+        def forward(self, input):
+            return torch.exp(input)
+
+    @tvm.script.ir_module
+    class expected_exp:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Exp(), example_args, {}, expected_exp)
+
+    # neg
+    class Neg(Module):
+        def forward(self, input):
+            return -input
+
+    @I.ir_module
+    class expected_neg:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.negative(inp_0)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Neg(), example_args, {}, expected_neg)
 
     # relu
     class ReLU0(Module):
@@ -81,7 +280,7 @@ def test_unary():
             return torch.nn.functional.relu(input)
 
     @tvm.script.ir_module
-    class expected:
+    class expected_relu:
         @R.function
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -93,8 +292,502 @@ def test_unary():
                 R.output(gv)
             return gv
 
-    verify_model(ReLU0(), example_args, {}, expected)
-    verify_model(ReLU1(), example_args, {}, expected)
+    verify_model(ReLU0(), example_args, {}, expected_relu)
+    verify_model(ReLU1(), example_args, {}, expected_relu)
+
+    # rsqrt
+    class Rsqrt(Module):
+        def forward(self, input):
+            return torch.rsqrt(input)
+
+    @I.ir_module
+    class expected_rsqrt:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Rsqrt(), example_args, {}, expected_rsqrt)
+
+    # sigmoid
+    class Sigmoid(Module):
+        def __init__(self):
+            super().__init__()
+            self.sigmoid = torch.nn.Sigmoid()
+
+        def forward(self, input):
+            return self.sigmoid(input)
+
+    class Sigmoid2(Module):
+        def forward(self, input):
+            return torch.sigmoid(input)
+
+    @tvm.script.ir_module
+    class expected_sigmoid:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.sigmoid(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Sigmoid(), example_args, {}, expected_sigmoid)
+    verify_model(Sigmoid2(), example_args, {}, expected_sigmoid)
+
+    # silu
+    class SiLU(Module):
+        def __init__(self):
+            super().__init__()
+            self.silu = torch.nn.SiLU()
+
+        def forward(self, input):
+            return self.silu(input)
+
+    class SiLU2(Module):
+        def forward(self, input):
+            return torch.nn.functional.silu(input)
+
+    @tvm.script.ir_module
+    class expected_silu:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.silu(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(SiLU(), example_args, {}, expected_silu)
+    verify_model(SiLU2(), example_args, {}, expected_silu)
+
+    # sin
+    class Sin(Module):
+        def forward(self, input: torch.Tensor):
+            return torch.sin(input)
+
+    @tvm.script.ir_module
+    class expected_sin:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Sin(), example_args, {}, expected_sin)
+
+    # sinh
+    class Sinh(Module):
+        def forward(self, input):
+            return torch.sinh(input)
+
+    @tvm.script.ir_module
+    class expected_sinh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Sinh(), example_args, {}, expected_sinh)
+
+    # sqrt
+    class Sqrt(Module):
+        def forward(self, input):
+            return torch.sqrt(input)
+
+    @tvm.script.ir_module
+    class expected_sqrt:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Sqrt(), example_args, {}, expected_sqrt)
+
+    # tan
+    class Tan(Module):
+        def forward(self, input):
+            return torch.tan(input)
+
+    @tvm.script.ir_module
+    class expected_tan:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Tan(), example_args, {}, expected_tan)
+
+    # tanh
+    class Tanh(Module):
+        def forward(self, input):
+            return torch.tanh(input)
+
+    @tvm.script.ir_module
+    class expected_tanh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Tanh(), example_args, {}, expected_tanh)
+
+
+def test_clamp():
+    class Clamp(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=0.1, max=0.5)
+
+    @tvm.script.ir_module
+    class expected_clamp:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.clip(input_1, 0.1, 0.5)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Clamp(), example_args, {}, expected_clamp)
+
+
+def test_gelu():
+    class Gelu(Module):
+        def __init__(self):
+            super().__init__()
+            self.gelu = torch.nn.GELU()
+
+        def forward(self, input):
+            return self.gelu(input)
+
+    class Gelu2(Module):
+        def forward(self, input):
+            return torch.nn.functional.gelu(input)
+
+    @tvm.script.ir_module
+    class expected_gelu:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.gelu(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Gelu(), example_args, {}, expected_gelu)
+    verify_model(Gelu2(), example_args, {}, expected_gelu)
+
+
+def test_hardsigmoid():
+    class Hardsigmoid(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.hs = torch.nn.Hardsigmoid()
+
+        def forward(self, input):
+            return self.hs(input)
+
+    class Hardsigmoid2(torch.nn.Module):
+        def forward(self, input):
+            return torch.nn.functional.hardsigmoid(input)
+
+    @tvm.script.ir_module
+    class expected_hardsigmoid:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, 
R.const(3, "float32"))
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 
6)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    lv1, R.const(6, "float32")
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
+    verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
+
+
+def test_hardswish():
+    class Hardswish(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.hs = torch.nn.Hardswish()
+
+        def forward(self, input):
+            return self.hs(input)
+
+    class Hardswish2(torch.nn.Module):
+        def forward(self, input):
+            return torch.nn.functional.hardswish(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, 
R.const(3, "float32"))
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 
6)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    lv1, R.const(6, "float32")
+                )
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(inp_0, lv2)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Hardswish(), example_args, {}, expected1)
+    verify_model(Hardswish2(), example_args, {}, expected1)
+
+
+def test_hardtanh():
+    class Hardtanh(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.ht = torch.nn.Hardtanh()
+
+        def forward(self, input):
+            return self.ht(input)
+
+    class Hardtanh2(torch.nn.Module):
+        def forward(self, input):
+            return torch.nn.functional.hardtanh(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    inp_0, R.prim_value(T.float64(-1.0)), 
R.prim_value(T.float64(1.0))
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Hardtanh(), example_args, {}, expected1)
+    verify_model(Hardtanh2(), example_args, {}, expected1)
+
+
+def test_leakyrelu():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+
+    class LeakyReLU0(Module):
+        def __init__(self):
+            super().__init__()
+            self.leakyrelu = torch.nn.LeakyReLU(0.02)
+
+        def forward(self, input):
+            return self.leakyrelu(input)
+
+    class LeakyReLU1(Module):
+        def forward(self, input):
+            return torch.nn.functional.leaky_relu(input, 0.02)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.leakyrelu(input_1, 0.02)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(LeakyReLU0(), example_args, {}, expected)
+    verify_model(LeakyReLU1(), example_args, {}, expected)
+
+
+def test_logsoftmax():
+    class LogSoftmax(Module):
+        def __init__(self):
+            super().__init__()
+            self.lsm = torch.nn.LogSoftmax(dim=1)
+
+        def forward(self, input):
+            return self.lsm(input)
+
+    class LogSoftmax2(Module):
+        def forward(self, input):
+            return torch.nn.functional.log_softmax(input, dim=1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.log_softmax(input_1, axis=1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(LogSoftmax(), example_args, {}, expected1)
+    verify_model(LogSoftmax2(), example_args, {}, expected1)
+
+
+def test_round():
+    class Round(Module):
+        def forward(self, input):
+            return torch.round(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.round(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Round(), example_args, {}, expected)
+
+
+def test_softmax():
+    class Softmax(Module):
+        def __init__(self):
+            super().__init__()
+            self.sm = torch.nn.Softmax(dim=1)
+
+        def forward(self, input):
+            return self.sm(input)
+
+    class Softmax2(Module):
+        def forward(self, input):
+            return torch.nn.functional.softmax(input, dim=1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.softmax(input_1, axis=1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Softmax(), example_args, {}, expected1)
+    verify_model(Softmax2(), example_args, {}, expected1)
+
+
+def test_tril_triu():
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+
+    class Tril(Module):
+        def forward(self, input):
+            return torch.tril(input, 1)
+
+    @tvm.script.ir_module
+    class expected_tril:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Tril(), example_args, {}, expected_tril)
+
+    class Triu(Module):
+        def forward(self, input):
+            return torch.triu(input, 1)
+
+    @tvm.script.ir_module
+    class expected_triu:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(Triu(), example_args, {}, expected_triu)
 
 
 def test_adaptive_avgpool2d():
diff --git a/tests/python/relay/test_to_mixed_precision.py 
b/tests/python/relay/test_to_mixed_precision.py
index ae5172f6ca..a8032ce0d2 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -98,6 +98,7 @@ def test_lstm(target_precision):
     )
 
 
[email protected](reason="Flaky test")
 def test_lstm_float64():
     """Tests if can handle other mixed precision types.
 


Reply via email to