This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 340eb73948 [Relax][PyTorch] Support several unary ops for
ExportedProgram importer (#17679)
340eb73948 is described below
commit 340eb73948ea7a6afbe2ac05b246f49c36b1fd9f
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Feb 26 02:58:49 2025 +0800
[Relax][PyTorch] Support several unary ops for ExportedProgram importer
(#17679)
---
.../frontend/torch/exported_program_translator.py | 11 +
.../relax/test_frontend_from_exported_program.py | 525 ++++++---------------
2 files changed, 148 insertions(+), 388 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7bcd20c462..1c676d0267 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -156,33 +156,44 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return {
# unary
+ "abs.default": self._unary_op(relax.op.abs),
"acos.default": self._unary_op(relax.op.acos),
"acosh.default": self._unary_op(relax.op.acosh),
"asin.default": self._unary_op(relax.op.asin),
"asinh.default": self._unary_op(relax.op.asinh),
"atan.default": self._unary_op(relax.op.atan),
"atanh.default": self._unary_op(relax.op.atanh),
+ "bitwise_not.default": self._unary_op(relax.op.bitwise_not),
+ "ceil.default": self._unary_op(relax.op.ceil),
"clamp.default": self._clamp,
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
+ "erf.default": self._unary_op(relax.op.erf),
"exp.default": self._unary_op(relax.op.exp),
+ "floor.default": self._unary_op(relax.op.floor),
"gelu.default": self._gelu,
"hardsigmoid.default": self._hardsigmoid,
"hardswish.default": self._hardswish,
"hardtanh.default": self._hardtanh,
+ "isfinite.default": self._unary_op(relax.op.isfinite),
+ "isinf.default": self._unary_op(relax.op.isinf),
+ "isnan.default": self._unary_op(relax.op.isnan),
"leaky_relu.default": self._leakyrelu,
+ "log.default": self._unary_op(relax.op.log),
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
+ "sign.default": self._unary_op(relax.op.sign),
"silu.default": self._unary_op(relax.op.nn.silu),
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
"sqrt.default": self._unary_op(relax.op.sqrt),
+ "square.default": self._unary_op(relax.op.square),
"tan.default": self._unary_op(relax.op.tan),
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 0d8425fc7f..33379e74ac 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
import torch
from torch.nn import Module
from torch.export import export
@@ -36,235 +37,241 @@ def verify_model(torch_model, example_args, binding,
expected):
tvm.ir.assert_structural_equal(mod, expected)
-def test_unary():
+operator_basic_unary = [
+ (torch.abs, R.abs),
+ (torch.acos, R.acos),
+ (torch.acosh, R.acosh),
+ (torch.asin, R.asin),
+ (torch.asinh, R.asinh),
+ (torch.atan, R.atan),
+ (torch.atanh, R.atanh),
+ (torch.bitwise_not, R.bitwise_not),
+ (torch.ceil, R.ceil),
+ (torch.cos, R.cos),
+ (torch.cosh, R.cosh),
+ (torch.erf, R.erf),
+ (torch.exp, R.exp),
+ (torch.floor, R.floor),
+ (torch.log, R.log),
+ (torch.neg, R.negative),
+ (torch.round, R.round),
+ (torch.rsqrt, R.rsqrt),
+ (torch.sin, R.sin),
+ (torch.sinh, R.sinh),
+ (torch.sign, R.sign),
+ (torch.sqrt, R.sqrt),
+ (torch.square, R.square),
+ (torch.tan, R.tan),
+ (torch.tanh, R.tanh),
+]
+
+
[email protected]("pytorch_op, relax_op", operator_basic_unary)
+def test_basic_unary_ops(pytorch_op, relax_op):
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- # acos
- class Acos(Module):
+ class UnaryOp(Module):
def forward(self, input):
- return torch.acos(input)
+ return pytorch_op(input)
@tvm.script.ir_module
- class expected_acos:
+ class expected:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
relax_op(input_1)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
- verify_model(Acos(), example_args, {}, expected_acos)
+ verify_model(UnaryOp(), example_args, {}, expected)
- # acosh
- class Acosh(Module):
- def forward(self, input):
- return torch.acosh(input)
- @tvm.script.ir_module
- class expected_acosh:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.acosh(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
+operator_bool_unary = [
+ (torch.isfinite, R.isfinite),
+ (torch.isinf, R.isinf),
+ (torch.isnan, R.isnan),
+]
+
- verify_model(Acosh(), example_args, {}, expected_acosh)
[email protected]("pytorch_op, relax_op", operator_bool_unary)
+def test_bool_unary_ops(pytorch_op, relax_op):
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- # asin
- class Asin(Module):
+ class UnaryOp(Module):
def forward(self, input):
- return torch.asin(input)
+ return pytorch_op(input)
@tvm.script.ir_module
- class expected_asin:
+ class expected:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv,)
R.output(gv)
return gv
- verify_model(Asin(), example_args, {}, expected_asin)
+ verify_model(UnaryOp(), example_args, {}, expected)
- # asinh
- class Asinh(Module):
- def forward(self, input):
- return torch.asinh(input)
- @tvm.script.ir_module
- class expected_asinh:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.asinh(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Asinh(), example_args, {}, expected_asinh)
+def test_extended_unary_ops():
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- # atan
- class Atan(Module):
+ # clamp
+ class Clamp(Module):
def forward(self, input):
- return torch.atan(input)
+ return torch.clamp(input, min=0.1, max=0.5)
@tvm.script.ir_module
- class expected_atan:
+ class expected_clamp:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.1, 0.5)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
- verify_model(Atan(), example_args, {}, expected_atan)
-
- # atanh
- class Atanh(Module):
- def forward(self, input):
- return torch.atanh(input)
+ verify_model(Clamp(), example_args, {}, expected_clamp)
- @tvm.script.ir_module
- class expected_atanh:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.atanh(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
+ # dropout
+ class Dropout1(Module):
+ def __init__(self):
+ super().__init__()
+ self.dropout = torch.nn.Dropout(0.5)
- verify_model(Atanh(), example_args, {}, expected_atanh)
+ def forward(self, input):
+ return self.dropout(input)
- # cos
- class Cos(Module):
+ class Dropout2(Module):
def forward(self, input):
- return torch.cos(input)
+ return torch.dropout(input, 0.5, train=True)
@tvm.script.ir_module
- class expected_cos:
+ class expected_dropout:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(input_1,)
R.output(gv)
return gv
- verify_model(Cos(), example_args, {}, expected_cos)
+ verify_model(Dropout1(), example_args, {}, expected_dropout)
+ verify_model(Dropout2(), example_args, {}, expected_dropout)
+
+ # gelu
+ class Gelu(Module):
+ def __init__(self):
+ super().__init__()
+ self.gelu = torch.nn.GELU()
+
+ def forward(self, input):
+ return self.gelu(input)
- # cosh
- class Cosh(Module):
+ class Gelu2(Module):
def forward(self, input):
- return torch.cosh(input)
+ return torch.nn.functional.gelu(input)
@tvm.script.ir_module
- class expected_cosh:
+ class expected_gelu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.gelu(input_1)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
- verify_model(Cosh(), example_args, {}, expected_cosh)
+ verify_model(Gelu(), example_args, {}, expected_gelu)
+ verify_model(Gelu2(), example_args, {}, expected_gelu)
- # dropout
- class Dropout1(Module):
+ # hardsigmoid
+ class Hardsigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
- self.dropout = torch.nn.Dropout(0.5)
+ self.hs = torch.nn.Hardsigmoid()
def forward(self, input):
- return self.dropout(input)
+ return self.hs(input)
- class Dropout2(Module):
+ class Hardsigmoid2(torch.nn.Module):
def forward(self, input):
- return torch.dropout(input, 0.5, train=True)
+ return torch.nn.functional.hardsigmoid(input)
@tvm.script.ir_module
- class expected_dropout:
+ class expected_hardsigmoid:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
with R.dataflow():
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(input_1,)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1, R.const(6, "float32")
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,)
R.output(gv)
return gv
- verify_model(Dropout1(), example_args, {}, expected_dropout)
- verify_model(Dropout2(), example_args, {}, expected_dropout)
+ verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
+ verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
+
+ # hardwish
+ class Hardswish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hs = torch.nn.Hardswish()
- # exp
- class Exp(Module):
def forward(self, input):
- return torch.exp(input)
+ return self.hs(input)
+
+ class Hardswish2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardswish(input)
@tvm.script.ir_module
- class expected_exp:
+ class expected1:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1, R.const(6, "float32")
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(inp_0, lv2)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
R.output(gv)
return gv
- verify_model(Exp(), example_args, {}, expected_exp)
+ verify_model(Hardswish(), example_args, {}, expected1)
+ verify_model(Hardswish2(), example_args, {}, expected1)
- # neg
- class Neg(Module):
- def forward(self, input):
- return -input
+ # hardtanh
+ test_hardtanh()
- @I.ir_module
- class expected_neg:
- @R.function
- def main(
- inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.negative(inp_0)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
+ # leakyrelu
+ test_leakyrelu()
- verify_model(Neg(), example_args, {}, expected_neg)
+ # log_softmax
+ test_logsoftmax()
# relu
class ReLU0(Module):
@@ -295,26 +302,6 @@ def test_unary():
verify_model(ReLU0(), example_args, {}, expected_relu)
verify_model(ReLU1(), example_args, {}, expected_relu)
- # rsqrt
- class Rsqrt(Module):
- def forward(self, input):
- return torch.rsqrt(input)
-
- @I.ir_module
- class expected_rsqrt:
- @R.function
- def main(
- inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Rsqrt(), example_args, {}, expected_rsqrt)
-
# sigmoid
class Sigmoid(Module):
def __init__(self):
@@ -373,227 +360,11 @@ def test_unary():
verify_model(SiLU(), example_args, {}, expected_silu)
verify_model(SiLU2(), example_args, {}, expected_silu)
- # sin
- class Sin(Module):
- def forward(self, input: torch.Tensor):
- return torch.sin(input)
-
- @tvm.script.ir_module
- class expected_sin:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Sin(), example_args, {}, expected_sin)
-
- # sinh
- class Sinh(Module):
- def forward(self, input):
- return torch.sinh(input)
+ # softmax
+ test_softmax()
- @tvm.script.ir_module
- class expected_sinh:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Sinh(), example_args, {}, expected_sinh)
-
- # sqrt
- class Sqrt(Module):
- def forward(self, input):
- return torch.sqrt(input)
-
- @tvm.script.ir_module
- class expected_sqrt:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Sqrt(), example_args, {}, expected_sqrt)
-
- # tan
- class Tan(Module):
- def forward(self, input):
- return torch.tan(input)
-
- @tvm.script.ir_module
- class expected_tan:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Tan(), example_args, {}, expected_tan)
-
- # tanh
- class Tanh(Module):
- def forward(self, input):
- return torch.tanh(input)
-
- @tvm.script.ir_module
- class expected_tanh:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- verify_model(Tanh(), example_args, {}, expected_tanh)
-
-
-def test_clamp():
- class Clamp(Module):
- def forward(self, input):
- return torch.clamp(input, min=0.1, max=0.5)
-
- @tvm.script.ir_module
- class expected_clamp:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.1, 0.5)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Clamp(), example_args, {}, expected_clamp)
-
-
-def test_gelu():
- class Gelu(Module):
- def __init__(self):
- super().__init__()
- self.gelu = torch.nn.GELU()
-
- def forward(self, input):
- return self.gelu(input)
-
- class Gelu2(Module):
- def forward(self, input):
- return torch.nn.functional.gelu(input)
-
- @tvm.script.ir_module
- class expected_gelu:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.gelu(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Gelu(), example_args, {}, expected_gelu)
- verify_model(Gelu2(), example_args, {}, expected_gelu)
-
-
-def test_hardsigmoid():
- class Hardsigmoid(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.hs = torch.nn.Hardsigmoid()
-
- def forward(self, input):
- return self.hs(input)
-
- class Hardsigmoid2(torch.nn.Module):
- def forward(self, input):
- return torch.nn.functional.hardsigmoid(input)
-
- @tvm.script.ir_module
- class expected_hardsigmoid:
- @R.function
- def main(
- inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
- lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
- lv1, R.const(6, "float32")
- )
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
- verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
-
-
-def test_hardswish():
- class Hardswish(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.hs = torch.nn.Hardswish()
-
- def forward(self, input):
- return self.hs(input)
-
- class Hardswish2(torch.nn.Module):
- def forward(self, input):
- return torch.nn.functional.hardswish(input)
-
- @tvm.script.ir_module
- class expected1:
- @R.function
- def main(
- inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
- lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
- lv1, R.const(6, "float32")
- )
- lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(inp_0, lv2)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Hardswish(), example_args, {}, expected1)
- verify_model(Hardswish2(), example_args, {}, expected1)
+ # tril, triu
+ test_tril_triu()
def test_hardtanh():
@@ -695,28 +466,6 @@ def test_logsoftmax():
verify_model(LogSoftmax2(), example_args, {}, expected1)
-def test_round():
- class Round(Module):
- def forward(self, input):
- return torch.round(input)
-
- @tvm.script.ir_module
- class expected:
- @R.function
- def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
- with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.round(input_1)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
- R.output(gv)
- return gv
-
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Round(), example_args, {}, expected)
-
-
def test_softmax():
class Softmax(Module):
def __init__(self):