This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a8c75802ad [Relax][PyTorch] Enable run_ep_decomposition by default 
(#18471)
a8c75802ad is described below

commit a8c75802ad462bc093d2bab51c79b3bd3303355a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 20 17:07:53 2025 +0800

    [Relax][PyTorch] Enable run_ep_decomposition by default (#18471)
    
    ## Why
    
    We have finished the migration for our tests then we could set default
    to run ep decompose.
    
    ## How
    
    Update tests and exported_program_translator.py
---
 .../frontend/torch/exported_program_translator.py  |  15 +-
 .../relax/test_frontend_from_exported_program.py   | 524 ++++++++++-----------
 2 files changed, 259 insertions(+), 280 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 6aa118ee5c..a2b9b2afa4 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -208,6 +208,15 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             )
         )
 
+    def _native_layer_norm(self, node: fx.Node) -> relax.Var:
+        # native_layer_norm signature: (input, normalized_shape, weight, bias, 
eps)
+        x = self.env[node.args[0]]
+        normalized_shape = node.args[1]
+        gamma = self.env.get(node.args[2], None) if len(node.args) > 2 else 
None
+        beta = self.env.get(node.args[3], None) if len(node.args) > 3 else None
+        eps = node.args[4] if len(node.args) > 4 else 1e-05
+        return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
     def _upsample_impl(
         self,
         x: relax.Expr,
@@ -1058,6 +1067,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "instance_norm.default": self._instance_norm,
             "native_group_norm.default": self._native_group_norm,
             "layer_norm.default": self._layer_norm,
+            "native_layer_norm.default": self._native_layer_norm,
             "linear.default": self._linear,
             "lstm.input": self._lstm,
             "gru.input": self._gru,
@@ -1403,7 +1413,7 @@ def from_exported_program(
     keep_params_as_input: bool = False,
     unwrap_unit_return_tuple: bool = False,
     no_bind_return_tuple: bool = False,
-    run_ep_decomposition: bool = False,
+    run_ep_decomposition: bool = True,
 ) -> tvm.IRModule:
     """Convert a PyTorch ExportedProgram to a Relax program
 
@@ -1426,8 +1436,7 @@ def from_exported_program(
     run_ep_decomposition : bool
         A boolean flag indicating whether to run PyTorch's decomposition on the
         exported program before translation. When True, high-level operators 
will
-        be decomposed into their constituent parts. Defaults to False for 
backward
-        compatibility.
+        be decomposed into their constituent parts. Defaults to True.
 
     Returns
     -------
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index acd1344ec9..1429dec5e7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,7 +32,7 @@ from tvm.relax.frontend.torch import from_exported_program
 
 
 def verify_model(
-    torch_model, example_args, binding, expected, dynamic_shapes=None, 
run_ep_decomposition=False
+    torch_model, example_args, binding, expected, dynamic_shapes=None, 
run_ep_decomposition=True
 ):
     exported_program = export(torch_model, args=example_args, 
dynamic_shapes=dynamic_shapes)
     mod = from_exported_program(exported_program, 
run_ep_decomposition=run_ep_decomposition)
@@ -94,7 +94,7 @@ def test_basic_unary_ops(pytorch_op, relax_op):
                 R.output(gv)
             return gv
 
-    verify_model(UnaryOp(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(UnaryOp(), example_args, {}, expected)
 
 
 operator_bool_unary = [
@@ -123,7 +123,7 @@ def test_bool_unary_ops(pytorch_op, relax_op):
                 R.output(gv)
             return gv
 
-    verify_model(UnaryOp(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(UnaryOp(), example_args, {}, expected)
 
 
 def test_sqrt_integer_input():
@@ -147,7 +147,7 @@ def test_sqrt_integer_input():
                 R.output(gv)
             return gv
 
-    verify_model(SqrtIntModel(), example_args, {}, expected_int64, 
run_ep_decomposition=True)
+    verify_model(SqrtIntModel(), example_args, {}, expected_int64)
 
     example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)
 
@@ -164,7 +164,7 @@ def test_sqrt_integer_input():
                 R.output(gv)
             return gv
 
-    verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32, 
run_ep_decomposition=True)
+    verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32)
 
 
 def test_extended_unary_ops():
@@ -203,8 +203,8 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Celu1(), example_args, {}, expected_celu, 
run_ep_decomposition=True)
-    verify_model(Celu2(), example_args, {}, expected_celu, 
run_ep_decomposition=True)
+    verify_model(Celu1(), example_args, {}, expected_celu)
+    verify_model(Celu2(), example_args, {}, expected_celu)
 
     # clamp
     class Clamp(Module):
@@ -227,7 +227,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Clamp(), example_args, {}, expected_clamp, 
run_ep_decomposition=True)
+    verify_model(Clamp(), example_args, {}, expected_clamp)
 
     class ClampMinOnly(Module):
         def forward(self, input):
@@ -247,9 +247,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(
-        ClampMinOnly(), example_args, {}, expected_clamp_min_only, 
run_ep_decomposition=True
-    )
+    verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
 
     class ClampTensors(Module):
         def forward(self, input):
@@ -277,9 +275,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(
-        ClampTensors(), example_args, {}, expected_clamp_tensors, 
run_ep_decomposition=True
-    )
+    verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
 
     # dropout
 
@@ -335,9 +331,9 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2, 
run_ep_decomposition=True)
-    verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2, 
run_ep_decomposition=True)
-    verify_model(Dropout3(), example_args, {}, expected_dropout_for_3, 
run_ep_decomposition=True)
+    verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2)
+    verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2)
+    verify_model(Dropout3(), example_args, {}, expected_dropout_for_3)
 
     # elu
     class Elu(Module):
@@ -380,8 +376,8 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Elu(), example_args, {}, expected_elu, 
run_ep_decomposition=True)
-    verify_model(Elu2(), example_args, {}, expected_elu, 
run_ep_decomposition=True)
+    verify_model(Elu(), example_args, {}, expected_elu)
+    verify_model(Elu2(), example_args, {}, expected_elu)
 
     # hardsigmoid
     class Hardsigmoid(torch.nn.Module):
@@ -419,8 +415,8 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid, 
run_ep_decomposition=True)
-    verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid, 
run_ep_decomposition=True)
+    verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
+    verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)
 
     # hardwish
     class Hardswish(torch.nn.Module):
@@ -492,15 +488,9 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(
-        Hardswish(), example_args, {}, expected_hardswish_for_1_2, 
run_ep_decomposition=True
-    )
-    verify_model(
-        Hardswish2(), example_args, {}, expected_hardswish_for_1_2, 
run_ep_decomposition=True
-    )
-    verify_model(
-        Hardswish3(), example_args, {}, expected_hardswish_for_3, 
run_ep_decomposition=True
-    )
+    verify_model(Hardswish(), example_args, {}, expected_hardswish_for_1_2)
+    verify_model(Hardswish2(), example_args, {}, expected_hardswish_for_1_2)
+    verify_model(Hardswish3(), example_args, {}, expected_hardswish_for_3)
 
     # isfinite
     class IsFinite(Module):
@@ -524,7 +514,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(IsFinite(), example_args, {}, expected_isfinite, 
run_ep_decomposition=True)
+    verify_model(IsFinite(), example_args, {}, expected_isfinite)
 
     # log2
     class Log2(Module):
@@ -546,7 +536,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Log2(), example_args, {}, Expected_log2, 
run_ep_decomposition=True)
+    verify_model(Log2(), example_args, {}, Expected_log2)
 
     # log10
     class Log10(Module):
@@ -568,7 +558,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Log10(), example_args, {}, Expected_log10, 
run_ep_decomposition=True)
+    verify_model(Log10(), example_args, {}, Expected_log10)
 
     # log1p
     class Log1p(Module):
@@ -589,7 +579,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Log1p(), example_args, {}, Expected_log1p, 
run_ep_decomposition=True)
+    verify_model(Log1p(), example_args, {}, Expected_log1p)
 
     # reciprocal
     class Reciprocal(Module):
@@ -610,7 +600,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Reciprocal(), example_args, {}, expected_reciprocal, 
run_ep_decomposition=True)
+    verify_model(Reciprocal(), example_args, {}, expected_reciprocal)
 
     # Returns the maximum value of all elements in the input tensor.
     class MaxModel(Module):
@@ -629,7 +619,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(MaxModel(), example_args, {}, expected_max, 
run_ep_decomposition=True)
+    verify_model(MaxModel(), example_args, {}, expected_max)
 
     # Returns the minimum value of all elements in the input tensor.
     class MinModel(Module):
@@ -648,7 +638,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(MinModel(), example_args, {}, expected_min, 
run_ep_decomposition=True)
+    verify_model(MinModel(), example_args, {}, expected_min)
 
     # relu6
     class ReLU6_1(torch.nn.Module):
@@ -712,9 +702,9 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(ReLU6_1(), example_args, {}, expected_relu6_1, 
run_ep_decomposition=True)
-    verify_model(ReLU6_2(), example_args, {}, expected_relu6_2, 
run_ep_decomposition=True)
-    verify_model(ReLU6_3(), example_args, {}, expected_relu6_3, 
run_ep_decomposition=True)
+    verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
+    verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
+    verify_model(ReLU6_3(), example_args, {}, expected_relu6_3)
 
     # selu
     class SELU(Module):
@@ -749,7 +739,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(SELU(), example_args, {}, expected_selu, 
run_ep_decomposition=True)
+    verify_model(SELU(), example_args, {}, expected_selu)
 
     # silu
     class SiLU(Module):
@@ -769,7 +759,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(SiLU(), example_args, {}, expected_silu, 
run_ep_decomposition=True)
+    verify_model(SiLU(), example_args, {}, expected_silu)
 
     # silu_
     class SiLU_(Module):
@@ -797,7 +787,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(SiLU_(), example_args, {}, expected_silu_, 
run_ep_decomposition=True)
+    verify_model(SiLU_(), example_args, {}, expected_silu_)
 
     # square
     class Square(Module):
@@ -818,7 +808,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(Square(), example_args, {}, expected_square, 
run_ep_decomposition=True)
+    verify_model(Square(), example_args, {}, expected_square)
 
     # relu_
     class ReLU_(Module):
@@ -837,7 +827,7 @@ def test_extended_unary_ops():
                 R.output(gv)
             return gv
 
-    verify_model(ReLU_(), example_args, {}, expected_relu_, 
run_ep_decomposition=True)
+    verify_model(ReLU_(), example_args, {}, expected_relu_)
 
 
 def test_hardtanh():
@@ -891,9 +881,9 @@ def test_hardtanh():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Hardtanh(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
-    verify_model(Hardtanh2(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
-    verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3, 
run_ep_decomposition=True)
+    verify_model(Hardtanh(), example_args, {}, expected_for_1_2)
+    verify_model(Hardtanh2(), example_args, {}, expected_for_1_2)
+    verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3)
 
 
 def test_softplus():
@@ -939,8 +929,8 @@ def test_softplus():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softplus0(), example_args, {}, expected, 
run_ep_decomposition=True)
-    verify_model(Softplus1(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(Softplus0(), example_args, {}, expected)
+    verify_model(Softplus1(), example_args, {}, expected)
 
 
 def test_leakyrelu():
@@ -997,9 +987,9 @@ def test_leakyrelu():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
-    verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
-    verify_model(LeakyReLU2(), example_args, {}, expected_for_3, 
run_ep_decomposition=True)
+    verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2)
+    verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2)
+    verify_model(LeakyReLU2(), example_args, {}, expected_for_3)
 
 
 def test_logaddexp():
@@ -1044,7 +1034,7 @@ def test_logaddexp():
         torch.randn(1, 3, 10, 10, dtype=torch.float32),
         torch.randn(1, 3, 10, 10, dtype=torch.float32),
     )
-    verify_model(LogAddExp(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(LogAddExp(), example_args, {}, expected)
 
 
 def test_logsoftmax():
@@ -1074,8 +1064,8 @@ def test_logsoftmax():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(LogSoftmax(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(LogSoftmax2(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(LogSoftmax(), example_args, {}, expected1)
+    verify_model(LogSoftmax2(), example_args, {}, expected1)
 
 
 def test_prelu():
@@ -1113,8 +1103,8 @@ def test_prelu():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Prelu1(), example_args, {}, expected, 
run_ep_decomposition=True)
-    verify_model(Prelu2(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(Prelu1(), example_args, {}, expected)
+    verify_model(Prelu2(), example_args, {}, expected)
 
 
 def test_softmax():
@@ -1144,8 +1134,8 @@ def test_softmax():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softmax(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Softmax2(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Softmax(), example_args, {}, expected1)
+    verify_model(Softmax2(), example_args, {}, expected1)
 
 
 def test_softsign():
@@ -1176,8 +1166,8 @@ def test_softsign():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softsign(), example_args, {}, expected_softsign, 
run_ep_decomposition=True)
-    verify_model(Softsign2(), example_args, {}, expected_softsign, 
run_ep_decomposition=True)
+    verify_model(Softsign(), example_args, {}, expected_softsign)
+    verify_model(Softsign2(), example_args, {}, expected_softsign)
 
 
 def test_softshrink():
@@ -1216,8 +1206,8 @@ def test_softshrink():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softshrink(), example_args, {}, expected_softshrink, 
run_ep_decomposition=True)
-    verify_model(Softshrink2(), example_args, {}, expected_softshrink, 
run_ep_decomposition=True)
+    verify_model(Softshrink(), example_args, {}, expected_softshrink)
+    verify_model(Softshrink2(), example_args, {}, expected_softshrink)
 
 
 def test_tril_triu():
@@ -1251,7 +1241,7 @@ def test_tril_triu():
                 R.output(gv)
             return gv
 
-    verify_model(Tril(), example_args, {}, expected_tril, 
run_ep_decomposition=True)
+    verify_model(Tril(), example_args, {}, expected_tril)
 
     class Triu(Module):
         def forward(self, input):
@@ -1281,7 +1271,7 @@ def test_tril_triu():
                 R.output(gv)
             return gv
 
-    verify_model(Triu(), example_args, {}, expected_triu, 
run_ep_decomposition=True)
+    verify_model(Triu(), example_args, {}, expected_triu)
 
 
 operator_binary_1 = [
@@ -1389,8 +1379,8 @@ def test_binary1(op, relax_op):
 
     expected1 = expected_binary1_inplace if op in inplace_ops else 
expected_binary1
     expected2 = expected_binary2_inplace if op in inplace_ops else 
expected_binary2
-    verify_model(Binary1(op), example_args1, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Binary2(op), example_args2, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Binary1(op), example_args1, {}, expected1)
+    verify_model(Binary2(op), example_args2, {}, expected2)
 
 
 operator_binary_2 = [
@@ -1452,8 +1442,8 @@ def test_binary2(op, relax_op):
                 R.output(gv)
             return gv
 
-    verify_model(Binary1(op), example_args1, {}, expected_binary1, 
run_ep_decomposition=True)
-    verify_model(Binary2(op), example_args2, {}, expected_binary2, 
run_ep_decomposition=True)
+    verify_model(Binary1(op), example_args1, {}, expected_binary1)
+    verify_model(Binary2(op), example_args2, {}, expected_binary2)
 
 
 def test_binary3():
@@ -1481,7 +1471,7 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(Max1(), example_args1, {}, expected_max1, 
run_ep_decomposition=True)
+    verify_model(Max1(), example_args1, {}, expected_max1)
 
     # Min
     class Min1(Module):
@@ -1501,7 +1491,7 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(Min1(), example_args1, {}, expected_min1, 
run_ep_decomposition=True)
+    verify_model(Min1(), example_args1, {}, expected_min1)
 
     # RSub
     class RSub1(Module):
@@ -1536,8 +1526,8 @@ def test_binary3():
                 R.output(gv)
             return gv
 
-    verify_model(RSub1(), example_args1, {}, expected_rsub1, 
run_ep_decomposition=True)
-    verify_model(RSub2(), example_args2, {}, expected_rsub2, 
run_ep_decomposition=True)
+    verify_model(RSub1(), example_args1, {}, expected_rsub1)
+    verify_model(RSub2(), example_args2, {}, expected_rsub2)
 
 
 # IsIn
@@ -1566,7 +1556,7 @@ def test_isin():
         torch.randn(10, 10, dtype=torch.float32),
         torch.randn(8, dtype=torch.float32),
     )
-    verify_model(IsInModel(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(IsInModel(), example_args, {}, expected)
 
 
 def test_div_mode():
@@ -1591,7 +1581,7 @@ def test_div_mode():
         torch.randn(64, 64, dtype=torch.float32),
         torch.randn(64, dtype=torch.float32),
     )
-    verify_model(DivModel(), example_args, {}, expected_div, 
run_ep_decomposition=True)
+    verify_model(DivModel(), example_args, {}, expected_div)
 
     # Case 2: Division with trunc rounding
     class DivTruncModel(torch.nn.Module):
@@ -1611,7 +1601,7 @@ def test_div_mode():
                 R.output(gv)
             return gv
 
-    verify_model(DivTruncModel(), example_args, {}, expected_div_trunc, 
run_ep_decomposition=True)
+    verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)
 
     # Case 3: Division with floor rounding
     class DivFloorModel(torch.nn.Module):
@@ -1630,7 +1620,7 @@ def test_div_mode():
                 R.output(gv)
             return gv
 
-    verify_model(DivFloorModel(), example_args, {}, expected_div_floor, 
run_ep_decomposition=True)
+    verify_model(DivFloorModel(), example_args, {}, expected_div_floor)
 
 
 def test_batchnorm2d():
@@ -1685,7 +1675,7 @@ def test_batchnorm2d():
         "w3": model.bn.running_mean.detach().numpy(),
         "w4": model.bn.running_var.detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
 
 def test_adaptive_avgpool1d():
@@ -1718,8 +1708,8 @@ def test_adaptive_avgpool1d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)
 
 
 def test_adaptive_avgpool2d():
@@ -1751,8 +1741,8 @@ def test_adaptive_avgpool2d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
 
 
 def test_adaptive_avgpool3d():
@@ -1783,8 +1773,8 @@ def test_adaptive_avgpool3d():
             return gv
 
     example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)
 
 
 def test_addmm():
@@ -1842,8 +1832,8 @@ def test_addmm():
         torch.randn(10, 10, dtype=torch.float32),
     )
 
-    verify_model(Addmm1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Addmm2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Addmm1(), example_args, {}, expected1)
+    verify_model(Addmm2(), example_args, {}, expected2)
 
 
 def test_avg_pool1d():
@@ -1946,10 +1936,10 @@ def test_avg_pool1d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
-    verify_model(AvgPool1d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AvgPool1d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool1d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool1d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(AvgPool1d1(), example_args, {}, expected1)
+    verify_model(AvgPool1d2(), example_args, {}, expected2)
+    verify_model(AvgPool1d3(), example_args, {}, expected2)
+    verify_model(AvgPool1d4(), example_args, {}, expected3)
 
 
 def test_avg_pool2d():
@@ -2039,10 +2029,10 @@ def test_avg_pool2d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(AvgPool2d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AvgPool2d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool2d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool2d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(AvgPool2d1(), example_args, {}, expected1)
+    verify_model(AvgPool2d2(), example_args, {}, expected2)
+    verify_model(AvgPool2d3(), example_args, {}, expected2)
+    verify_model(AvgPool2d4(), example_args, {}, expected3)
 
 
 def test_avg_pool3d():
@@ -2135,10 +2125,10 @@ def test_avg_pool3d():
             return gv
 
     example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
-    verify_model(AvgPool3d1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(AvgPool3d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool3d3(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(AvgPool3d4(), example_args, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(AvgPool3d1(), example_args, {}, expected1)
+    verify_model(AvgPool3d2(), example_args, {}, expected2)
+    verify_model(AvgPool3d3(), example_args, {}, expected2)
+    verify_model(AvgPool3d4(), example_args, {}, expected3)
 
 
 def test_baddbmm():
@@ -2372,15 +2362,15 @@ def test_conv_transpose1d():
 
     model = ConvTranspose1d1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = ConvTranspose1d1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = ConvTranspose1d2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_conv_transpose2d():
@@ -2466,15 +2456,15 @@ def test_conv_transpose2d():
 
     model = ConvTranspose2d1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = ConvTranspose2d1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = ConvTranspose2d2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_conv1d():
@@ -2558,15 +2548,15 @@ def test_conv1d():
 
     model = Conv1D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv1D1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv1D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_conv2d():
@@ -2650,15 +2640,15 @@ def test_conv2d():
 
     model = Conv2D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv2D1Func()
     binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv2D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_conv3d():
@@ -2742,15 +2732,15 @@ def test_conv3d():
 
     model = Conv3D1()
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv3D1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Conv3D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_pad():
@@ -2973,9 +2963,7 @@ def test_pad():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(
-        PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant, 
run_ep_decomposition=True
-    )
+    verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, 
expected_constant)
     verify_model(
         PadModel(pad=[1, 1, 2, 2], mode="reflect"),
         example_args,
@@ -3037,12 +3025,8 @@ def test_pixel_shuffle():
             return gv
 
     example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
-    verify_model(
-        PixelShuffle1(upscale_factor=2), example_args, {}, expected, 
run_ep_decomposition=True
-    )
-    verify_model(
-        PixelShuffle2(upscale_factor=2), example_args, {}, expected, 
run_ep_decomposition=True
-    )
+    verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
+    verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)
 
 
 def test_einsum():
@@ -3115,7 +3099,7 @@ def test_outer():
         torch.randn(3, dtype=torch.float32),
         torch.randn(4, dtype=torch.float32),
     )
-    verify_model(Outer(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(Outer(), example_args, {}, expected)
 
 
 def test_embedding():
@@ -3145,7 +3129,7 @@ def test_embedding():
 
     model = Embedding()
     binding = {"w1": model.embedding.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
 
 def test_groupnorm():
@@ -3194,7 +3178,7 @@ def test_groupnorm():
         "w1": model.gn.weight.detach().numpy(),
         "w2": model.gn.bias.detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
 
 def test_instancenorm2d():
@@ -3239,7 +3223,7 @@ def test_instancenorm2d():
         "w1": torch.ones(3).detach().numpy(),
         "w2": torch.zeros(3).detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
 
 def test_layernorm():
@@ -3354,15 +3338,15 @@ def test_linear():
 
     model = Dense1()
     binding = {"w1": model.linear.weight.detach().numpy(), "w2": 
model.linear.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Dense1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected1)
 
     model = Dense2()
     binding = {"w1": model.linear.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_maxpool1d():
@@ -3479,9 +3463,9 @@ def test_maxpool1d():
     example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
 
     # Verify the models
-    verify_model(MaxPool1d(), example_args1, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(MaxPool1d_functional(), example_args2, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(MaxPool1d2(), example_args3, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(MaxPool1d(), example_args1, {}, expected1)
+    verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
+    verify_model(MaxPool1d2(), example_args3, {}, expected3)
 
 
 def test_maxpool2d():
@@ -3596,10 +3580,10 @@ def test_maxpool2d():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(MaxPool2d(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(MaxPool2d_functional(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(MaxPool2d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(MaxPool2d3(), example_args, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(MaxPool2d(), example_args, {}, expected1)
+    verify_model(MaxPool2d_functional(), example_args, {}, expected1)
+    verify_model(MaxPool2d2(), example_args, {}, expected2)
+    verify_model(MaxPool2d3(), example_args, {}, expected3)
 
 
 def test_maxpool3d():
@@ -3718,10 +3702,10 @@ def test_maxpool3d():
     example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
 
     # Verify the models with expected IR modules
-    verify_model(MaxPool3d(), example_args1, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(MaxPool3d_functional(), example_args1, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(MaxPool3d2(), example_args2, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(MaxPool3d3(), example_args3, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(MaxPool3d(), example_args1, {}, expected1)
+    verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
+    verify_model(MaxPool3d2(), example_args2, {}, expected2)
+    verify_model(MaxPool3d3(), example_args3, {}, expected3)
 
 
 def test_scaled_dot_product_attention():
@@ -4039,10 +4023,10 @@ def test_unbind():
             return gv
 
     example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Unbind1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Unbind2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Unbind1(), example_args, {}, expected1)
+    verify_model(Unbind2(), example_args, {}, expected2)
     single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
-    verify_model(Unbind2(), single_dim_args, {}, expected3, 
run_ep_decomposition=True)
+    verify_model(Unbind2(), single_dim_args, {}, expected3)
 
 
 def test_interpolate():
@@ -4536,15 +4520,9 @@ def test_interpolate():
             return gv
 
     example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
-    verify_model(
-        InterpolateBilinear(), example_args, {}, expected_bilinear, 
run_ep_decomposition=True
-    )
-    verify_model(
-        InterpolateNearest(), example_args, {}, expected_nearest, 
run_ep_decomposition=True
-    )
-    verify_model(
-        InterpolateBicubic(), example_args, {}, expected_bicubic, 
run_ep_decomposition=True
-    )
+    verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
+    verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
+    verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)
 
 
 def test_mean():
@@ -4581,8 +4559,8 @@ def test_mean():
             return gv
 
     example_args = (torch.randn(256, 256, dtype=torch.float32),)
-    verify_model(Mean(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(MeanKeepDim(), example_args, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Mean(), example_args, {}, Expected1)
+    verify_model(MeanKeepDim(), example_args, {}, Expected2)
 
 
 def test_sum():
@@ -4604,7 +4582,7 @@ def test_sum():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Sum(), example_args, {}, expected1, run_ep_decomposition=True)
+    verify_model(Sum(), example_args, {}, expected1)
 
 
 def test_argmax_argmin():
@@ -4648,8 +4626,8 @@ def test_argmax_argmin():
                 R.output(gv)
             return gv
 
-    verify_model(Argmax1(), example_args, {}, expected_argmax1, 
run_ep_decomposition=True)
-    verify_model(Argmax2(), example_args, {}, expected_argmax2, 
run_ep_decomposition=True)
+    verify_model(Argmax1(), example_args, {}, expected_argmax1)
+    verify_model(Argmax2(), example_args, {}, expected_argmax2)
 
     class Argmin1(Module):
         def __init__(self) -> None:
@@ -4689,8 +4667,8 @@ def test_argmax_argmin():
                 R.output(gv)
             return gv
 
-    verify_model(Argmin1(), example_args, {}, expected_argmin1, 
run_ep_decomposition=True)
-    verify_model(Argmin2(), example_args, {}, expected_argmin2, 
run_ep_decomposition=True)
+    verify_model(Argmin1(), example_args, {}, expected_argmin1)
+    verify_model(Argmin2(), example_args, {}, expected_argmin2)
 
 
 def test_cat_concat():
@@ -4737,10 +4715,10 @@ def test_cat_concat():
             return gv
 
     example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
-    verify_model(Cat0(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Cat1(), example_args, {}, Expected2, 
run_ep_decomposition=True)
-    verify_model(Cat2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
-    verify_model(Cat3(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Cat0(), example_args, {}, Expected1)
+    verify_model(Cat1(), example_args, {}, Expected2)
+    verify_model(Cat2(), example_args, {}, Expected2)
+    verify_model(Cat3(), example_args, {}, Expected1)
 
 
 def test_cumsum():
@@ -4762,7 +4740,7 @@ def test_cumsum():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Cumsum(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Cumsum(), example_args, {}, expected1)
 
 
 def test_expand():
@@ -4788,8 +4766,8 @@ def test_expand():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Expand1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Expand2(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Expand1(), example_args, {}, expected1)
+    verify_model(Expand2(), example_args, {}, expected1)
 
 
 def test_flatten():
@@ -4815,7 +4793,7 @@ def test_flatten():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Flatten(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Flatten(), example_args, {}, expected1)
 
 
 def test_meshgrid():
@@ -4865,8 +4843,8 @@ def test_meshgrid():
         torch.randn(3, dtype=torch.float32),
         torch.randn(3, dtype=torch.float32),
     )
-    verify_model(Meshgrid1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Meshgrid2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Meshgrid1(), example_args, {}, expected1)
+    verify_model(Meshgrid2(), example_args, {}, expected2)
 
 
 def test_permute():
@@ -4892,8 +4870,8 @@ def test_permute():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Permute1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Permute2(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Permute1(), example_args, {}, expected1)
+    verify_model(Permute2(), example_args, {}, expected1)
 
 
 def test_repeat():
@@ -4930,13 +4908,13 @@ def test_repeat():
             return gv
 
     example_args = (torch.randn(3, dtype=torch.float32),)
-    verify_model(Tile1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Tile1(), example_args, {}, expected1)
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Tile2(), example_args, {}, expected2)
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Tile2(), example_args, {}, expected2)
 
 
 def test_reshape():
@@ -4958,7 +4936,7 @@ def test_reshape():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Reshape(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Reshape(), example_args, {}, expected1)
 
 
 def test_reshape_as():
@@ -4984,7 +4962,7 @@ def test_reshape_as():
         torch.randn(1, 2, 3, 4, dtype=torch.float32),
         torch.randn(2, 12, dtype=torch.float32),
     )
-    verify_model(ReshapeAs(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(ReshapeAs(), example_args, {}, expected1)
 
 
 def test_roll():
@@ -5062,9 +5040,9 @@ def test_roll():
     example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
 
     # Run verification for each case
-    verify_model(Roll1(), (example_input,), {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Roll2(), (example_input,), {}, Expected2, 
run_ep_decomposition=True)
-    verify_model(Roll3(), (example_input,), {}, Expected3, 
run_ep_decomposition=True)
+    verify_model(Roll1(), (example_input,), {}, Expected1)
+    verify_model(Roll2(), (example_input,), {}, Expected2)
+    verify_model(Roll3(), (example_input,), {}, Expected3)
 
 
 def test_select_slice():
@@ -5144,10 +5122,10 @@ def test_select_slice():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Slice1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Slice1(), example_args, {}, expected1)
 
     example_args = (torch.randn(8, 16, dtype=torch.float32),)
-    verify_model(Slice2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Slice2(), example_args, {}, expected2)
 
 
 def test_slice_scatter():
@@ -5189,10 +5167,10 @@ def test_slice_scatter():
             return gv
 
     example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), 
torch.randn(8, 3, 10, 10))
-    verify_model(SliceScatter1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(SliceScatter1(), example_args, {}, expected1)
 
     example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 
16))
-    verify_model(SliceScatter2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(SliceScatter2(), example_args, {}, expected2)
 
 
 def test_split():
@@ -5331,11 +5309,11 @@ def test_split():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Chunk(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Chunk(), example_args, {}, Expected)
 
     example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Unbind1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Unbind2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Unbind1(), example_args, {}, expected1)
+    verify_model(Unbind2(), example_args, {}, expected2)
 
 
 def test_squeeze():
@@ -5373,8 +5351,8 @@ def test_squeeze():
 
     example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
 
-    verify_model(Squeeze1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Squeeze2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Squeeze1(), example_args, {}, Expected1)
+    verify_model(Squeeze2(), example_args, {}, Expected2)
 
 
 def test_stack():
@@ -5439,10 +5417,10 @@ def test_stack():
 
     example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
 
-    verify_model(Stack0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
-    verify_model(Stack1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Stack2(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Stack3(), example_args, {}, Expected3, 
run_ep_decomposition=True)
+    verify_model(Stack0(), example_args, {}, Expected0)
+    verify_model(Stack1(), example_args, {}, Expected1)
+    verify_model(Stack2(), example_args, {}, Expected1)
+    verify_model(Stack3(), example_args, {}, Expected3)
 
 
 def test_tile():
@@ -5485,9 +5463,9 @@ def test_tile():
             return gv
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
-    verify_model(Tile3(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Tile1(), example_args, {}, expected1)
+    verify_model(Tile2(), example_args, {}, expected2)
+    verify_model(Tile3(), example_args, {}, expected2)
 
 
 def test_transpose():
@@ -5509,7 +5487,7 @@ def test_transpose():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Transpose(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Transpose(), example_args, {}, expected1)
 
 
 def test_unsqueeze():
@@ -5549,8 +5527,8 @@ def test_unsqueeze():
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
 
-    verify_model(Unsqueeze1(), example_args, {}, expected1, 
run_ep_decomposition=True)
-    verify_model(Unsqueeze2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Unsqueeze1(), example_args, {}, expected1)
+    verify_model(Unsqueeze2(), example_args, {}, expected2)
 
 
 def test_view():
@@ -5572,7 +5550,7 @@ def test_view():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(View(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(View(), example_args, {}, expected1)
 
 
 def test_arange():
@@ -5593,7 +5571,7 @@ def test_arange():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Arange(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Arange(), example_args, {}, Expected)
 
 
 def test_hamming_window():
@@ -5620,7 +5598,7 @@ def test_hamming_window():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(HammingWindow(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(HammingWindow(), example_args, {}, Expected)
 
 
 def test_contiguous():
@@ -5640,7 +5618,7 @@ def test_contiguous():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Contiguous(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Contiguous(), example_args, {}, Expected)
 
 
 def test_clone():
@@ -5660,7 +5638,7 @@ def test_clone():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Clone(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Clone(), example_args, {}, Expected)
 
 
 def test_empty():
@@ -5683,7 +5661,7 @@ def test_empty():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Empty(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Empty(), example_args, {}, Expected)
 
 
 def test_empty_without_dtype():
@@ -5704,7 +5682,7 @@ def test_empty_without_dtype():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(EmptyWithoutDtype(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(EmptyWithoutDtype(), example_args, {}, Expected)
 
 
 def test_fill():
@@ -5727,7 +5705,7 @@ def test_fill():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Fill(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Fill(), example_args, {}, Expected)
 
 
 def test_fill_inplace():
@@ -5753,7 +5731,7 @@ def test_fill_inplace():
             return gv
 
     example_args = (torch.randn(2, 3, dtype=torch.float32),)
-    verify_model(FillInplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(FillInplace(), example_args, {}, Expected)
 
 
 def test_masked_fill():
@@ -5775,7 +5753,7 @@ def test_masked_fill():
             return gv
 
     example_args = (torch.randn(128, 128, dtype=torch.float32), 
torch.rand(128, 128) < 0.5)
-    verify_model(Masked_Fill(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Masked_Fill(), example_args, {}, Expected)
 
 
 def test_masked_fill_inplace():
@@ -5799,7 +5777,7 @@ def test_masked_fill_inplace():
             return gv
 
     example_args = (torch.randn(128, 128, dtype=torch.float32), 
torch.rand(128, 128) < 0.5)
-    verify_model(Masked_Fill_Inplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
 
 
 def test_new_ones():
@@ -5823,7 +5801,7 @@ def test_new_ones():
             return gv
 
     example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
-    verify_model(NewOnes(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(NewOnes(), example_args, {}, expected1)
 
 
 def test_new_zeros():
@@ -5846,7 +5824,7 @@ def test_new_zeros():
             return gv
 
     example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
-    verify_model(NewZeros(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(NewZeros(), example_args, {}, expected1)
 
 
 def test_to_copy():
@@ -5937,11 +5915,11 @@ def test_to_copy():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(ToFloat(), example_args, {}, expected_float, 
run_ep_decomposition=True)
-    verify_model(ToHalf(), example_args, {}, expected_half, 
run_ep_decomposition=True)
-    verify_model(Type(), example_args, {}, expected_type, 
run_ep_decomposition=True)
-    verify_model(To1(), example_args, {}, expected_to1, 
run_ep_decomposition=True)
-    verify_model(To2(), example_args, {}, expected_to2, 
run_ep_decomposition=True)
+    verify_model(ToFloat(), example_args, {}, expected_float)
+    verify_model(ToHalf(), example_args, {}, expected_half)
+    verify_model(Type(), example_args, {}, expected_type)
+    verify_model(To1(), example_args, {}, expected_to1)
+    verify_model(To2(), example_args, {}, expected_to2)
 
 
 def test_keep_params():
@@ -5986,9 +5964,7 @@ def test_keep_params():
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
     model = Conv2D1()
     exported_program = torch.export.export(model, example_args)
-    mod = from_exported_program(
-        exported_program, keep_params_as_input=True, run_ep_decomposition=True
-    )
+    mod = from_exported_program(exported_program, keep_params_as_input=True)
     mod, params = detach_params(mod)
     tvm.ir.assert_structural_equal(mod, expected1)
     func = mod["main"]
@@ -6024,9 +6000,7 @@ def test_unwrap_unit_return_tuple():
 
     example_args = (torch.randn(256, 256, dtype=torch.float32),)
     exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(
-        exported_program, unwrap_unit_return_tuple=True, 
run_ep_decomposition=True
-    )
+    mod = from_exported_program(exported_program, 
unwrap_unit_return_tuple=True)
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
@@ -6056,9 +6030,7 @@ def test_no_bind_return_tuple():
         torch.randn(256, 256, dtype=torch.float32),
     )
     exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(
-        exported_program, no_bind_return_tuple=True, run_ep_decomposition=True
-    )
+    mod = from_exported_program(exported_program, no_bind_return_tuple=True)
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
@@ -6081,7 +6053,7 @@ def test_empty_like():
 
     example_args = (torch.randn(5, dtype=torch.float32),)
 
-    verify_model(EmptyLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(EmptyLike(), example_args, {}, Expected)
 
 
 def test_one_hot():
@@ -6108,7 +6080,7 @@ def test_one_hot():
 
     example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
 
-    verify_model(OneHot(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(OneHot(), example_args, {}, Expected)
 
 
 def test_ones_like():
@@ -6132,7 +6104,7 @@ def test_ones_like():
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(OnesLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(OnesLike(), example_args, {}, Expected)
 
 
 def test_zero_inplace():
@@ -6161,7 +6133,7 @@ def test_zero_inplace():
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(ZeroInplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(ZeroInplace(), example_args, {}, Expected)
 
 
 def test_zeros():
@@ -6185,7 +6157,7 @@ def test_zeros():
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(Zeros(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Zeros(), example_args, {}, Expected)
 
 
 def test_zeros_like():
@@ -6208,7 +6180,7 @@ def test_zeros_like():
             return gv
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
-    verify_model(ZerosLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(ZerosLike(), example_args, {}, Expected)
 
 
 def test_type_as():
@@ -6234,7 +6206,7 @@ def test_type_as():
         torch.rand(128, 128, dtype=torch.float16),
     )
 
-    verify_model(TypeAs(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(TypeAs(), example_args, {}, Expected)
 
 
 def test_select():
@@ -6256,7 +6228,7 @@ def test_select():
 
     example_args = (torch.randn(2, 3, dtype=torch.float32),)
 
-    verify_model(Select(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Select(), example_args, {}, Expected)
 
 
 def test_unflatten():
@@ -6282,8 +6254,8 @@ def test_unflatten():
 
     example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)
 
-    verify_model(Unflatten(), example_args, {}, Expected, 
run_ep_decomposition=True)
-    verify_model(Unflatten1(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Unflatten(), example_args, {}, Expected)
+    verify_model(Unflatten1(), example_args, {}, Expected)
 
 
 def test_gather():
@@ -6360,10 +6332,10 @@ def test_gather():
         torch.randint(0, 3, (2, 3), dtype=torch.int64),
     )
 
-    verify_model(Gather0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
-    verify_model(Gather1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
-    verify_model(Gather2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
-    verify_model(Gather3(), example_args, {}, Expected3, 
run_ep_decomposition=True)
+    verify_model(Gather0(), example_args, {}, Expected0)
+    verify_model(Gather1(), example_args, {}, Expected1)
+    verify_model(Gather2(), example_args, {}, Expected2)
+    verify_model(Gather3(), example_args, {}, Expected3)
 
 
 def test_index_put():
@@ -6555,11 +6527,11 @@ def test_index_put():
             return gv
 
     # Run verification for each case
-    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D, 
run_ep_decomposition=True)
-    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D, 
run_ep_decomposition=True)
-    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D, 
run_ep_decomposition=True)
-    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D, 
run_ep_decomposition=True)
-    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D, 
run_ep_decomposition=True)
+    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
+    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
+    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
+    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
+    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
 
 
 def test_flip():
@@ -6597,8 +6569,8 @@ def test_flip():
 
     example_args = (torch.randn(2, 2, dtype=torch.float32),)
 
-    verify_model(Flip0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
-    verify_model(Flip1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Flip0(), example_args, {}, Expected0)
+    verify_model(Flip1(), example_args, {}, Expected1)
 
 
 def test_take():
@@ -6625,7 +6597,7 @@ def test_take():
         torch.randint(0, 5, (3,), dtype=torch.int64),
     )
 
-    verify_model(Take(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Take(), example_args, {}, Expected)
 
 
 def test_std():
@@ -6647,7 +6619,7 @@ def test_std():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Std(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Std(), example_args, {}, Expected)
 
 
 def test_var():
@@ -6668,7 +6640,7 @@ def test_var():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Var(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Var(), example_args, {}, Expected)
 
 
 def test_prod():
@@ -6689,7 +6661,7 @@ def test_prod():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Prod(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Prod(), example_args, {}, Expected)
 
 
 def test_cumprod():
@@ -6710,7 +6682,7 @@ def test_cumprod():
             return gv
 
     example_input = torch.randn(5, 3, dtype=torch.float32)
-    verify_model(Cumprod(), (example_input,), {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Cumprod(), (example_input,), {}, Expected)
 
 
 def test_where():
@@ -6736,7 +6708,7 @@ def test_where():
     x = torch.randn(5, 3, dtype=torch.float32)
     y = torch.randn(5, 3, dtype=torch.float32)
 
-    verify_model(Where(), (condition, x, y), {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Where(), (condition, x, y), {}, Expected)
 
 
 def test_bucketize():
@@ -6761,7 +6733,7 @@ def test_bucketize():
     input_tensor = torch.arange(0, 20)
     boundaries = torch.arange(0, 20, 2)
 
-    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
 
 
 def test_argsort():
@@ -6788,7 +6760,7 @@ def test_argsort():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Argsort(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Argsort(), example_args, {}, Expected)
 
 
 def test_topk():
@@ -6816,7 +6788,7 @@ def test_topk():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Topk(), example_args, {}, Expected)
 
 
 def test_dynamic_shape():
@@ -6872,7 +6844,7 @@ def test_broadcast_to():
             return gv
 
     example_args = (torch.randn(5, 1, dtype=torch.float32),)
-    verify_model(BroadcastTo(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(BroadcastTo(), example_args, {}, Expected)
 
 
 def test_narrow():
@@ -6901,7 +6873,7 @@ def test_narrow():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Narrow(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Narrow(), example_args, {}, Expected)
 
 
 def test_item():
@@ -6920,7 +6892,7 @@ def test_item():
             return gv
 
     example_args = (torch.randn(1, dtype=torch.float32),)
-    verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True)
+    verify_model(Item(), example_args, {}, Expected)
 
 
 def test_norm():
@@ -7032,9 +7004,7 @@ def test_norm():
     example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
 
     for (p, dim, keepdim), expected in norms:
-        verify_model(
-            Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected, 
run_ep_decomposition=True
-        )
+        verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, 
expected)
 
 
 def test_eye():
@@ -7095,10 +7065,10 @@ def test_eye():
             return gv
 
     example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
-    verify_model(Eye1(), example_args1, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Eye1(), example_args1, {}, Expected1)
 
     example_args2 = (torch.randn(5, dtype=torch.float32),)
-    verify_model(Eye2(), example_args2, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Eye2(), example_args2, {}, Expected2)
 
 
 def test_cross_entropy():
@@ -7146,7 +7116,7 @@ def test_cross_entropy():
             return gv
 
     example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
-    verify_model(CrossEntropyModule(), example_args1, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
 
 
 def test_linspace():
@@ -7178,7 +7148,7 @@ def test_linspace():
             return gv
 
     example_args = (torch.randn(9, 9, dtype=torch.float32),)
-    verify_model(Linspace(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Linspace(), example_args, {}, Expected)
 
 
 @pytest.mark.parametrize(
@@ -7215,7 +7185,7 @@ def test_dtypes(torch_dtype, relax_dtype):
                 R.output(gv)
             return gv
 
-    verify_model(Model(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Model(), example_args, {}, Expected)
 
 
 def test_mm():
@@ -7241,7 +7211,7 @@ def test_mm():
                 R.output(gv)
             return gv
 
-    verify_model(MatrixMultiply(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(MatrixMultiply(), example_args, {}, Expected)
 
 
 def test_lstm():
@@ -7266,7 +7236,7 @@ def test_lstm():
     with torch.no_grad():
         pytorch_output = model(x)
     exported_program = export(model, args=(x,))
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -7302,7 +7272,7 @@ def test_lstm():
     with torch.no_grad():
         pytorch_output2 = model2(x2)
     exported_program2 = export(model2, args=(x2,))
-    mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
+    mod2 = from_exported_program(exported_program2)
     ex2 = relax.build(mod2, target)
     vm2 = relax.VirtualMachine(ex2, tvm.cpu())
     x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -7334,7 +7304,7 @@ def test_tensor_none_tuple():
                 R.output(gv)
             return gv
 
-    verify_model(TensorNoneModel(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(TensorNoneModel(), example_args, {}, Expected)
 
 
 def test_gru():
@@ -7359,7 +7329,7 @@ def test_gru():
     with torch.no_grad():
         pytorch_output = model(x)
     exported_program = export(model, args=(x,))
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -7395,7 +7365,7 @@ def test_gru():
     with torch.no_grad():
         pytorch_output2 = model2(x2)
     exported_program2 = export(model2, args=(x2,))
-    mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
+    mod2 = from_exported_program(exported_program2)
     ex2 = relax.build(mod2, target)
     vm2 = relax.VirtualMachine(ex2, tvm.cpu())
     x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -7432,7 +7402,7 @@ def test_dynamic_shape_with_range_constraints():
     dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
     exported_program = export(DynamicModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
@@ -7466,7 +7436,7 @@ def test_dynamic_shape_with_addition_constraints():
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
     exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 
@@ -7500,7 +7470,7 @@ def test_dynamic_shape_with_subtraction_constraints():
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
     exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 
@@ -7534,7 +7504,7 @@ def test_dynamic_shape_with_multiplication_constraints():
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
     exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    mod = from_exported_program(exported_program)
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 

Reply via email to