This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a8c75802ad [Relax][PyTorch] Enable run_ep_decomposition by default
(#18471)
a8c75802ad is described below
commit a8c75802ad462bc093d2bab51c79b3bd3303355a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 20 17:07:53 2025 +0800
[Relax][PyTorch] Enable run_ep_decomposition by default (#18471)
## Why
We have finished the migration for our tests then we could set default
to run ep decompose.
## How
Update tests and exported_program_translator.py
---
.../frontend/torch/exported_program_translator.py | 15 +-
.../relax/test_frontend_from_exported_program.py | 524 ++++++++++-----------
2 files changed, 259 insertions(+), 280 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 6aa118ee5c..a2b9b2afa4 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -208,6 +208,15 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
)
+ def _native_layer_norm(self, node: fx.Node) -> relax.Var:
+ # native_layer_norm signature: (input, normalized_shape, weight, bias,
eps)
+ x = self.env[node.args[0]]
+ normalized_shape = node.args[1]
+ gamma = self.env.get(node.args[2], None) if len(node.args) > 2 else
None
+ beta = self.env.get(node.args[3], None) if len(node.args) > 3 else None
+ eps = node.args[4] if len(node.args) > 4 else 1e-05
+ return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
def _upsample_impl(
self,
x: relax.Expr,
@@ -1058,6 +1067,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"instance_norm.default": self._instance_norm,
"native_group_norm.default": self._native_group_norm,
"layer_norm.default": self._layer_norm,
+ "native_layer_norm.default": self._native_layer_norm,
"linear.default": self._linear,
"lstm.input": self._lstm,
"gru.input": self._gru,
@@ -1403,7 +1413,7 @@ def from_exported_program(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
- run_ep_decomposition: bool = False,
+ run_ep_decomposition: bool = True,
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program
@@ -1426,8 +1436,7 @@ def from_exported_program(
run_ep_decomposition : bool
A boolean flag indicating whether to run PyTorch's decomposition on the
exported program before translation. When True, high-level operators
will
- be decomposed into their constituent parts. Defaults to False for
backward
- compatibility.
+ be decomposed into their constituent parts. Defaults to True.
Returns
-------
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index acd1344ec9..1429dec5e7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,7 +32,7 @@ from tvm.relax.frontend.torch import from_exported_program
def verify_model(
- torch_model, example_args, binding, expected, dynamic_shapes=None,
run_ep_decomposition=False
+ torch_model, example_args, binding, expected, dynamic_shapes=None,
run_ep_decomposition=True
):
exported_program = export(torch_model, args=example_args,
dynamic_shapes=dynamic_shapes)
mod = from_exported_program(exported_program,
run_ep_decomposition=run_ep_decomposition)
@@ -94,7 +94,7 @@ def test_basic_unary_ops(pytorch_op, relax_op):
R.output(gv)
return gv
- verify_model(UnaryOp(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(UnaryOp(), example_args, {}, expected)
operator_bool_unary = [
@@ -123,7 +123,7 @@ def test_bool_unary_ops(pytorch_op, relax_op):
R.output(gv)
return gv
- verify_model(UnaryOp(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(UnaryOp(), example_args, {}, expected)
def test_sqrt_integer_input():
@@ -147,7 +147,7 @@ def test_sqrt_integer_input():
R.output(gv)
return gv
- verify_model(SqrtIntModel(), example_args, {}, expected_int64,
run_ep_decomposition=True)
+ verify_model(SqrtIntModel(), example_args, {}, expected_int64)
example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)
@@ -164,7 +164,7 @@ def test_sqrt_integer_input():
R.output(gv)
return gv
- verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32,
run_ep_decomposition=True)
+ verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32)
def test_extended_unary_ops():
@@ -203,8 +203,8 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Celu1(), example_args, {}, expected_celu,
run_ep_decomposition=True)
- verify_model(Celu2(), example_args, {}, expected_celu,
run_ep_decomposition=True)
+ verify_model(Celu1(), example_args, {}, expected_celu)
+ verify_model(Celu2(), example_args, {}, expected_celu)
# clamp
class Clamp(Module):
@@ -227,7 +227,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Clamp(), example_args, {}, expected_clamp,
run_ep_decomposition=True)
+ verify_model(Clamp(), example_args, {}, expected_clamp)
class ClampMinOnly(Module):
def forward(self, input):
@@ -247,9 +247,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(
- ClampMinOnly(), example_args, {}, expected_clamp_min_only,
run_ep_decomposition=True
- )
+ verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
class ClampTensors(Module):
def forward(self, input):
@@ -277,9 +275,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(
- ClampTensors(), example_args, {}, expected_clamp_tensors,
run_ep_decomposition=True
- )
+ verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
# dropout
@@ -335,9 +331,9 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2,
run_ep_decomposition=True)
- verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2,
run_ep_decomposition=True)
- verify_model(Dropout3(), example_args, {}, expected_dropout_for_3,
run_ep_decomposition=True)
+ verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2)
+ verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2)
+ verify_model(Dropout3(), example_args, {}, expected_dropout_for_3)
# elu
class Elu(Module):
@@ -380,8 +376,8 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Elu(), example_args, {}, expected_elu,
run_ep_decomposition=True)
- verify_model(Elu2(), example_args, {}, expected_elu,
run_ep_decomposition=True)
+ verify_model(Elu(), example_args, {}, expected_elu)
+ verify_model(Elu2(), example_args, {}, expected_elu)
# hardsigmoid
class Hardsigmoid(torch.nn.Module):
@@ -419,8 +415,8 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid,
run_ep_decomposition=True)
- verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid,
run_ep_decomposition=True)
+ verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
+ verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
# hardwish
class Hardswish(torch.nn.Module):
@@ -492,15 +488,9 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(
- Hardswish(), example_args, {}, expected_hardswish_for_1_2,
run_ep_decomposition=True
- )
- verify_model(
- Hardswish2(), example_args, {}, expected_hardswish_for_1_2,
run_ep_decomposition=True
- )
- verify_model(
- Hardswish3(), example_args, {}, expected_hardswish_for_3,
run_ep_decomposition=True
- )
+ verify_model(Hardswish(), example_args, {}, expected_hardswish_for_1_2)
+ verify_model(Hardswish2(), example_args, {}, expected_hardswish_for_1_2)
+ verify_model(Hardswish3(), example_args, {}, expected_hardswish_for_3)
# isfinite
class IsFinite(Module):
@@ -524,7 +514,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(IsFinite(), example_args, {}, expected_isfinite,
run_ep_decomposition=True)
+ verify_model(IsFinite(), example_args, {}, expected_isfinite)
# log2
class Log2(Module):
@@ -546,7 +536,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log2(), example_args, {}, Expected_log2,
run_ep_decomposition=True)
+ verify_model(Log2(), example_args, {}, Expected_log2)
# log10
class Log10(Module):
@@ -568,7 +558,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log10(), example_args, {}, Expected_log10,
run_ep_decomposition=True)
+ verify_model(Log10(), example_args, {}, Expected_log10)
# log1p
class Log1p(Module):
@@ -589,7 +579,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Log1p(), example_args, {}, Expected_log1p,
run_ep_decomposition=True)
+ verify_model(Log1p(), example_args, {}, Expected_log1p)
# reciprocal
class Reciprocal(Module):
@@ -610,7 +600,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Reciprocal(), example_args, {}, expected_reciprocal,
run_ep_decomposition=True)
+ verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
# Returns the maximum value of all elements in the input tensor.
class MaxModel(Module):
@@ -629,7 +619,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(MaxModel(), example_args, {}, expected_max,
run_ep_decomposition=True)
+ verify_model(MaxModel(), example_args, {}, expected_max)
# Returns the minimum value of all elements in the input tensor.
class MinModel(Module):
@@ -648,7 +638,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(MinModel(), example_args, {}, expected_min,
run_ep_decomposition=True)
+ verify_model(MinModel(), example_args, {}, expected_min)
# relu6
class ReLU6_1(torch.nn.Module):
@@ -712,9 +702,9 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(ReLU6_1(), example_args, {}, expected_relu6_1,
run_ep_decomposition=True)
- verify_model(ReLU6_2(), example_args, {}, expected_relu6_2,
run_ep_decomposition=True)
- verify_model(ReLU6_3(), example_args, {}, expected_relu6_3,
run_ep_decomposition=True)
+ verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
+ verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
+ verify_model(ReLU6_3(), example_args, {}, expected_relu6_3)
# selu
class SELU(Module):
@@ -749,7 +739,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(SELU(), example_args, {}, expected_selu,
run_ep_decomposition=True)
+ verify_model(SELU(), example_args, {}, expected_selu)
# silu
class SiLU(Module):
@@ -769,7 +759,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(SiLU(), example_args, {}, expected_silu,
run_ep_decomposition=True)
+ verify_model(SiLU(), example_args, {}, expected_silu)
# silu_
class SiLU_(Module):
@@ -797,7 +787,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(SiLU_(), example_args, {}, expected_silu_,
run_ep_decomposition=True)
+ verify_model(SiLU_(), example_args, {}, expected_silu_)
# square
class Square(Module):
@@ -818,7 +808,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(Square(), example_args, {}, expected_square,
run_ep_decomposition=True)
+ verify_model(Square(), example_args, {}, expected_square)
# relu_
class ReLU_(Module):
@@ -837,7 +827,7 @@ def test_extended_unary_ops():
R.output(gv)
return gv
- verify_model(ReLU_(), example_args, {}, expected_relu_,
run_ep_decomposition=True)
+ verify_model(ReLU_(), example_args, {}, expected_relu_)
def test_hardtanh():
@@ -891,9 +881,9 @@ def test_hardtanh():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Hardtanh(), example_args, {}, expected_for_1_2,
run_ep_decomposition=True)
- verify_model(Hardtanh2(), example_args, {}, expected_for_1_2,
run_ep_decomposition=True)
- verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3,
run_ep_decomposition=True)
+ verify_model(Hardtanh(), example_args, {}, expected_for_1_2)
+ verify_model(Hardtanh2(), example_args, {}, expected_for_1_2)
+ verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3)
def test_softplus():
@@ -939,8 +929,8 @@ def test_softplus():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softplus0(), example_args, {}, expected,
run_ep_decomposition=True)
- verify_model(Softplus1(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(Softplus0(), example_args, {}, expected)
+ verify_model(Softplus1(), example_args, {}, expected)
def test_leakyrelu():
@@ -997,9 +987,9 @@ def test_leakyrelu():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2,
run_ep_decomposition=True)
- verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2,
run_ep_decomposition=True)
- verify_model(LeakyReLU2(), example_args, {}, expected_for_3,
run_ep_decomposition=True)
+ verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2)
+ verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2)
+ verify_model(LeakyReLU2(), example_args, {}, expected_for_3)
def test_logaddexp():
@@ -1044,7 +1034,7 @@ def test_logaddexp():
torch.randn(1, 3, 10, 10, dtype=torch.float32),
torch.randn(1, 3, 10, 10, dtype=torch.float32),
)
- verify_model(LogAddExp(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(LogAddExp(), example_args, {}, expected)
def test_logsoftmax():
@@ -1074,8 +1064,8 @@ def test_logsoftmax():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(LogSoftmax(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(LogSoftmax2(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(LogSoftmax(), example_args, {}, expected1)
+ verify_model(LogSoftmax2(), example_args, {}, expected1)
def test_prelu():
@@ -1113,8 +1103,8 @@ def test_prelu():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Prelu1(), example_args, {}, expected,
run_ep_decomposition=True)
- verify_model(Prelu2(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(Prelu1(), example_args, {}, expected)
+ verify_model(Prelu2(), example_args, {}, expected)
def test_softmax():
@@ -1144,8 +1134,8 @@ def test_softmax():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softmax(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Softmax2(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Softmax(), example_args, {}, expected1)
+ verify_model(Softmax2(), example_args, {}, expected1)
def test_softsign():
@@ -1176,8 +1166,8 @@ def test_softsign():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softsign(), example_args, {}, expected_softsign,
run_ep_decomposition=True)
- verify_model(Softsign2(), example_args, {}, expected_softsign,
run_ep_decomposition=True)
+ verify_model(Softsign(), example_args, {}, expected_softsign)
+ verify_model(Softsign2(), example_args, {}, expected_softsign)
def test_softshrink():
@@ -1216,8 +1206,8 @@ def test_softshrink():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Softshrink(), example_args, {}, expected_softshrink,
run_ep_decomposition=True)
- verify_model(Softshrink2(), example_args, {}, expected_softshrink,
run_ep_decomposition=True)
+ verify_model(Softshrink(), example_args, {}, expected_softshrink)
+ verify_model(Softshrink2(), example_args, {}, expected_softshrink)
def test_tril_triu():
@@ -1251,7 +1241,7 @@ def test_tril_triu():
R.output(gv)
return gv
- verify_model(Tril(), example_args, {}, expected_tril,
run_ep_decomposition=True)
+ verify_model(Tril(), example_args, {}, expected_tril)
class Triu(Module):
def forward(self, input):
@@ -1281,7 +1271,7 @@ def test_tril_triu():
R.output(gv)
return gv
- verify_model(Triu(), example_args, {}, expected_triu,
run_ep_decomposition=True)
+ verify_model(Triu(), example_args, {}, expected_triu)
operator_binary_1 = [
@@ -1389,8 +1379,8 @@ def test_binary1(op, relax_op):
expected1 = expected_binary1_inplace if op in inplace_ops else
expected_binary1
expected2 = expected_binary2_inplace if op in inplace_ops else
expected_binary2
- verify_model(Binary1(op), example_args1, {}, expected1,
run_ep_decomposition=True)
- verify_model(Binary2(op), example_args2, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Binary1(op), example_args1, {}, expected1)
+ verify_model(Binary2(op), example_args2, {}, expected2)
operator_binary_2 = [
@@ -1452,8 +1442,8 @@ def test_binary2(op, relax_op):
R.output(gv)
return gv
- verify_model(Binary1(op), example_args1, {}, expected_binary1,
run_ep_decomposition=True)
- verify_model(Binary2(op), example_args2, {}, expected_binary2,
run_ep_decomposition=True)
+ verify_model(Binary1(op), example_args1, {}, expected_binary1)
+ verify_model(Binary2(op), example_args2, {}, expected_binary2)
def test_binary3():
@@ -1481,7 +1471,7 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(Max1(), example_args1, {}, expected_max1,
run_ep_decomposition=True)
+ verify_model(Max1(), example_args1, {}, expected_max1)
# Min
class Min1(Module):
@@ -1501,7 +1491,7 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(Min1(), example_args1, {}, expected_min1,
run_ep_decomposition=True)
+ verify_model(Min1(), example_args1, {}, expected_min1)
# RSub
class RSub1(Module):
@@ -1536,8 +1526,8 @@ def test_binary3():
R.output(gv)
return gv
- verify_model(RSub1(), example_args1, {}, expected_rsub1,
run_ep_decomposition=True)
- verify_model(RSub2(), example_args2, {}, expected_rsub2,
run_ep_decomposition=True)
+ verify_model(RSub1(), example_args1, {}, expected_rsub1)
+ verify_model(RSub2(), example_args2, {}, expected_rsub2)
# IsIn
@@ -1566,7 +1556,7 @@ def test_isin():
torch.randn(10, 10, dtype=torch.float32),
torch.randn(8, dtype=torch.float32),
)
- verify_model(IsInModel(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(IsInModel(), example_args, {}, expected)
def test_div_mode():
@@ -1591,7 +1581,7 @@ def test_div_mode():
torch.randn(64, 64, dtype=torch.float32),
torch.randn(64, dtype=torch.float32),
)
- verify_model(DivModel(), example_args, {}, expected_div,
run_ep_decomposition=True)
+ verify_model(DivModel(), example_args, {}, expected_div)
# Case 2: Division with trunc rounding
class DivTruncModel(torch.nn.Module):
@@ -1611,7 +1601,7 @@ def test_div_mode():
R.output(gv)
return gv
- verify_model(DivTruncModel(), example_args, {}, expected_div_trunc,
run_ep_decomposition=True)
+ verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)
# Case 3: Division with floor rounding
class DivFloorModel(torch.nn.Module):
@@ -1630,7 +1620,7 @@ def test_div_mode():
R.output(gv)
return gv
- verify_model(DivFloorModel(), example_args, {}, expected_div_floor,
run_ep_decomposition=True)
+ verify_model(DivFloorModel(), example_args, {}, expected_div_floor)
def test_batchnorm2d():
@@ -1685,7 +1675,7 @@ def test_batchnorm2d():
"w3": model.bn.running_mean.detach().numpy(),
"w4": model.bn.running_var.detach().numpy(),
}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
def test_adaptive_avgpool1d():
@@ -1718,8 +1708,8 @@ def test_adaptive_avgpool1d():
return gv
example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
- verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
+ verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)
def test_adaptive_avgpool2d():
@@ -1751,8 +1741,8 @@ def test_adaptive_avgpool2d():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+ verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
def test_adaptive_avgpool3d():
@@ -1783,8 +1773,8 @@ def test_adaptive_avgpool3d():
return gv
example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
- verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
+ verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)
def test_addmm():
@@ -1842,8 +1832,8 @@ def test_addmm():
torch.randn(10, 10, dtype=torch.float32),
)
- verify_model(Addmm1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Addmm2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Addmm1(), example_args, {}, expected1)
+ verify_model(Addmm2(), example_args, {}, expected2)
def test_avg_pool1d():
@@ -1946,10 +1936,10 @@ def test_avg_pool1d():
return gv
example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
- verify_model(AvgPool1d1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AvgPool1d2(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool1d3(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool1d4(), example_args, {}, expected3,
run_ep_decomposition=True)
+ verify_model(AvgPool1d1(), example_args, {}, expected1)
+ verify_model(AvgPool1d2(), example_args, {}, expected2)
+ verify_model(AvgPool1d3(), example_args, {}, expected2)
+ verify_model(AvgPool1d4(), example_args, {}, expected3)
def test_avg_pool2d():
@@ -2039,10 +2029,10 @@ def test_avg_pool2d():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(AvgPool2d1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AvgPool2d2(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool2d3(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool2d4(), example_args, {}, expected3,
run_ep_decomposition=True)
+ verify_model(AvgPool2d1(), example_args, {}, expected1)
+ verify_model(AvgPool2d2(), example_args, {}, expected2)
+ verify_model(AvgPool2d3(), example_args, {}, expected2)
+ verify_model(AvgPool2d4(), example_args, {}, expected3)
def test_avg_pool3d():
@@ -2135,10 +2125,10 @@ def test_avg_pool3d():
return gv
example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
- verify_model(AvgPool3d1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(AvgPool3d2(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool3d3(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(AvgPool3d4(), example_args, {}, expected3,
run_ep_decomposition=True)
+ verify_model(AvgPool3d1(), example_args, {}, expected1)
+ verify_model(AvgPool3d2(), example_args, {}, expected2)
+ verify_model(AvgPool3d3(), example_args, {}, expected2)
+ verify_model(AvgPool3d4(), example_args, {}, expected3)
def test_baddbmm():
@@ -2372,15 +2362,15 @@ def test_conv_transpose1d():
model = ConvTranspose1d1()
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = ConvTranspose1d1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = ConvTranspose1d2()
binding = {"w1": model.conv.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_conv_transpose2d():
@@ -2466,15 +2456,15 @@ def test_conv_transpose2d():
model = ConvTranspose2d1()
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = ConvTranspose2d1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = ConvTranspose2d2()
binding = {"w1": model.conv.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_conv1d():
@@ -2558,15 +2548,15 @@ def test_conv1d():
model = Conv1D1()
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv1D1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv1D2()
binding = {"w1": model.conv.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_conv2d():
@@ -2650,15 +2640,15 @@ def test_conv2d():
model = Conv2D1()
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv2D1Func()
binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv2D2()
binding = {"w1": model.conv.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_conv3d():
@@ -2742,15 +2732,15 @@ def test_conv3d():
model = Conv3D1()
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv3D1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Conv3D2()
binding = {"w1": model.conv.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_pad():
@@ -2973,9 +2963,7 @@ def test_pad():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(
- PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant,
run_ep_decomposition=True
- )
+ verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {},
expected_constant)
verify_model(
PadModel(pad=[1, 1, 2, 2], mode="reflect"),
example_args,
@@ -3037,12 +3025,8 @@ def test_pixel_shuffle():
return gv
example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
- verify_model(
- PixelShuffle1(upscale_factor=2), example_args, {}, expected,
run_ep_decomposition=True
- )
- verify_model(
- PixelShuffle2(upscale_factor=2), example_args, {}, expected,
run_ep_decomposition=True
- )
+ verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
+ verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)
def test_einsum():
@@ -3115,7 +3099,7 @@ def test_outer():
torch.randn(3, dtype=torch.float32),
torch.randn(4, dtype=torch.float32),
)
- verify_model(Outer(), example_args, {}, expected,
run_ep_decomposition=True)
+ verify_model(Outer(), example_args, {}, expected)
def test_embedding():
@@ -3145,7 +3129,7 @@ def test_embedding():
model = Embedding()
binding = {"w1": model.embedding.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
def test_groupnorm():
@@ -3194,7 +3178,7 @@ def test_groupnorm():
"w1": model.gn.weight.detach().numpy(),
"w2": model.gn.bias.detach().numpy(),
}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
def test_instancenorm2d():
@@ -3239,7 +3223,7 @@ def test_instancenorm2d():
"w1": torch.ones(3).detach().numpy(),
"w2": torch.zeros(3).detach().numpy(),
}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
def test_layernorm():
@@ -3354,15 +3338,15 @@ def test_linear():
model = Dense1()
binding = {"w1": model.linear.weight.detach().numpy(), "w2":
model.linear.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Dense1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected1)
model = Dense2()
binding = {"w1": model.linear.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
+ verify_model(model, example_args, binding, expected2)
def test_maxpool1d():
@@ -3479,9 +3463,9 @@ def test_maxpool1d():
example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
# Verify the models
- verify_model(MaxPool1d(), example_args1, {}, expected1,
run_ep_decomposition=True)
- verify_model(MaxPool1d_functional(), example_args2, {}, expected2,
run_ep_decomposition=True)
- verify_model(MaxPool1d2(), example_args3, {}, expected3,
run_ep_decomposition=True)
+ verify_model(MaxPool1d(), example_args1, {}, expected1)
+ verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
+ verify_model(MaxPool1d2(), example_args3, {}, expected3)
def test_maxpool2d():
@@ -3596,10 +3580,10 @@ def test_maxpool2d():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(MaxPool2d(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(MaxPool2d_functional(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(MaxPool2d2(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(MaxPool2d3(), example_args, {}, expected3,
run_ep_decomposition=True)
+ verify_model(MaxPool2d(), example_args, {}, expected1)
+ verify_model(MaxPool2d_functional(), example_args, {}, expected1)
+ verify_model(MaxPool2d2(), example_args, {}, expected2)
+ verify_model(MaxPool2d3(), example_args, {}, expected3)
def test_maxpool3d():
@@ -3718,10 +3702,10 @@ def test_maxpool3d():
example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
# Verify the models with expected IR modules
- verify_model(MaxPool3d(), example_args1, {}, expected1,
run_ep_decomposition=True)
- verify_model(MaxPool3d_functional(), example_args1, {}, expected1,
run_ep_decomposition=True)
- verify_model(MaxPool3d2(), example_args2, {}, expected2,
run_ep_decomposition=True)
- verify_model(MaxPool3d3(), example_args3, {}, expected3,
run_ep_decomposition=True)
+ verify_model(MaxPool3d(), example_args1, {}, expected1)
+ verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
+ verify_model(MaxPool3d2(), example_args2, {}, expected2)
+ verify_model(MaxPool3d3(), example_args3, {}, expected3)
def test_scaled_dot_product_attention():
@@ -4039,10 +4023,10 @@ def test_unbind():
return gv
example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
- verify_model(Unbind1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Unbind2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Unbind1(), example_args, {}, expected1)
+ verify_model(Unbind2(), example_args, {}, expected2)
single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
- verify_model(Unbind2(), single_dim_args, {}, expected3,
run_ep_decomposition=True)
+ verify_model(Unbind2(), single_dim_args, {}, expected3)
def test_interpolate():
@@ -4536,15 +4520,9 @@ def test_interpolate():
return gv
example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
- verify_model(
- InterpolateBilinear(), example_args, {}, expected_bilinear,
run_ep_decomposition=True
- )
- verify_model(
- InterpolateNearest(), example_args, {}, expected_nearest,
run_ep_decomposition=True
- )
- verify_model(
- InterpolateBicubic(), example_args, {}, expected_bicubic,
run_ep_decomposition=True
- )
+ verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
+ verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
+ verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)
def test_mean():
@@ -4581,8 +4559,8 @@ def test_mean():
return gv
example_args = (torch.randn(256, 256, dtype=torch.float32),)
- verify_model(Mean(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(MeanKeepDim(), example_args, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Mean(), example_args, {}, Expected1)
+ verify_model(MeanKeepDim(), example_args, {}, Expected2)
def test_sum():
@@ -4604,7 +4582,7 @@ def test_sum():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Sum(), example_args, {}, expected1, run_ep_decomposition=True)
+ verify_model(Sum(), example_args, {}, expected1)
def test_argmax_argmin():
@@ -4648,8 +4626,8 @@ def test_argmax_argmin():
R.output(gv)
return gv
- verify_model(Argmax1(), example_args, {}, expected_argmax1,
run_ep_decomposition=True)
- verify_model(Argmax2(), example_args, {}, expected_argmax2,
run_ep_decomposition=True)
+ verify_model(Argmax1(), example_args, {}, expected_argmax1)
+ verify_model(Argmax2(), example_args, {}, expected_argmax2)
class Argmin1(Module):
def __init__(self) -> None:
@@ -4689,8 +4667,8 @@ def test_argmax_argmin():
R.output(gv)
return gv
- verify_model(Argmin1(), example_args, {}, expected_argmin1,
run_ep_decomposition=True)
- verify_model(Argmin2(), example_args, {}, expected_argmin2,
run_ep_decomposition=True)
+ verify_model(Argmin1(), example_args, {}, expected_argmin1)
+ verify_model(Argmin2(), example_args, {}, expected_argmin2)
def test_cat_concat():
@@ -4737,10 +4715,10 @@ def test_cat_concat():
return gv
example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
- verify_model(Cat0(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(Cat1(), example_args, {}, Expected2,
run_ep_decomposition=True)
- verify_model(Cat2(), example_args, {}, Expected2,
run_ep_decomposition=True)
- verify_model(Cat3(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Cat0(), example_args, {}, Expected1)
+ verify_model(Cat1(), example_args, {}, Expected2)
+ verify_model(Cat2(), example_args, {}, Expected2)
+ verify_model(Cat3(), example_args, {}, Expected1)
def test_cumsum():
@@ -4762,7 +4740,7 @@ def test_cumsum():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Cumsum(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Cumsum(), example_args, {}, expected1)
def test_expand():
@@ -4788,8 +4766,8 @@ def test_expand():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Expand1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Expand2(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Expand1(), example_args, {}, expected1)
+ verify_model(Expand2(), example_args, {}, expected1)
def test_flatten():
@@ -4815,7 +4793,7 @@ def test_flatten():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Flatten(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Flatten(), example_args, {}, expected1)
def test_meshgrid():
@@ -4865,8 +4843,8 @@ def test_meshgrid():
torch.randn(3, dtype=torch.float32),
torch.randn(3, dtype=torch.float32),
)
- verify_model(Meshgrid1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Meshgrid2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Meshgrid1(), example_args, {}, expected1)
+ verify_model(Meshgrid2(), example_args, {}, expected2)
def test_permute():
@@ -4892,8 +4870,8 @@ def test_permute():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Permute1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Permute2(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Permute1(), example_args, {}, expected1)
+ verify_model(Permute2(), example_args, {}, expected1)
def test_repeat():
@@ -4930,13 +4908,13 @@ def test_repeat():
return gv
example_args = (torch.randn(3, dtype=torch.float32),)
- verify_model(Tile1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Tile1(), example_args, {}, expected1)
example_args = (torch.randn(1, 3, dtype=torch.float32),)
- verify_model(Tile2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Tile2(), example_args, {}, expected2)
example_args = (torch.randn(1, 3, dtype=torch.float32),)
- verify_model(Tile2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Tile2(), example_args, {}, expected2)
def test_reshape():
@@ -4958,7 +4936,7 @@ def test_reshape():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Reshape(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Reshape(), example_args, {}, expected1)
def test_reshape_as():
@@ -4984,7 +4962,7 @@ def test_reshape_as():
torch.randn(1, 2, 3, 4, dtype=torch.float32),
torch.randn(2, 12, dtype=torch.float32),
)
- verify_model(ReshapeAs(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(ReshapeAs(), example_args, {}, expected1)
def test_roll():
@@ -5062,9 +5040,9 @@ def test_roll():
example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
# Run verification for each case
- verify_model(Roll1(), (example_input,), {}, Expected1,
run_ep_decomposition=True)
- verify_model(Roll2(), (example_input,), {}, Expected2,
run_ep_decomposition=True)
- verify_model(Roll3(), (example_input,), {}, Expected3,
run_ep_decomposition=True)
+ verify_model(Roll1(), (example_input,), {}, Expected1)
+ verify_model(Roll2(), (example_input,), {}, Expected2)
+ verify_model(Roll3(), (example_input,), {}, Expected3)
def test_select_slice():
@@ -5144,10 +5122,10 @@ def test_select_slice():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Slice1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Slice1(), example_args, {}, expected1)
example_args = (torch.randn(8, 16, dtype=torch.float32),)
- verify_model(Slice2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Slice2(), example_args, {}, expected2)
def test_slice_scatter():
@@ -5189,10 +5167,10 @@ def test_slice_scatter():
return gv
example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32),
torch.randn(8, 3, 10, 10))
- verify_model(SliceScatter1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(SliceScatter1(), example_args, {}, expected1)
example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6,
16))
- verify_model(SliceScatter2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(SliceScatter2(), example_args, {}, expected2)
def test_split():
@@ -5331,11 +5309,11 @@ def test_split():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Chunk(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Chunk(), example_args, {}, Expected)
example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
- verify_model(Unbind1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Unbind2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Unbind1(), example_args, {}, expected1)
+ verify_model(Unbind2(), example_args, {}, expected2)
def test_squeeze():
@@ -5373,8 +5351,8 @@ def test_squeeze():
example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
- verify_model(Squeeze1(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(Squeeze2(), example_args, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Squeeze1(), example_args, {}, Expected1)
+ verify_model(Squeeze2(), example_args, {}, Expected2)
def test_stack():
@@ -5439,10 +5417,10 @@ def test_stack():
example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
- verify_model(Stack0(), example_args, {}, Expected0,
run_ep_decomposition=True)
- verify_model(Stack1(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(Stack2(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(Stack3(), example_args, {}, Expected3,
run_ep_decomposition=True)
+ verify_model(Stack0(), example_args, {}, Expected0)
+ verify_model(Stack1(), example_args, {}, Expected1)
+ verify_model(Stack2(), example_args, {}, Expected1)
+ verify_model(Stack3(), example_args, {}, Expected3)
def test_tile():
@@ -5485,9 +5463,9 @@ def test_tile():
return gv
example_args = (torch.randn(1, 3, dtype=torch.float32),)
- verify_model(Tile1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Tile2(), example_args, {}, expected2,
run_ep_decomposition=True)
- verify_model(Tile3(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Tile1(), example_args, {}, expected1)
+ verify_model(Tile2(), example_args, {}, expected2)
+ verify_model(Tile3(), example_args, {}, expected2)
def test_transpose():
@@ -5509,7 +5487,7 @@ def test_transpose():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Transpose(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Transpose(), example_args, {}, expected1)
def test_unsqueeze():
@@ -5549,8 +5527,8 @@ def test_unsqueeze():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Unsqueeze1(), example_args, {}, expected1,
run_ep_decomposition=True)
- verify_model(Unsqueeze2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(Unsqueeze1(), example_args, {}, expected1)
+ verify_model(Unsqueeze2(), example_args, {}, expected2)
def test_view():
@@ -5572,7 +5550,7 @@ def test_view():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(View(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(View(), example_args, {}, expected1)
def test_arange():
@@ -5593,7 +5571,7 @@ def test_arange():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(Arange(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Arange(), example_args, {}, Expected)
def test_hamming_window():
@@ -5620,7 +5598,7 @@ def test_hamming_window():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(HammingWindow(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(HammingWindow(), example_args, {}, Expected)
def test_contiguous():
@@ -5640,7 +5618,7 @@ def test_contiguous():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(Contiguous(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Contiguous(), example_args, {}, Expected)
def test_clone():
@@ -5660,7 +5638,7 @@ def test_clone():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(Clone(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Clone(), example_args, {}, Expected)
def test_empty():
@@ -5683,7 +5661,7 @@ def test_empty():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(Empty(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Empty(), example_args, {}, Expected)
def test_empty_without_dtype():
@@ -5704,7 +5682,7 @@ def test_empty_without_dtype():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(EmptyWithoutDtype(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(EmptyWithoutDtype(), example_args, {}, Expected)
def test_fill():
@@ -5727,7 +5705,7 @@ def test_fill():
return gv
example_args = (torch.randn(10, 10, dtype=torch.float32),)
- verify_model(Fill(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Fill(), example_args, {}, Expected)
def test_fill_inplace():
@@ -5753,7 +5731,7 @@ def test_fill_inplace():
return gv
example_args = (torch.randn(2, 3, dtype=torch.float32),)
- verify_model(FillInplace(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(FillInplace(), example_args, {}, Expected)
def test_masked_fill():
@@ -5775,7 +5753,7 @@ def test_masked_fill():
return gv
example_args = (torch.randn(128, 128, dtype=torch.float32),
torch.rand(128, 128) < 0.5)
- verify_model(Masked_Fill(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Masked_Fill(), example_args, {}, Expected)
def test_masked_fill_inplace():
@@ -5799,7 +5777,7 @@ def test_masked_fill_inplace():
return gv
example_args = (torch.randn(128, 128, dtype=torch.float32),
torch.rand(128, 128) < 0.5)
- verify_model(Masked_Fill_Inplace(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
def test_new_ones():
@@ -5823,7 +5801,7 @@ def test_new_ones():
return gv
example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
- verify_model(NewOnes(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(NewOnes(), example_args, {}, expected1)
def test_new_zeros():
@@ -5846,7 +5824,7 @@ def test_new_zeros():
return gv
example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
- verify_model(NewZeros(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(NewZeros(), example_args, {}, expected1)
def test_to_copy():
@@ -5937,11 +5915,11 @@ def test_to_copy():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(ToFloat(), example_args, {}, expected_float,
run_ep_decomposition=True)
- verify_model(ToHalf(), example_args, {}, expected_half,
run_ep_decomposition=True)
- verify_model(Type(), example_args, {}, expected_type,
run_ep_decomposition=True)
- verify_model(To1(), example_args, {}, expected_to1,
run_ep_decomposition=True)
- verify_model(To2(), example_args, {}, expected_to2,
run_ep_decomposition=True)
+ verify_model(ToFloat(), example_args, {}, expected_float)
+ verify_model(ToHalf(), example_args, {}, expected_half)
+ verify_model(Type(), example_args, {}, expected_type)
+ verify_model(To1(), example_args, {}, expected_to1)
+ verify_model(To2(), example_args, {}, expected_to2)
def test_keep_params():
@@ -5986,9 +5964,7 @@ def test_keep_params():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
model = Conv2D1()
exported_program = torch.export.export(model, example_args)
- mod = from_exported_program(
- exported_program, keep_params_as_input=True, run_ep_decomposition=True
- )
+ mod = from_exported_program(exported_program, keep_params_as_input=True)
mod, params = detach_params(mod)
tvm.ir.assert_structural_equal(mod, expected1)
func = mod["main"]
@@ -6024,9 +6000,7 @@ def test_unwrap_unit_return_tuple():
example_args = (torch.randn(256, 256, dtype=torch.float32),)
exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(
- exported_program, unwrap_unit_return_tuple=True,
run_ep_decomposition=True
- )
+ mod = from_exported_program(exported_program,
unwrap_unit_return_tuple=True)
tvm.ir.assert_structural_equal(mod, Expected)
@@ -6056,9 +6030,7 @@ def test_no_bind_return_tuple():
torch.randn(256, 256, dtype=torch.float32),
)
exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(
- exported_program, no_bind_return_tuple=True, run_ep_decomposition=True
- )
+ mod = from_exported_program(exported_program, no_bind_return_tuple=True)
tvm.ir.assert_structural_equal(mod, Expected)
@@ -6081,7 +6053,7 @@ def test_empty_like():
example_args = (torch.randn(5, dtype=torch.float32),)
- verify_model(EmptyLike(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(EmptyLike(), example_args, {}, Expected)
def test_one_hot():
@@ -6108,7 +6080,7 @@ def test_one_hot():
example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
- verify_model(OneHot(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(OneHot(), example_args, {}, Expected)
def test_ones_like():
@@ -6132,7 +6104,7 @@ def test_ones_like():
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(OnesLike(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(OnesLike(), example_args, {}, Expected)
def test_zero_inplace():
@@ -6161,7 +6133,7 @@ def test_zero_inplace():
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(ZeroInplace(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(ZeroInplace(), example_args, {}, Expected)
def test_zeros():
@@ -6185,7 +6157,7 @@ def test_zeros():
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(Zeros(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Zeros(), example_args, {}, Expected)
def test_zeros_like():
@@ -6208,7 +6180,7 @@ def test_zeros_like():
return gv
example_args = (torch.rand(128, 128, dtype=torch.float32),)
- verify_model(ZerosLike(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(ZerosLike(), example_args, {}, Expected)
def test_type_as():
@@ -6234,7 +6206,7 @@ def test_type_as():
torch.rand(128, 128, dtype=torch.float16),
)
- verify_model(TypeAs(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(TypeAs(), example_args, {}, Expected)
def test_select():
@@ -6256,7 +6228,7 @@ def test_select():
example_args = (torch.randn(2, 3, dtype=torch.float32),)
- verify_model(Select(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Select(), example_args, {}, Expected)
def test_unflatten():
@@ -6282,8 +6254,8 @@ def test_unflatten():
example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)
- verify_model(Unflatten(), example_args, {}, Expected,
run_ep_decomposition=True)
- verify_model(Unflatten1(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Unflatten(), example_args, {}, Expected)
+ verify_model(Unflatten1(), example_args, {}, Expected)
def test_gather():
@@ -6360,10 +6332,10 @@ def test_gather():
torch.randint(0, 3, (2, 3), dtype=torch.int64),
)
- verify_model(Gather0(), example_args, {}, Expected0,
run_ep_decomposition=True)
- verify_model(Gather1(), example_args, {}, Expected1,
run_ep_decomposition=True)
- verify_model(Gather2(), example_args, {}, Expected2,
run_ep_decomposition=True)
- verify_model(Gather3(), example_args, {}, Expected3,
run_ep_decomposition=True)
+ verify_model(Gather0(), example_args, {}, Expected0)
+ verify_model(Gather1(), example_args, {}, Expected1)
+ verify_model(Gather2(), example_args, {}, Expected2)
+ verify_model(Gather3(), example_args, {}, Expected3)
def test_index_put():
@@ -6555,11 +6527,11 @@ def test_index_put():
return gv
# Run verification for each case
- verify_model(IndexPut1D(), example_args_1d, {}, Expected1D,
run_ep_decomposition=True)
- verify_model(IndexPut2D(), example_args_2d, {}, Expected2D,
run_ep_decomposition=True)
- verify_model(IndexPut3D(), example_args_3d, {}, Expected3D,
run_ep_decomposition=True)
- verify_model(IndexPut4D(), example_args_4d, {}, Expected4D,
run_ep_decomposition=True)
- verify_model(IndexPut5D(), example_args_5d, {}, Expected5D,
run_ep_decomposition=True)
+ verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
+ verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
+ verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
+ verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
+ verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
def test_flip():
@@ -6597,8 +6569,8 @@ def test_flip():
example_args = (torch.randn(2, 2, dtype=torch.float32),)
- verify_model(Flip0(), example_args, {}, Expected0,
run_ep_decomposition=True)
- verify_model(Flip1(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Flip0(), example_args, {}, Expected0)
+ verify_model(Flip1(), example_args, {}, Expected1)
def test_take():
@@ -6625,7 +6597,7 @@ def test_take():
torch.randint(0, 5, (3,), dtype=torch.int64),
)
- verify_model(Take(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Take(), example_args, {}, Expected)
def test_std():
@@ -6647,7 +6619,7 @@ def test_std():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Std(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Std(), example_args, {}, Expected)
def test_var():
@@ -6668,7 +6640,7 @@ def test_var():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Var(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Var(), example_args, {}, Expected)
def test_prod():
@@ -6689,7 +6661,7 @@ def test_prod():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Prod(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Prod(), example_args, {}, Expected)
def test_cumprod():
@@ -6710,7 +6682,7 @@ def test_cumprod():
return gv
example_input = torch.randn(5, 3, dtype=torch.float32)
- verify_model(Cumprod(), (example_input,), {}, Expected,
run_ep_decomposition=True)
+ verify_model(Cumprod(), (example_input,), {}, Expected)
def test_where():
@@ -6736,7 +6708,7 @@ def test_where():
x = torch.randn(5, 3, dtype=torch.float32)
y = torch.randn(5, 3, dtype=torch.float32)
- verify_model(Where(), (condition, x, y), {}, Expected,
run_ep_decomposition=True)
+ verify_model(Where(), (condition, x, y), {}, Expected)
def test_bucketize():
@@ -6761,7 +6733,7 @@ def test_bucketize():
input_tensor = torch.arange(0, 20)
boundaries = torch.arange(0, 20, 2)
- verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected,
run_ep_decomposition=True)
+ verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
def test_argsort():
@@ -6788,7 +6760,7 @@ def test_argsort():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Argsort(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Argsort(), example_args, {}, Expected)
def test_topk():
@@ -6816,7 +6788,7 @@ def test_topk():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Topk(), example_args, {}, Expected)
def test_dynamic_shape():
@@ -6872,7 +6844,7 @@ def test_broadcast_to():
return gv
example_args = (torch.randn(5, 1, dtype=torch.float32),)
- verify_model(BroadcastTo(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(BroadcastTo(), example_args, {}, Expected)
def test_narrow():
@@ -6901,7 +6873,7 @@ def test_narrow():
return gv
example_args = (torch.randn(5, 3, dtype=torch.float32),)
- verify_model(Narrow(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Narrow(), example_args, {}, Expected)
def test_item():
@@ -6920,7 +6892,7 @@ def test_item():
return gv
example_args = (torch.randn(1, dtype=torch.float32),)
- verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True)
+ verify_model(Item(), example_args, {}, Expected)
def test_norm():
@@ -7032,9 +7004,7 @@ def test_norm():
example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
for (p, dim, keepdim), expected in norms:
- verify_model(
- Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected,
run_ep_decomposition=True
- )
+ verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {},
expected)
def test_eye():
@@ -7095,10 +7065,10 @@ def test_eye():
return gv
example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
- verify_model(Eye1(), example_args1, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Eye1(), example_args1, {}, Expected1)
example_args2 = (torch.randn(5, dtype=torch.float32),)
- verify_model(Eye2(), example_args2, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Eye2(), example_args2, {}, Expected2)
def test_cross_entropy():
@@ -7146,7 +7116,7 @@ def test_cross_entropy():
return gv
example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
- verify_model(CrossEntropyModule(), example_args1, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
def test_linspace():
@@ -7178,7 +7148,7 @@ def test_linspace():
return gv
example_args = (torch.randn(9, 9, dtype=torch.float32),)
- verify_model(Linspace(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Linspace(), example_args, {}, Expected)
@pytest.mark.parametrize(
@@ -7215,7 +7185,7 @@ def test_dtypes(torch_dtype, relax_dtype):
R.output(gv)
return gv
- verify_model(Model(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(Model(), example_args, {}, Expected)
def test_mm():
@@ -7241,7 +7211,7 @@ def test_mm():
R.output(gv)
return gv
- verify_model(MatrixMultiply(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(MatrixMultiply(), example_args, {}, Expected)
def test_lstm():
@@ -7266,7 +7236,7 @@ def test_lstm():
with torch.no_grad():
pytorch_output = model(x)
exported_program = export(model, args=(x,))
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -7302,7 +7272,7 @@ def test_lstm():
with torch.no_grad():
pytorch_output2 = model2(x2)
exported_program2 = export(model2, args=(x2,))
- mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
+ mod2 = from_exported_program(exported_program2)
ex2 = relax.build(mod2, target)
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -7334,7 +7304,7 @@ def test_tensor_none_tuple():
R.output(gv)
return gv
- verify_model(TensorNoneModel(), example_args, {}, Expected,
run_ep_decomposition=True)
+ verify_model(TensorNoneModel(), example_args, {}, Expected)
def test_gru():
@@ -7359,7 +7329,7 @@ def test_gru():
with torch.no_grad():
pytorch_output = model(x)
exported_program = export(model, args=(x,))
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -7395,7 +7365,7 @@ def test_gru():
with torch.no_grad():
pytorch_output2 = model2(x2)
exported_program2 = export(model2, args=(x2,))
- mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
+ mod2 = from_exported_program(exported_program2)
ex2 = relax.build(mod2, target)
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -7432,7 +7402,7 @@ def test_dynamic_shape_with_range_constraints():
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
tvm.ir.assert_structural_equal(mod, Expected)
@@ -7466,7 +7436,7 @@ def test_dynamic_shape_with_addition_constraints():
dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
@@ -7500,7 +7470,7 @@ def test_dynamic_shape_with_subtraction_constraints():
dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
@@ -7534,7 +7504,7 @@ def test_dynamic_shape_with_multiplication_constraints():
dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ mod = from_exported_program(exported_program)
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)