This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 176d01e612 [Relax][PyTorch] Support more unary ops for ExportedProgram
importer (#17421)
176d01e612 is described below
commit 176d01e61276b0e94910fd904363ef4cd91fb8b5
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Sep 28 05:12:17 2024 +0900
[Relax][PyTorch] Support more unary ops for ExportedProgram importer
(#17421)
* support more unary ops
* support clamp
* support gelu
* support hardsigmoid
* support hardswish
* support hardtanh
* support leaky_relu
* support log_softmax
* support round
* support softmax
* support tril and triu
* skip flaky test
---
.../frontend/torch/base_fx_graph_translator.py | 74 +++
.../frontend/torch/exported_program_translator.py | 38 ++
python/tvm/relax/frontend/torch/fx_translator.py | 74 ---
.../relax/test_frontend_from_exported_program.py | 705 ++++++++++++++++++++-
tests/python/relay/test_to_mixed_precision.py | 1 +
5 files changed, 812 insertions(+), 80 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6a001b5a04..d52b3d598f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -111,6 +111,80 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ def _clamp(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ a_min = args[1] if len(args) > 1 else node.kwargs["min"]
+ a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+ if not isinstance(a_min, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant min value for torch.clamp/clip, "
+ f"but got {a_min} with type {type(a_min)}"
+ )
+ if not isinstance(a_max, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant max value for torch.clamp/clip, "
+ f"but got {a_max} with type {type(a_max)}"
+ )
+ return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+ def _gelu(self, node: fx.Node) -> relax.Expr:
+ approximate = node.kwargs.get("approximate", "none")
+ if approximate == "none":
+ return
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
+ elif approximate == "tanh":
+ return
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
+ else:
+ raise KeyError("Unregonized approximate algorithm for gelu:
{}.".format(approximate))
+
+ def _hardsigmoid(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ x0 = relax.op.add(x, relax.const(3, dtype))
+ x1 = relax.op.clip(x0, 0, 6)
+ return self.block_builder.emit(relax.op.divide(x1, relax.const(6,
dtype)))
+
+ def _hardswish(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ x0 = relax.op.add(x, relax.const(3, dtype))
+ x1 = relax.op.clip(x0, 0, 6)
+ x2 = relax.op.divide(x1, relax.const(6, dtype))
+ return self.block_builder.emit(relax.op.multiply(x, x2))
+
+ def _leakyrelu(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("negative_slope", 0.01)
+ return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+ def _log_softmax(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
+
+ def _round(self, node: fx.Node) -> relax.Expr:
+ if node.kwargs.get("decimals", 0) != 0:
+ raise ValueError("specifying decimals for round is not supported
yet")
+ arg = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.round(arg))
+
+ def _softmax(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+ def _tril_triu(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else
node.kwargs.get("diagonal", 0)
+ assert isinstance(k, int)
+ return self.block_builder.emit(op(x, k))
+
+ return convert
+
########## Neural Network ##########
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 9af422d1c3..1ceddad7d7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -64,13 +64,51 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return parameters_buffers_constants, user_inputs
+ ########## Unary Ops ##########
+
+ def _hardtanh(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ x = args[0]
+ min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val",
-1.0)
+ max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val",
1.0)
+ return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
+
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
return {
# unary
+ "acos.default": self._unary_op(relax.op.acos),
+ "acosh.default": self._unary_op(relax.op.acosh),
+ "asin.default": self._unary_op(relax.op.asin),
+ "asinh.default": self._unary_op(relax.op.asinh),
+ "atan.default": self._unary_op(relax.op.atan),
+ "atanh.default": self._unary_op(relax.op.atanh),
+ "clamp.default": self._clamp,
+ "cos.default": self._unary_op(relax.op.cos),
+ "cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
+ "exp.default": self._unary_op(relax.op.exp),
+ "gelu.default": self._gelu,
+ "hardsigmoid.default": self._hardsigmoid,
+ "hardswish.default": self._hardswish,
+ "hardtanh.default": self._hardtanh,
+ "leaky_relu.default": self._leakyrelu,
+ "log_softmax.int": self._log_softmax,
+ "neg.default": self._unary_op(relax.op.negative),
"relu.default": self._unary_op(relax.op.nn.relu),
+ "round.default": self._round,
+ "rsqrt.default": self._unary_op(relax.op.rsqrt),
+ "sigmoid.default": self._unary_op(relax.op.sigmoid),
+ "silu.default": self._unary_op(relax.op.nn.silu),
+ "sin.default": self._unary_op(relax.op.sin),
+ "sinh.default": self._unary_op(relax.op.sinh),
+ "softmax.int": self._softmax,
+ "sqrt.default": self._unary_op(relax.op.sqrt),
+ "tan.default": self._unary_op(relax.op.tan),
+ "tanh.default": self._unary_op(relax.op.tanh),
+ "tril.default": self._tril_triu(relax.op.tril),
+ "triu.default": self._tril_triu(relax.op.triu),
# neural network
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"conv2d.default": self._conv2d,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index ec53cf23ed..6f7c6fa2c5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -62,64 +62,12 @@ class TorchFXImporter(BaseFXGraphImporter):
########## Unary Ops ##########
- def _clamp(self, node: fx.Node) -> relax.Expr:
- args = self.retrieve_args(node)
- a_min = args[1] if len(args) > 1 else node.kwargs["min"]
- a_max = args[2] if len(args) > 2 else node.kwargs["max"]
- if not isinstance(a_min, (int, float)):
- raise ValueError(
- f"TVM only supports constant min value for torch.clamp/clip, "
- f"but got {a_min} with type {type(a_min)}"
- )
- if not isinstance(a_max, (int, float)):
- raise ValueError(
- f"TVM only supports constant max value for torch.clamp/clip, "
- f"but got {a_max} with type {type(a_max)}"
- )
- return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
-
- def _gelu(self, node: fx.Node) -> relax.Expr:
- approximate = node.kwargs.get("approximate", "none")
- if approximate == "none":
- return
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
- elif approximate == "tanh":
- return
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
- else:
- raise KeyError("Unregonized approximate algorithm for gelu:
{}.".format(approximate))
-
- def _hardsigmoid(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dtype = x.struct_info.dtype
- x0 = relax.op.add(x, relax.const(3, dtype))
- x1 = relax.op.clip(x0, 0, 6)
- return self.block_builder.emit(relax.op.divide(x1, relax.const(6,
dtype)))
-
- def _hardswish(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dtype = x.struct_info.dtype
- x0 = relax.op.add(x, relax.const(3, dtype))
- x1 = relax.op.clip(x0, 0, 6)
- x2 = relax.op.divide(x1, relax.const(6, dtype))
- return self.block_builder.emit(relax.op.multiply(x, x2))
-
- def _leakyrelu(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("negative_slope", 0.01)
- return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
-
def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
alpha = module.negative_slope
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
- def _log_softmax(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
- return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
-
def _log_softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -127,17 +75,6 @@ class TorchFXImporter(BaseFXGraphImporter):
assert dim is not None
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
- def _round(self, node: fx.Node) -> relax.Expr:
- if node.kwargs.get("decimals", 0) != 0:
- raise ValueError("specifying decimals for round is not supported
yet")
- arg = self.env[node.args[0]]
- return self.block_builder.emit(relax.op.round(arg))
-
- def _softmax(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
- return self.block_builder.emit(relax.op.nn.softmax(x, dim))
-
def _softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -159,17 +96,6 @@ class TorchFXImporter(BaseFXGraphImporter):
return convert
- def _tril_triu(self, op: Callable) -> Callable:
- from torch import fx
-
- def convert(node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- k = node.args[1] if len(node.args) > 1 else
node.kwargs.get("diagonal", 0)
- assert isinstance(k, int)
- return self.block_builder.emit(op(x, k))
-
- return convert
-
########## Binary Ops ##########
def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) ->
Callable:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 112390fe60..6c17d96004 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -39,6 +39,166 @@ def verify_model(torch_model, example_args, binding,
expected):
def test_unary():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ # acos
+ class Acos(Module):
+ def forward(self, input):
+ return torch.acos(input)
+
+ @tvm.script.ir_module
+ class expected_acos:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Acos(), example_args, {}, expected_acos)
+
+ # acosh
+ class Acosh(Module):
+ def forward(self, input):
+ return torch.acosh(input)
+
+ @tvm.script.ir_module
+ class expected_acosh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.acosh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Acosh(), example_args, {}, expected_acosh)
+
+ # asin
+ class Asin(Module):
+ def forward(self, input):
+ return torch.asin(input)
+
+ @tvm.script.ir_module
+ class expected_asin:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Asin(), example_args, {}, expected_asin)
+
+ # asinh
+ class Asinh(Module):
+ def forward(self, input):
+ return torch.asinh(input)
+
+ @tvm.script.ir_module
+ class expected_asinh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.asinh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Asinh(), example_args, {}, expected_asinh)
+
+ # atan
+ class Atan(Module):
+ def forward(self, input):
+ return torch.atan(input)
+
+ @tvm.script.ir_module
+ class expected_atan:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Atan(), example_args, {}, expected_atan)
+
+ # atanh
+ class Atanh(Module):
+ def forward(self, input):
+ return torch.atanh(input)
+
+ @tvm.script.ir_module
+ class expected_atanh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.atanh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Atanh(), example_args, {}, expected_atanh)
+
+ # cos
+ class Cos(Module):
+ def forward(self, input):
+ return torch.cos(input)
+
+ @tvm.script.ir_module
+ class expected_cos:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Cos(), example_args, {}, expected_cos)
+
+ # cosh
+ class Cosh(Module):
+ def forward(self, input):
+ return torch.cosh(input)
+
+ @tvm.script.ir_module
+ class expected_cosh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Cosh(), example_args, {}, expected_cosh)
+
# dropout
class Dropout1(Module):
def __init__(self):
@@ -53,7 +213,7 @@ def test_unary():
return torch.dropout(input, 0.5, train=True)
@tvm.script.ir_module
- class expected1:
+ class expected_dropout:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -64,8 +224,47 @@ def test_unary():
R.output(gv)
return gv
- verify_model(Dropout1(), example_args, {}, expected1)
- verify_model(Dropout2(), example_args, {}, expected1)
+ verify_model(Dropout1(), example_args, {}, expected_dropout)
+ verify_model(Dropout2(), example_args, {}, expected_dropout)
+
+ # exp
+ class Exp(Module):
+ def forward(self, input):
+ return torch.exp(input)
+
+ @tvm.script.ir_module
+ class expected_exp:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Exp(), example_args, {}, expected_exp)
+
+ # neg
+ class Neg(Module):
+ def forward(self, input):
+ return -input
+
+ @I.ir_module
+ class expected_neg:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.negative(inp_0)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Neg(), example_args, {}, expected_neg)
# relu
class ReLU0(Module):
@@ -81,7 +280,7 @@ def test_unary():
return torch.nn.functional.relu(input)
@tvm.script.ir_module
- class expected:
+ class expected_relu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -93,8 +292,502 @@ def test_unary():
R.output(gv)
return gv
- verify_model(ReLU0(), example_args, {}, expected)
- verify_model(ReLU1(), example_args, {}, expected)
+ verify_model(ReLU0(), example_args, {}, expected_relu)
+ verify_model(ReLU1(), example_args, {}, expected_relu)
+
+ # rsqrt
+ class Rsqrt(Module):
+ def forward(self, input):
+ return torch.rsqrt(input)
+
+ @I.ir_module
+ class expected_rsqrt:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Rsqrt(), example_args, {}, expected_rsqrt)
+
+ # sigmoid
+ class Sigmoid(Module):
+ def __init__(self):
+ super().__init__()
+ self.sigmoid = torch.nn.Sigmoid()
+
+ def forward(self, input):
+ return self.sigmoid(input)
+
+ class Sigmoid2(Module):
+ def forward(self, input):
+ return torch.sigmoid(input)
+
+ @tvm.script.ir_module
+ class expected_sigmoid:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.sigmoid(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Sigmoid(), example_args, {}, expected_sigmoid)
+ verify_model(Sigmoid2(), example_args, {}, expected_sigmoid)
+
+ # silu
+ class SiLU(Module):
+ def __init__(self):
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+
+ def forward(self, input):
+ return self.silu(input)
+
+ class SiLU2(Module):
+ def forward(self, input):
+ return torch.nn.functional.silu(input)
+
+ @tvm.script.ir_module
+ class expected_silu:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.silu(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(SiLU(), example_args, {}, expected_silu)
+ verify_model(SiLU2(), example_args, {}, expected_silu)
+
+ # sin
+ class Sin(Module):
+ def forward(self, input: torch.Tensor):
+ return torch.sin(input)
+
+ @tvm.script.ir_module
+ class expected_sin:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Sin(), example_args, {}, expected_sin)
+
+ # sinh
+ class Sinh(Module):
+ def forward(self, input):
+ return torch.sinh(input)
+
+ @tvm.script.ir_module
+ class expected_sinh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Sinh(), example_args, {}, expected_sinh)
+
+ # sqrt
+ class Sqrt(Module):
+ def forward(self, input):
+ return torch.sqrt(input)
+
+ @tvm.script.ir_module
+ class expected_sqrt:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Sqrt(), example_args, {}, expected_sqrt)
+
+ # tan
+ class Tan(Module):
+ def forward(self, input):
+ return torch.tan(input)
+
+ @tvm.script.ir_module
+ class expected_tan:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Tan(), example_args, {}, expected_tan)
+
+ # tanh
+ class Tanh(Module):
+ def forward(self, input):
+ return torch.tanh(input)
+
+ @tvm.script.ir_module
+ class expected_tanh:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Tanh(), example_args, {}, expected_tanh)
+
+
+def test_clamp():
+ class Clamp(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=0.1, max=0.5)
+
+ @tvm.script.ir_module
+ class expected_clamp:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.1, 0.5)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Clamp(), example_args, {}, expected_clamp)
+
+
+def test_gelu():
+ class Gelu(Module):
+ def __init__(self):
+ super().__init__()
+ self.gelu = torch.nn.GELU()
+
+ def forward(self, input):
+ return self.gelu(input)
+
+ class Gelu2(Module):
+ def forward(self, input):
+ return torch.nn.functional.gelu(input)
+
+ @tvm.script.ir_module
+ class expected_gelu:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.gelu(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Gelu(), example_args, {}, expected_gelu)
+ verify_model(Gelu2(), example_args, {}, expected_gelu)
+
+
+def test_hardsigmoid():
+ class Hardsigmoid(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hs = torch.nn.Hardsigmoid()
+
+ def forward(self, input):
+ return self.hs(input)
+
+ class Hardsigmoid2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardsigmoid(input)
+
+ @tvm.script.ir_module
+ class expected_hardsigmoid:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1, R.const(6, "float32")
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
+ verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
+
+
+def test_hardswish():
+ class Hardswish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hs = torch.nn.Hardswish()
+
+ def forward(self, input):
+ return self.hs(input)
+
+ class Hardswish2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardswish(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1, R.const(6, "float32")
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(inp_0, lv2)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Hardswish(), example_args, {}, expected1)
+ verify_model(Hardswish2(), example_args, {}, expected1)
+
+
+def test_hardtanh():
+ class Hardtanh(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ht = torch.nn.Hardtanh()
+
+ def forward(self, input):
+ return self.ht(input)
+
+ class Hardtanh2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardtanh(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ inp_0, R.prim_value(T.float64(-1.0)),
R.prim_value(T.float64(1.0))
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Hardtanh(), example_args, {}, expected1)
+ verify_model(Hardtanh2(), example_args, {}, expected1)
+
+
+def test_leakyrelu():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+
+ class LeakyReLU0(Module):
+ def __init__(self):
+ super().__init__()
+ self.leakyrelu = torch.nn.LeakyReLU(0.02)
+
+ def forward(self, input):
+ return self.leakyrelu(input)
+
+ class LeakyReLU1(Module):
+ def forward(self, input):
+ return torch.nn.functional.leaky_relu(input, 0.02)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.leakyrelu(input_1, 0.02)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(LeakyReLU0(), example_args, {}, expected)
+ verify_model(LeakyReLU1(), example_args, {}, expected)
+
+
+def test_logsoftmax():
+ class LogSoftmax(Module):
+ def __init__(self):
+ super().__init__()
+ self.lsm = torch.nn.LogSoftmax(dim=1)
+
+ def forward(self, input):
+ return self.lsm(input)
+
+ class LogSoftmax2(Module):
+ def forward(self, input):
+ return torch.nn.functional.log_softmax(input, dim=1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.log_softmax(input_1, axis=1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(LogSoftmax(), example_args, {}, expected1)
+ verify_model(LogSoftmax2(), example_args, {}, expected1)
+
+
+def test_round():
+ class Round(Module):
+ def forward(self, input):
+ return torch.round(input)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.round(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Round(), example_args, {}, expected)
+
+
+def test_softmax():
+ class Softmax(Module):
+ def __init__(self):
+ super().__init__()
+ self.sm = torch.nn.Softmax(dim=1)
+
+ def forward(self, input):
+ return self.sm(input)
+
+ class Softmax2(Module):
+ def forward(self, input):
+ return torch.nn.functional.softmax(input, dim=1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.softmax(input_1, axis=1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Softmax(), example_args, {}, expected1)
+ verify_model(Softmax2(), example_args, {}, expected1)
+
+
+def test_tril_triu():
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+
+ class Tril(Module):
+ def forward(self, input):
+ return torch.tril(input, 1)
+
+ @tvm.script.ir_module
+ class expected_tril:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Tril(), example_args, {}, expected_tril)
+
+ class Triu(Module):
+ def forward(self, input):
+ return torch.triu(input, 1)
+
+ @tvm.script.ir_module
+ class expected_triu:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Triu(), example_args, {}, expected_triu)
def test_adaptive_avgpool2d():
diff --git a/tests/python/relay/test_to_mixed_precision.py
b/tests/python/relay/test_to_mixed_precision.py
index ae5172f6ca..a8032ce0d2 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -98,6 +98,7 @@ def test_lstm(target_precision):
)
[email protected](reason="Flaky test")
def test_lstm_float64():
"""Tests if can handle other mixed precision types.