This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4a513825bc [relax][frontend]add relax frontend torch op:
tan,asin,acos,atan,sinh,cosh,tanh,asinh,… (#15610)
4a513825bc is described below
commit 4a513825bcf29e685a1a0acf493667db8e2451ad
Author: HLearning <[email protected]>
AuthorDate: Fri Aug 25 14:35:32 2023 +0800
[relax][frontend]add relax frontend torch op:
tan,asin,acos,atan,sinh,cosh,tanh,asinh,… (#15610)
add relax frontend torch op:
tan,asin,acos,atan,sinh,cosh,tanh,asinh,acosh,atanh
Co-authored-by: HLearning <[email protected]>
---
python/tvm/relax/frontend/torch/fx_translator.py | 21 ++-
tests/python/relax/test_frontend_from_fx.py | 208 ++++++++++++++++++++++-
2 files changed, 216 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index be95a4880b..b5cee77d11 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -138,15 +138,9 @@ class TorchFXImporter:
########## Arithmetic ##########
- def _cos(self, node: fx.node.Node) -> relax.Var:
- return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
-
def _exp(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
- def _sin(self, node: fx.node.Node) -> relax.Var:
- return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
-
def _sigmoid(self, node: fx.node.Node) -> relax.Var:
return
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
@@ -1291,9 +1285,19 @@ class TorchFXImporter:
nn.modules.sparse.Embedding: self._embedding,
nn.CrossEntropyLoss: self._cross_entropy,
# call_function and call_method
- "cos": self._cos,
+ "sin": lambda node:
self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
+ "cos": lambda node:
self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
+ "tan": lambda node:
self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
+ "asin": lambda node:
self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
+ "acos": lambda node:
self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
+ "atan": lambda node:
self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
+ "sinh": lambda node:
self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
+ "cosh": lambda node:
self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
+ "tanh": lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
+ "asinh": lambda node:
self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
+ "acosh": lambda node:
self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
+ "atanh": lambda node:
self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
"exp": self._exp,
- "sin": self._sin,
"iadd": self._add,
"add": self._add,
"floordiv": self._floordiv,
@@ -1350,7 +1354,6 @@ class TorchFXImporter:
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
- "tanh": lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2b95d3897d..ec312767b4 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1807,7 +1807,7 @@ def test_unary():
return torch.sin(input)
@tvm.script.ir_module
- class expected1:
+ class expected_sin:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -1819,7 +1819,7 @@ def test_unary():
R.output(gv)
return gv
- verify_model(Sin(), input_info, {}, expected1)
+ verify_model(Sin(), input_info, {}, expected_sin)
# cos
class Cos(Module):
@@ -1827,7 +1827,7 @@ def test_unary():
return torch.cos(input)
@tvm.script.ir_module
- class expected2:
+ class expected_cos:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -1839,7 +1839,207 @@ def test_unary():
R.output(gv)
return gv
- verify_model(Cos(), input_info, {}, expected2)
+ verify_model(Cos(), input_info, {}, expected_cos)
+
+ # tan
+ class Tan(Module):
+ def forward(self, input):
+ return torch.tan(input)
+
+ @tvm.script.ir_module
+ class expected_tan:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Tan(), input_info, {}, expected_tan)
+
+ # asin
+ class Asin(Module):
+ def forward(self, input):
+ return torch.asin(input)
+
+ @tvm.script.ir_module
+ class expected_asin:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Asin(), input_info, {}, expected_asin)
+
+ # acos
+ class Acos(Module):
+ def forward(self, input):
+ return torch.acos(input)
+
+ @tvm.script.ir_module
+ class expected_acos:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Acos(), input_info, {}, expected_acos)
+
+ # atan
+ class Atan(Module):
+ def forward(self, input):
+ return torch.atan(input)
+
+ @tvm.script.ir_module
+ class expected_atan:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Atan(), input_info, {}, expected_atan)
+
+ # sinh
+ class Sinh(Module):
+ def forward(self, input):
+ return torch.sinh(input)
+
+ @tvm.script.ir_module
+ class expected_sinh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sinh(), input_info, {}, expected_sinh)
+
+ # cosh
+ class Cosh(Module):
+ def forward(self, input):
+ return torch.cosh(input)
+
+ @tvm.script.ir_module
+ class expected_cosh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Cosh(), input_info, {}, expected_cosh)
+
+ # tanh
+ class Tanh(Module):
+ def forward(self, input):
+ return torch.tanh(input)
+
+ @tvm.script.ir_module
+ class expected_tanh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Tanh(), input_info, {}, expected_tanh)
+
+ # asinh
+ class Asinh(Module):
+ def forward(self, input):
+ return torch.asinh(input)
+
+ @tvm.script.ir_module
+ class expected_asinh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.asinh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Asinh(), input_info, {}, expected_asinh)
+
+ # acosh
+ class Acosh(Module):
+ def forward(self, input):
+ return torch.acosh(input)
+
+ @tvm.script.ir_module
+ class expected_acosh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.acosh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Acosh(), input_info, {}, expected_acosh)
+
+ # atanh
+ class Atanh(Module):
+ def forward(self, input):
+ return torch.atanh(input)
+
+ @tvm.script.ir_module
+ class expected_atanh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.atanh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Atanh(), input_info, {}, expected_atanh)
# exp
class Exp(Module):