This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4a513825bc [relax][frontend]add relax frontend torch op: 
tan,asin,acos,atan,sinh,cosh,tanh,asinh,… (#15610)
4a513825bc is described below

commit 4a513825bcf29e685a1a0acf493667db8e2451ad
Author: HLearning <[email protected]>
AuthorDate: Fri Aug 25 14:35:32 2023 +0800

    [relax][frontend]add relax frontend torch op: 
tan,asin,acos,atan,sinh,cosh,tanh,asinh,… (#15610)
    
    add relax frontend torch op: 
tan,asin,acos,atan,sinh,cosh,tanh,asinh,acosh,atanh
    
    Co-authored-by: HLearning <[email protected]>
---
 python/tvm/relax/frontend/torch/fx_translator.py |  21 ++-
 tests/python/relax/test_frontend_from_fx.py      | 208 ++++++++++++++++++++++-
 2 files changed, 216 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index be95a4880b..b5cee77d11 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -138,15 +138,9 @@ class TorchFXImporter:
 
     ########## Arithmetic ##########
 
-    def _cos(self, node: fx.node.Node) -> relax.Var:
-        return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
-
     def _exp(self, node: fx.node.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
 
-    def _sin(self, node: fx.node.Node) -> relax.Var:
-        return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
-
     def _sigmoid(self, node: fx.node.Node) -> relax.Var:
         return 
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
 
@@ -1291,9 +1285,19 @@ class TorchFXImporter:
             nn.modules.sparse.Embedding: self._embedding,
             nn.CrossEntropyLoss: self._cross_entropy,
             # call_function and call_method
-            "cos": self._cos,
+            "sin": lambda node: 
self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
+            "cos": lambda node: 
self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
+            "tan": lambda node: 
self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
+            "asin": lambda node: 
self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
+            "acos": lambda node: 
self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
+            "atan": lambda node: 
self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
+            "sinh": lambda node: 
self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
+            "cosh": lambda node: 
self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
+            "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
+            "asinh": lambda node: 
self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
+            "acosh": lambda node: 
self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
+            "atanh": lambda node: 
self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
             "exp": self._exp,
-            "sin": self._sin,
             "iadd": self._add,
             "add": self._add,
             "floordiv": self._floordiv,
@@ -1350,7 +1354,6 @@ class TorchFXImporter:
             "leaky_relu": self._leakyrelu,
             "gelu": self._gelu,
             "silu": lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
-            "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
             "interpolate": self._interpolate,
             "size": self._size,
             "getattr": self._getattr,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2b95d3897d..ec312767b4 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1807,7 +1807,7 @@ def test_unary():
             return torch.sin(input)
 
     @tvm.script.ir_module
-    class expected1:
+    class expected_sin:
         @R.function
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -1819,7 +1819,7 @@ def test_unary():
                 R.output(gv)
             return gv
 
-    verify_model(Sin(), input_info, {}, expected1)
+    verify_model(Sin(), input_info, {}, expected_sin)
 
     # cos
     class Cos(Module):
@@ -1827,7 +1827,7 @@ def test_unary():
             return torch.cos(input)
 
     @tvm.script.ir_module
-    class expected2:
+    class expected_cos:
         @R.function
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -1839,7 +1839,207 @@ def test_unary():
                 R.output(gv)
             return gv
 
-    verify_model(Cos(), input_info, {}, expected2)
+    verify_model(Cos(), input_info, {}, expected_cos)
+
+    # tan
+    class Tan(Module):
+        def forward(self, input):
+            return torch.tan(input)
+
+    @tvm.script.ir_module
+    class expected_tan:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Tan(), input_info, {}, expected_tan)
+
+    # asin
+    class Asin(Module):
+        def forward(self, input):
+            return torch.asin(input)
+
+    @tvm.script.ir_module
+    class expected_asin:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Asin(), input_info, {}, expected_asin)
+
+    # acos
+    class Acos(Module):
+        def forward(self, input):
+            return torch.acos(input)
+
+    @tvm.script.ir_module
+    class expected_acos:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Acos(), input_info, {}, expected_acos)
+
+    # atan
+    class Atan(Module):
+        def forward(self, input):
+            return torch.atan(input)
+
+    @tvm.script.ir_module
+    class expected_atan:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Atan(), input_info, {}, expected_atan)
+
+    # sinh
+    class Sinh(Module):
+        def forward(self, input):
+            return torch.sinh(input)
+
+    @tvm.script.ir_module
+    class expected_sinh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sinh(), input_info, {}, expected_sinh)
+
+    # cosh
+    class Cosh(Module):
+        def forward(self, input):
+            return torch.cosh(input)
+
+    @tvm.script.ir_module
+    class expected_cosh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Cosh(), input_info, {}, expected_cosh)
+
+    # tanh
+    class Tanh(Module):
+        def forward(self, input):
+            return torch.tanh(input)
+
+    @tvm.script.ir_module
+    class expected_tanh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Tanh(), input_info, {}, expected_tanh)
+
+    # asinh
+    class Asinh(Module):
+        def forward(self, input):
+            return torch.asinh(input)
+
+    @tvm.script.ir_module
+    class expected_asinh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.asinh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Asinh(), input_info, {}, expected_asinh)
+
+    # acosh
+    class Acosh(Module):
+        def forward(self, input):
+            return torch.acosh(input)
+
+    @tvm.script.ir_module
+    class expected_acosh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.acosh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Acosh(), input_info, {}, expected_acosh)
+
+    # atanh
+    class Atanh(Module):
+        def forward(self, input):
+            return torch.atanh(input)
+
+    @tvm.script.ir_module
+    class expected_atanh:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.atanh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Atanh(), input_info, {}, expected_atanh)
 
     # exp
     class Exp(Module):

Reply via email to